In [1]:
import torch
import torch.nn as nn
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
class AE(nn.Module):
    def __init__(self, n_channels, len_sw, n_classes, outdim=128, backbone=True):
        super(AE, self).__init__()

        self.backbone = backbone
        self.len_sw = len_sw

        self.e1 = nn.Linear(n_channels, 8)
        self.e2 = nn.Linear(8 * len_sw, 2 * len_sw)
        self.e3 = nn.Linear(2 * len_sw, outdim)

        self.d1 = nn.Linear(outdim, 2 * len_sw)
        self.d2 = nn.Linear(2 * len_sw, 8 * len_sw)
        self.d3 = nn.Linear(8, n_channels)

        self.out_dim = outdim

        if backbone == False:
            self.classifier = nn.Linear(outdim, n_classes)

    def forward(self, x):
        x_e1 = self.e1(x)
        x_e1 = x_e1.reshape(x_e1.shape[0], -1)
        x_e2 = self.e2(x_e1)
        x_encoded = self.e3(x_e2)

        x_d1 = self.d1(x_encoded)
        x_d2 = self.d2(x_d1)
        x_d2 = x_d2.reshape(x_d2.shape[0], self.len_sw, 8)
        x_decoded = self.d3(x_d2)

        if self.backbone:
            return x_decoded, x_encoded
        else:
            out = self.classifier(x_encoded)
            return out, x_decoded

