In [1]:
import sys
sys.path.append('../..')

In [2]:
import logging
import numpy as np

import torch
from emgrep.datasets.EMGRepDataloader import EMGRepDataloader

In [3]:
logging.basicConfig(level=logging.INFO)

In [4]:
data_selection = [(subject, day, time) for subject in [1, 2] for day in [1, 2] for time in [1, 2]]

emgrep_ds = EMGRepDataloader(
    data_path='../../data/01_raw/',
    train_data=data_selection,
    positive_mode='subject',
)

In [5]:
emgrep_dl, _, _ = emgrep_ds.get_dataloaders()

INFO:root:Loading train dataset...
100%|██████████| 8/8 [00:12<00:00,  1.59s/it]


In [6]:
len(emgrep_dl)

3400

In [7]:
batch = next(iter(emgrep_dl))

In [8]:
x, y = batch[0], batch[1]

In [9]:
x.shape, y.shape

(torch.Size([1, 2, 10, 300, 16]), torch.Size([1, 2, 10, 300, 1]))

In [10]:
# simple cnn encoder network which takes in a sequence of blocks and outputs a single vector 
# containing the encoded representation of the sequence

# all blocks of size N x 512 x F or smaller should be mapped to a single vector of size N x 1 x H
# where H is the hidden size of the encoder network

from torch import nn

class EncoderNetwork(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Compute the highest power of 2 less than or equal to out_channels
        max_power = int(np.log(out_channels) / np.log(2))
        result = [in_channels] + [2**f for f in range(5, max_power + 1)]

        self.convs = nn.ModuleList([
            block
            for in_channels, out_channels in zip(result[:-1], result[1:])
            for block in [self.block(in_channels, in_channels), self.block(in_channels, out_channels)]
        ])

        self.output_conv = nn.AdaptiveAvgPool1d(1)

    def block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels),
            nn.MaxPool1d(kernel_size=2, padding=1),
         )

    def forward(self, x):

        BS, K, NB, L, C = x.shape
        # x is of shape N x 2 x num_blocks x F x block_len
        # reshape to N * 2 * num_blocks x F x block_len
        logging.info(f"Input:   {x.shape}")
        x = x.reshape(-1, C, L)
        logging.info(f"Reshape: {x.shape}")

        for i, conv in enumerate(self.convs):
            x = conv(x)
            logging.info(f"Conv {i}:  {x.shape}")

        x = self.output_conv(x)
        logging.info(f"Out:     {x.shape}")

        # reshape to have shape N x 2 x num_blocks x H
        x = x.reshape(BS, K, NB, -1)
        logging.info(f"Reshape: {x.shape}")


        return x

In [69]:
in_channels = 16
hidden_dim = 256

encoder = EncoderNetwork(in_channels, hidden_dim)

n_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
logging.info(f"Number of trainable parameters: {n_params / 1e6} M")

INFO:root:Number of trainable parameters: 0.198 M


In [70]:
random_input = torch.randn(3, 2, 10, 512, in_channels)
features = encoder(random_input)

In [71]:
random_input = torch.randn(3, 2, 10, 300, in_channels)
features = encoder(random_input)

In [72]:
# Refactored version of the encoder
import math

import torch.nn as nn

class CPCEncoder(nn.Module):
    """Encoder network for CPC."""

    def __init__(self, in_channels: int, hidden_dim: int):
        """Encoder network for encoding a sequence of blocks into a single vector.

        Args:
            in_channels (int): Number of input channels.
            hidden_dim (int): Feature dimension of the output vector for each block. Will be
            rounded to the next power of 2.
        """
        super().__init__()

        max_power = int(math.log(hidden_dim, 2))
        result = [in_channels] + [2**f for f in range(5, max_power + 1)]

        self.convs = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.BatchNorm1d(in_channels),
                    nn.MaxPool1d(kernel_size=2, padding=1),
                    nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
                    nn.ReLU(),
                    nn.BatchNorm1d(out_channels),
                    nn.MaxPool1d(kernel_size=2, padding=1),
                )
                for in_channels, out_channels in zip(result[:-1], result[1:])
            ]
        )

        # output_conv is used to map the time dimension to a single value
        # -> each block will be mapped to a feature with dimension hidden_dim
        self.output_conv = nn.AdaptiveAvgPool1d(1)


    def forward(self, x):
        """Forward pass.

        Args:
            x (torch.Tensor): Input tensor of shape (N, K, num_blocks, block_len, F).
        """
        N, K, num_blocks, block_len, F = x.shape
        x = x.view(N * K * num_blocks, F, block_len)

        x = self.convs(x)
        x = self.output_conv(x)

        x = x.view(N, K, num_blocks, -1)

        return x


In [73]:
encoder = CPCEncoder(in_channels, hidden_dim)

n_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)

print(f"Number of trainable parameters: {n_params / 1e6:.3f} M")

Number of trainable parameters: 0.198 M


In [74]:
random_input = torch.randn(3, 2, 10, 300, in_channels)

