In [1]:
from acs_se_cnn.model import SEBlock, ACSLayer

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.functional import F

import pickle


###################
## Configuration ##
###################
# SYSTEM
class Args:
    name = "baseline"
    device = "cpu"
    subject = 1
args = Args()
    
device = torch.device(args.device)

# LEARNING STRATEGY
batch_size = 20
epochs     = 500
criterion    = nn.BCEWithLogitsLoss()
Optimizer    = torch.optim.RMSprop
lr              = 0.001
# HYPER PARAMETER
sparse_lambda = 1 # ?

fit_data = "2a"
data_path = "cwt_data/2a" if fit_data=="2a" else "cwt_data/2b"

###################
#### Modeling #####
###################

class JHModel(nn.Module):
    def __init__(
        self, 
        # MODEL HYPER PARAMETER ,
        n_channels = 22 if fit_data=="2a" else 3,
        n_kerenls  = 64,
        r          = 2
    ):
        super().__init__()

        # ATUO CHANNEL SELECTION
        self.acs_layer   = ACSLayer(c=n_channels, r=r)
        
        # FEATURE EXTRACTION
        self.conv_layer1 = nn.Conv2d(n_channels, n_kerenls, kernel_size=(4,4), stride=(2, 2), padding=1)
        self.se_block1   = SEBlock(c=n_kerenls, r=r)
        
        self.conv_layer2 = nn.Conv2d(n_kerenls, n_kerenls, kernel_size=(4,4), stride=(4, 4))
        self.se_block2   = SEBlock(c=n_kerenls, r=r)
        
        self.conv_layer3 = nn.Conv2d(n_kerenls, n_kerenls, kernel_size=(4,4), stride=(4, 4))
        self.se_block3   = SEBlock(c=n_kerenls, r=r)
        
        # OUTPUT
        self.fc1 = nn.Linear(4, 1)
        self.fc2 = nn.Linear(64, 1)
#         self.sigmoid = F.sigmoid()
        
    def forward(self, inputs, return_s_acs=False): 
        """ 
        Args
        ----
            inputs (batch, channel, height, width) 
        """
        # ATUO CHANNEL SELECTION
        x, s_acs = self.acs_layer(inputs)
#         B, _, _, _ = inputs.shape
#         s_acs = inputs.new_zeros(B,22,1,1)
        
        # FEATURE EXTRACTION
        x = self.conv_layer1(x)
        x = F.elu(x)
        x = self.se_block1(x)
        
        x = self.conv_layer2(x)
        x = F.elu(x)
        x = self.se_block2(x)
        
        x = self.conv_layer3(x)
        x = F.elu(x)
        x = self.se_block3(x)        
        
        # OUTPUT
        B, _, _, _ = x.shape
        x = x.reshape(B, 64, 4)
        x = self.fc1(x) # (B, 64, 1)
        x = x.squeeze() # (B, 64)
        x = F.elu(x)
        out = self.fc2(x)
        
        if return_s_acs:
            return out, s_acs
        else:
            return out
        

In [2]:
class HParams:
    # for train
    # batch_size = 2
    per_batch = 20
    epoch = 500

    # for model
    input_channel = 22
    output_channel = 64
    r = 2
    
hparams = HParams()

class SqueezeExcitation(nn.Module):
    def __init__(self, input_channel):
        super(SqueezeExcitation, self).__init__()
        self.fc1 = torch.nn.Linear(input_channel, int(input_channel/hparams.r), bias=False)
        self.fc2 = torch.nn.Linear(int(input_channel/hparams.r), input_channel, bias=False)
        self.sm = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        # x -> (B, C, H, W)
        _, _, H, W = x.shape
        output = x.clone()
        x = torch.sum(x, axis=(2,3)) / (H*W)
        x = x.squeeze()
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sm(x)
        x = x.unsqueeze(2).unsqueeze(3)
        return output * x, x