class CNN_AE(nn.Module):
    def __init__(self, n_channels, length, n_classes, out_channels=128, backbone=True):
        super(CNN_AE, self).__init__()

        self.backbone = backbone
        self.n_channels = n_channels
        print(length)
        print(length % 2)
        print(length % 4)
        print(length % 8)
        curr_length = length
        if curr_length % 2 != 0:
            self.padding3 = True
        else:
            self.padding3 = False
        curr_length = curr_length // 2 
        self.e_conv1 = nn.Sequential(nn.Conv1d(n_channels, 32, kernel_size=5, stride=1, bias=False, padding=2),
                                         nn.BatchNorm1d(32),
                                         nn.ReLU())
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2, padding=0, return_indices=True)
        self.dropout = nn.Dropout(0.35)

        if curr_length % 2 != 0:
            self.padding2 = True
        else:
            self.padding2 = False
        curr_length = curr_length // 2
        self.e_conv2 = nn.Sequential(nn.Conv1d(32, 64, kernel_size=5, stride=1, bias=False, padding=2),
                                         nn.BatchNorm1d(64),
                                         nn.ReLU())
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2, padding=0, return_indices=True)

        if curr_length % 2 != 0:
            self.padding1 = True
        else:
            self.padding1 = False
        curr_length = curr_length // 2
        self.e_conv3 = nn.Sequential(nn.Conv1d(64, out_channels, kernel_size=5, stride=1, bias=False, padding=2),
                                         nn.BatchNorm1d(out_channels),
                                         nn.ReLU())
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2, padding=0, return_indices=True)

        self.unpool1 = nn.MaxUnpool1d(kernel_size=2, stride=2, padding=0)
        
        # if length % 8 == 1:
        #     padding = 3
        # else:
        #     padding = 2
        self.d_conv1 = nn.Sequential(nn.ConvTranspose1d(out_channels, 64, kernel_size=5, stride=1, bias=False, padding=2),
                                     nn.BatchNorm1d(64),
                                     nn.ReLU())

        # if n_channels == 9: # ucihar
        #     self.lin1 = nn.Linear(33, 34)
        # elif n_channels == 6: # hhar
        #     self.lin1 = nn.Identity()
        # elif n_channels == 3: # shar
        #     self.lin1 = nn.Linear(39, 40)
    
        self.lin1 = nn.Linear(length//8, 40)

        self.unpool2 = nn.MaxUnpool1d(kernel_size=2, stride=2, padding=0)
        self.d_conv2 = nn.Sequential(nn.ConvTranspose1d(64, 32, kernel_size=5, stride=1, bias=False, padding=2),
                                     nn.BatchNorm1d(32),
                                     nn.ReLU())

        self.unpool3 = nn.MaxUnpool1d(kernel_size=2, stride=2, padding=0)
        self.d_conv3 = nn.Sequential(nn.ConvTranspose1d(32, n_channels, kernel_size=5, stride=1, bias=False, padding=2),
                                     nn.BatchNorm1d(n_channels),
                                     nn.ReLU())

        self.lin2 = nn.Linear(length//2, length)

        # if n_channels == 9: # ucihar
        #     self.lin2 = nn.Linear(127, 128)
        #     self.out_dim = 18 * out_channels
        # elif n_channels == 6:  # hhar
        #     self.lin2 = nn.Linear(99, 100)
        #     self.out_dim = 15 * out_channels
        # elif n_channels == 3: # shar
        #     self.out_dim = 21 * out_channels

        # if backbone == False:
        #     self.classifier = nn.Linear(self.out_dim, n_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        print(x.shape)
        x, indice1 = self.pool1(self.e_conv1(x))
        print(x.shape)
        x = self.dropout(x)
        x, indice2 = self.pool2(self.e_conv2(x))
        print(x.shape)
        x_encoded, indice3 = self.pool3(self.e_conv3(x))
        print(x_encoded.shape)
        # return x_encoded
        x = self.d_conv1(self.unpool1(x_encoded, indice3))
        # x = self.unpool1(x_encoded, indice3)
        print(x.shape)
        # x = self.d_conv1(x)
        # print(x.shape)
        # x = self.unpool1(x_encoded, indice3)
        # x = self.d_conv1(x)
        # x = self.lin1(x)
        print(self.padding1)
        print(self.padding2)
        print(self.padding3)
        if self.padding1:
            m = nn.ConstantPad1d((0, 1), 1)
            x  = m(x)
            print('Paddinbg?')
            print(x.shape)
        x = self.d_conv2(self.unpool2(x, indice2))
        
        print('Hereeeeeee?')
        print(x.shape)
        
        
        if self.padding2:
            m = nn.ConstantPad1d((0, 1), 1)
            x  = m(x)
            print('Paddinbg?')
            print(x.shape)
        
        print('Here?')
            
        x = self.d_conv3(self.unpool3(x, indice1))
        if self.padding3:
            m = nn.ConstantPad1d((0, 1), 1)
            x  = m(x)
            print('Paddinbg?')
            print(x.shape)
            
        print(x.shape)
        # if self.n_channels == 9: # ucihar
        #     x_decoded = self.lin2(x)
        # elif self.n_channels == 6 : # hhar
        # elif self.n_channels == 3: # shar
        #     x_decoded = x
        # x_decoded =self.lin2(x)
        # x_decoded = x_decoded.permute(0, 2, 1)
        x_encoded = x_encoded.reshape(x_encoded.shape[0], -1)
        
        return x_encoded

        # if self.backbone:
        #     return x_decoded, x_encoded
        # else:
        #     out = self.classifier(x_encoded)
        #     return out, x_decoded


In [26]:
length = 28
channels = 5
x =  torch.zeros([32, length, channels])
ae = CNN_AE( channels, length, 5, out_channels=128, backbone=True)

28
0
0
4


In [25]:
enc = ae(x)
print(enc.shape)
# 364 % 4 != 0

torch.Size([32, 5, 365])
torch.Size([32, 32, 182])
torch.Size([32, 64, 91])
torch.Size([32, 128, 45])
torch.Size([32, 64, 90])
True
False
True
Paddinbg?
torch.Size([32, 64, 91])
Hereeeeeee?
torch.Size([32, 32, 182])
Here?
Paddinbg?
torch.Size([32, 5, 365])
torch.Size([32, 5, 365])
torch.Size([32, 5760])
