In [11]:
import sys
sys.path.insert(0, '..')
%load_ext autoreload
%autoreload 2

from libs.ssl_task import CPC
from libs.ssl_data import SSLHBNDataModule
import torch
import lightning as L

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# load config from runs/config_CPC.yaml
import yaml
with open('../runs/config_CPC.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [18]:
torch.manual_seed(0)  
ssl_task = CPC()
config['data']['ssl_task'] = ssl_task
config['data']['window_len_s'] = 2
config['data']['num_workers'] = 2
litDataModule = SSLHBNDataModule(**config['data'])
litDataModule.setup(stage='fit')

In [23]:
from libs.ssl_task import LitSSL
from libs.ssl_model import BENDRLSTM, BENDRContextualizer
from torch.nn import functional as F
import numpy as np
class CPCLit(LitSSL):
    # Repurpose from https://github.com/SPOClab-ca/BENDR/blob/ac918abaec111d15fcaa2a8fcd2bd3d8b0d81a10/dn3_ext.py#L232
    def __init__(self, downsampling_factor=96, mask_rate=0.1, mask_span=6, temp=0.1,
                permuted_encodings=False, permuted_contexts=False, enc_feat_l2=0.001,
                unmasked_negative_frac=0.25, num_negatives=5, **kwargs):
        super().__init__(**kwargs)
        self.contextualizer = BENDRLSTM(
            in_features=self.encoder_emb_size,
        )
        # self.contextualizer = BENDRContextualizer(
        #         in_features=self.encoder_emb_size,
        #         start_token=None,
        #     )
        self.predict_length = mask_span
        self._enc_downsample = downsampling_factor
        self.mask_rate = mask_rate
        self.mask_span = mask_span
        self.temp = temp
        self.permuted_encodings = permuted_encodings
        self.permuted_contexts = permuted_contexts
        self.beta = enc_feat_l2
        self.start_token = getattr(self.contextualizer, 'start_token', None)
        self.unmasked_negative_frac = unmasked_negative_frac
        self.num_negatives = num_negatives
    
    def _calculate_similarity_cross_batch(self, true_z, c):
        B, feat, seq_len = true_z.shape
        assert c.shape == true_z.shape, f"c {c.shape} and true_z {true_z.shape} should be the same shape"
        true_z = true_z.permute([0, 2, 1]) # (B, seq_len, F)
        c = c.permute([0, 2, 1]) # (B, seq_len, F)

        positives = F.cosine_similarity(c, true_z, dim=-1) / self.temp # (B, seq_len)
        # create negative batch by randomize batch elements
        negatives_batch_ind = torch.randint(0, B, (B,), device=true_z.device)
        negatives_batch = true_z[negatives_batch_ind] # (B, seq_len, F)
        assert negatives_batch.shape == true_z.shape, f"negatives_batch {negatives_batch.shape} should be the same shape as true_z {true_z.shape}"
        negatives_seq_ind = torch.randint(0, seq_len, (B, seq_len, self.num_negatives), device=true_z.device)
        negatives_batch = torch.gather(negatives_batch.unsqueeze(2).expand(-1, -1, self.num_negatives, -1),  # Expand to (B, seq_len, num_negative, F)
                                dim=1,  # Sample along the time dimension (originally dimension 1)
                                index=negatives_seq_ind.unsqueeze(-1).expand(-1, -1, -1, feat))
        negatives = F.cosine_similarity(c.unsqueeze(-2), negatives_batch, dim=-1) / self.temp # (B, seq_len, num_negatives)

        # assert positives.shape == negatives.shape, f"positives {positives.shape} and negatives {negatives.shape} should be the same shape"
        # assert not torch.allclose(positives, negatives), f"positives {positives.shape} and negatives {negatives.shape} should not be the same"

        # negatives = negatives.unsqueeze(-1).expand(-1, -1, self.num_negatives)
        assert negatives.shape == (B, seq_len, self.num_negatives), f"negatives {negatives.shape} should be (B, seq_len, num_negatives)"

        positives = positives.unsqueeze(-1) # (B, seq_len, 1)
        assert positives.shape == (B, seq_len, 1), f"positives {positives.shape} should be (B, seq_len, 1)"

        logits = torch.cat([positives, negatives], dim=-1) # (B, seq_len, 1+num_negatives)
        logits = logits * 10

        return logits.view(-1, logits.shape[-1]) # flatten B x seq_len. Last dim correspond to torch CrossEntropyLoss C 
                                                    # --> will have true class label 0

    def training_step(self, batch, batch_idx):
        z = self.encoder(batch[0])
        # z - (B, F, seq_len)
        # print('z shape', z.shape)

        unmasked_z = z.clone()
        
        # mask = self._make_mask((batch_size, samples), self.mask_rate, samples, self.mask_span)
        # make simple mask: only predict the last token
        # mask = torch.zeros((batch_size, samples), dtype=torch.bool)
        # mask[:, -1] = True

        c = self.contextualizer(z, None)
        # c - (B, F, seq_len+1) ?
        # print('c shape', c.shape)

        # logits = self._calculate_similarity(unmasked_z, c, negatives)
        # logits - (B x seq_len, 1+num_negatives)
        # fake task
        # populate values of each row of logits with the mean of the row keeping number of columns
        # logits = logits.mean(dim=1, keepdim=True).repeat(logits.shape[0], self.num_negatives + 1)
        
        logits = self._calculate_similarity_cross_batch(unmasked_z, c)
        # with torch.no_grad():
        #     print('logits[:5]', list(zip(logits[:5], logits[:5].argmax(dim=-1))))

        labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        # labels - (B x seq_len) all 0s
        # print('labels', labels)

        loss = torch.nn.functional.cross_entropy(logits, labels)
        # print('loss', loss)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        pass

    def on_validation_epoch_end(self):
        pass

In [None]:
config['model']['learning_rate'] = 0.005
cpc_model = CPCLit(downsampling_factor=96, **config['model']['init_args'])
print(cpc_model)

CPCLit(
  (encoder): ConvEncoderBENDR(
    (encoder): Sequential(
      (Encoder_0): Sequential(
        (0): Conv1d(129, 512, kernel_size=(3,), stride=(3,), padding=(1,))
        (1): Dropout1d(p=0.1, inplace=False)
        (2): GroupNorm(256, 512, eps=1e-05, affine=True)
        (3): GELU(approximate='none')
      )
      (Encoder_1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): Dropout1d(p=0.1, inplace=False)
        (2): GroupNorm(256, 512, eps=1e-05, affine=True)
        (3): GELU(approximate='none')
      )
      (Encoder_2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): Dropout1d(p=0.1, inplace=False)
        (2): GroupNorm(256, 512, eps=1e-05, affine=True)
        (3): GELU(approximate='none')
      )
      (Encoder_3): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
        (1): Dropout1d(p=0.1, inplace=False)
        (2): GroupNor

In [25]:
config['trainer']['callbacks'] = None
config['trainer']['logger'] = None
config['trainer']['overfit_batches'] = 0.01
config['trainer']['fast_dev_run'] = False
config['trainer']['max_epochs'] = 100
trainer = L.Trainer(**config['trainer'])
trainer.fit(model=cpc_model, datamodule=litDataModule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Using datasets: ['ds005505', 'ds005506', 'ds005507', 'ds005508', 'ds005509', 'ds005510', 'ds005511', 'ds005512', 'ds005514', 'ds005515', 'ds005516']
Validation release: ds005505


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | encoder        | ConvEncoderBENDR | 4.1 M  | train
1 | contextualizer | BENDRLSTM        | 6.3 M  | train
------------------------------------------------------------
10.4 M    Trainable params
0         Non-trainable params
10.4 M    Total params
41.775    Total estimated model params size (MB)
34        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Number of datasets: 2698
Number of examples: 549299


/home/dung/eeg-ssl/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:252: You requested to overfit but enabled train dataloader shuffling. We are turning off the train dataloader shuffling for you.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