class Recognizer(nn.Module):
    def __init__(self):
        super(Recognizer, self).__init__()
        self.acs = SqueezeExcitation(hparams.input_channel)

        self.conv1 = torch.nn.Conv2d(hparams.input_channel,
                                     hparams.output_channel,
                                     kernel_size=4,
                                     stride=2,
                                     padding=1)
        self.elu1 = torch.nn.ELU()
        self.se1 = SqueezeExcitation(hparams.output_channel)
        self.conv2 = torch.nn.Conv2d(hparams.output_channel,
                                     hparams.output_channel,
                                     kernel_size=4,
                                     stride=4)
        self.elu2 = torch.nn.ELU()
        self.se2 = SqueezeExcitation(hparams.output_channel)
        self.conv3 = torch.nn.Conv2d(hparams.output_channel,
                                     hparams.output_channel,
                                     kernel_size=4,
                                     stride=4)
        self.elu3 = torch.nn.ELU()
        self.se3 = SqueezeExcitation(hparams.output_channel)
        self.fc = torch.nn.Linear(4,1)
        self.elu4 = torch.nn.ELU()
        self.ffc = torch.nn.Linear(64,1)
        self.sm = torch.nn.Sigmoid()

    def forward(self, x):
        # x -> (B, C, H, W)
        B, _, _, _ = x.shape
        x, sp = self.acs(x)
        x = self.conv1(x)
        x = self.elu1(x)
        x, _ = self.se1(x)
        x = self.conv2(x)
        x = self.elu2(x)
        x, _ = self.se2(x)
        x = self.conv3(x)
        x = self.elu3(x)
        x, _ = self.se3(x)
        # x -> (B, C, 2, 2)
        x = x.reshape(B,-1,4)
        x = self.fc(x).squeeze()
        x = self.elu4(x)
        x = self.ffc(x)
        # x -> (B, C, 1)
        return sp, self.sm(x).squeeze()

class SWModel(nn.Module):
    def __init__(self):
        super(SWModel, self).__init__()
        self.recognizer = Recognizer()

    def forward(self, x):
        B, _, _, _ = x.shape
        sparse, bce = self.recognizer(x)
        sparse_loss = torch.norm(sparse, 1) / B
        return bce, sparse_loss

In [3]:
torch.manual_seed(0)
jh_model = JHModel()

In [4]:
torch.manual_seed(0)
sw_model = SWModel()

In [5]:
for jh, sw in zip(jh_model.named_parameters(), sw_model.named_parameters()):
    print("jh", jh[0])
    print("sw", sw[0])
    print("\tAre they same?", torch.all(jh[1]==sw[1]))
    print()

jh acs_layer.excitation.0.weight
sw recognizer.acs.fc1.weight
	Are they same? tensor(True)

jh acs_layer.excitation.2.weight
sw recognizer.acs.fc2.weight
	Are they same? tensor(True)

jh conv_layer1.weight
sw recognizer.conv1.weight
	Are they same? tensor(True)

jh conv_layer1.bias
sw recognizer.conv1.bias
	Are they same? tensor(True)

jh se_block1.excitation.0.weight
sw recognizer.se1.fc1.weight
	Are they same? tensor(True)

jh se_block1.excitation.2.weight
sw recognizer.se1.fc2.weight
	Are they same? tensor(True)

jh conv_layer2.weight
sw recognizer.conv2.weight
	Are they same? tensor(True)

jh conv_layer2.bias
sw recognizer.conv2.bias
	Are they same? tensor(True)

jh se_block2.excitation.0.weight
sw recognizer.se2.fc1.weight
	Are they same? tensor(True)

jh se_block2.excitation.2.weight
sw recognizer.se2.fc2.weight
	Are they same? tensor(True)

jh conv_layer3.weight
sw recognizer.conv3.weight
	Are they same? tensor(True)

jh conv_layer3.bias
sw recognizer.conv3.bias
	Are they same? 