In [2]:
# Notebook to train the model
import torch
from torch import nn
from torch import optim


In [27]:
# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(out_channels),
        )
    def forward(self, x):
        print(x.shape)
        return x + self.block(x)

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,2), stride=(2,2)),
        )
    def forward(self, x):
        return self.encoder(x)
    
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2), mode='nearest'),
        )
    def forward(self, x):
        return self.decoder(x)
    
# Construct a model with 3 conv layers 3 residual blocks and 3 deconv layers using the ResNet architecture
class EMGModel(nn.Module):
    def __init__(self, n_classes=4):
        super(EMGModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,2), stride=(2,2)),

            nn.Conv2d(32, 64, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,2), stride=(2,2)),


            nn.Conv2d(64, 128, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3,2), stride=(2,2)),
        )

        self.resnet = nn.Sequential(
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2), mode='nearest'),

            nn.ConvTranspose2d(64, 32, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2), mode='nearest'),

            nn.ConvTranspose2d(32, 16, kernel_size=(3,2), stride=(1,1), padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2), mode='nearest'),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.resnet(x)
        x = self.decoder(x)
        return x
    


In [28]:
#test model
model = EMGModel()

#dummy data
x = torch.randn(1,1,1000,8)
x.shape

torch.Size([1, 1, 1000, 8])

In [29]:
y = model(x)
y.shape

torch.Size([1, 128, 124, 1])
torch.Size([1, 128, 124, 3])


RuntimeError: The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 3