out = encoder(random_input)

out.shape

torch.Size([3, 2, 10, 256])

In [75]:
# print model summary
from torchsummary import summary

summary(encoder, input_size=(2, 10, 300, in_channels), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 16, 300]             784
              ReLU-2              [-1, 16, 300]               0
       BatchNorm1d-3              [-1, 16, 300]              32
         MaxPool1d-4              [-1, 16, 151]               0
            Conv1d-5              [-1, 32, 151]           1,568
              ReLU-6              [-1, 32, 151]               0
       BatchNorm1d-7              [-1, 32, 151]              64
         MaxPool1d-8               [-1, 32, 76]               0
            Conv1d-9               [-1, 32, 76]           3,104
             ReLU-10               [-1, 32, 76]               0
      BatchNorm1d-11               [-1, 32, 76]              64
        MaxPool1d-12               [-1, 32, 39]               0
           Conv1d-13               [-1, 64, 39]           6,208
             ReLU-14               [-1,

In [76]:
class CPCAR(nn.Module):
    """Autoregressive model for CPC."""

    def __init__(self, dimEncoded: int, dimOutput: int, numLayers: int):
        """Initialize the autoregressive model.

        Args:
            dimEncoded (int): Encoded dimension.
            dimOutput (int): Output dimension.
            numLayers (int): Number of layers.
        """
        super(CPCAR, self).__init__()

        self.gru = nn.GRU(dimEncoded, dimOutput, num_layers=numLayers, batch_first=True)

    def forward(self, x):
        """Encode a batch of sequences."""
        N, K, num_blocks, H = x.shape
        x = x.view(N * K, num_blocks, H)
        x, _ = self.gru(x)  # discard final hidden state

        x = x.view(N, K, num_blocks, -1)
        return x
    

ar = CPCAR(hidden_dim, hidden_dim, 2)

n_params = sum(p.numel() for p in ar.parameters() if p.requires_grad)

print(f"Number of trainable parameters: {n_params / 1e6:.3f} M")

Number of trainable parameters: 0.790 M


In [77]:
sample = torch.randn(3, 2, 10, hidden_dim)

out = ar(sample)

out.shape

torch.Size([3, 2, 10, 256])

In [78]:
class CPCModel(nn.Module):
    """CPC model."""

    def __init__(self, encoder: CPCEncoder, ar: CPCAR):
        """Initialize the CPC model.

        Args:
            encoder (_type_): _description_
            ar (_type_): _description_
        """
        super(CPCModel, self).__init__()
        self.gEnc = encoder
        self.gAR = ar
        # self.gEnc.double()
        # self.gAR.double()

    def forward(self, batch):
        """Forward pass."""
        z = self.gEnc(batch)
        c = self.gAR(z)
        return z, c
    
model = CPCModel(encoder, ar)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters: {n_params / 1e6:.3f} M")

Number of trainable parameters: 0.988 M


In [80]:
sample = torch.randn(3, 2, 10, 300, in_channels)

c = encoder(sample)

z = ar(c)

c.shape, z.shape

(torch.Size([3, 2, 10, 256]), torch.Size([3, 2, 10, 256]))

In [81]:
z, c = model(sample)

z.shape, c.shape

(torch.Size([3, 2, 10, 256]), torch.Size([3, 2, 10, 256]))

In [82]:
summary(model, input_size=(2, 10, 300, in_channels), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 16, 300]             784
              ReLU-2              [-1, 16, 300]               0
       BatchNorm1d-3              [-1, 16, 300]              32
         MaxPool1d-4              [-1, 16, 151]               0
            Conv1d-5              [-1, 32, 151]           1,568
              ReLU-6              [-1, 32, 151]               0
       BatchNorm1d-7              [-1, 32, 151]              64
         MaxPool1d-8               [-1, 32, 76]               0
            Conv1d-9               [-1, 32, 76]           3,104
             ReLU-10               [-1, 32, 76]               0
      BatchNorm1d-11               [-1, 32, 76]              64
        MaxPool1d-12               [-1, 32, 39]               0
           Conv1d-13               [-1, 64, 39]           6,208
             ReLU-14               [-1,

In [83]:
from emgrep.models.cpc_model import CPCModel, CPCEncoder, CPCAR

encoder = CPCEncoder(in_channels, hidden_dim)
ar = CPCAR(hidden_dim, hidden_dim, 2)
model = CPCModel(encoder, ar)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters: {n_params / 1e6:.3f} M")

Number of trainable parameters: 0.988 M


In [86]:
sample = torch.randn(3, 2, 10, 300, in_channels)

z, c = model(sample)

print(f"sample.shape: {sample.shape}")
print(f"z.shape:      {z.shape}")
print(f"c.shape:      {c.shape}")

sample.shape: torch.Size([3, 2, 10, 300, 16])
z.shape:      torch.Size([3, 2, 10, 256])
c.shape:      torch.Size([3, 2, 10, 256])
