In [1]:
import sys
sys.path.insert(0, '..')
%load_ext autoreload
%autoreload 2

from libs.ssl_task import CPC
from libs.ssl_data import SSLHBNDataModule
import torch
import lightning as L

In [2]:
# load config from runs/config_CPC.yaml
import yaml
with open('../runs/config_CPC.yaml', 'r') as f:
    config = yaml.safe_load(f)

In [3]:
ssl_task = CPC()
config['data']['ssl_task'] = ssl_task
config['data']['window_len_s'] = 20
config['data']['num_workers'] = 2
litDataModule = SSLHBNDataModule(**config['data'])

In [16]:
from libs.ssl_task import LitSSL
from libs.ssl_utils import instantiate_module
from libs.ssl_model import BENDRLSTM, BENDRContextualizer
from libs.evaluation import Regressor, RankMe
from torch.nn import functional as F
import numpy as np
from typing import Any, Optional, Union
class CPCLit(LitSSL):
    # Repurpose from https://github.com/SPOClab-ca/BENDR/blob/ac918abaec111d15fcaa2a8fcd2bd3d8b0d81a10/dn3_ext.py#L232
    def __init__(self, 
                contextualizer_path: str,
                contextualizer_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, 
                downsampling_factor=96, 
                mask_rate=0.1, mask_span=6, temp=0.1,
                permuted_encodings=False, permuted_contexts=False, enc_feat_l2=0.001,
                unmasked_negative_frac=0.25, num_negatives=20, **kwargs):
        super().__init__(**kwargs)
        # self.contextualizer = BENDRContextualizer(
        #     in_features=self.encoder_emb_size,
        #     start_token=None,
        # )
        self.contextualizer = instantiate_module(contextualizer_path, contextualizer_kwargs)
        # Initialize replacement vector with standard normal
        self.mask_replacement = torch.nn.Parameter(torch.normal(0, self.encoder_emb_size**(-0.5), size=(self.encoder_emb_size,)),
                                                requires_grad=True)

        self.predict_length = mask_span
        self._enc_downsample = downsampling_factor
        self.mask_rate = mask_rate
        self.mask_span = mask_span
        self.temp = temp
        self.permuted_encodings = permuted_encodings
        self.permuted_contexts = permuted_contexts
        self.beta = enc_feat_l2
        self.start_token = getattr(self.contextualizer, 'start_token', None)
        self.unmasked_negative_frac = unmasked_negative_frac
        self.num_negatives = num_negatives

        self.evaluators = [Regressor(projection_head=True)]
    
    def _generate_negatives(self, z):
        """Generate negative samples to compare each sequence location against"""
        batch_size, feat, full_len = z.shape
        with torch.no_grad():
            z_k = z.clone().permute([0, 2, 1]).reshape(-1, feat)
            negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * self.num_negatives))
            # From wav2vec 2.0 implementation, I don't understand
            # negative_inds[negative_inds >= candidates] += 1

            for i in range(1, batch_size):
                negative_inds[i] += i * full_len

            z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, self.num_negatives, feat)
            return z_k, negative_inds

    def _calculate_similarity(self, true_z, c, negatives):
        targets = true_z.permute([0, 2, 1]).unsqueeze(-2)
        # z - (B, seq_len, 1, F) - YT

        if self.start_token:
            c = c[..., 1:].permute([0, 2, 1]).unsqueeze(-2)
        else:
            c = c.permute([0, 2, 1]).unsqueeze(-2)
        # c - (B, seq_len, 1, F). First seq is added start token - YT
        # negatives - (B, seq_len, num_negatives, F) - YT
        predictions = torch.cat([c, negatives], dim=-2)
        # predictions - (B, seq_len, 1+num_negatives, F) - YT

        logits = F.cosine_similarity(targets, predictions, dim=-1) / self.temp # z is being broadcasted in the 3rd dimension - YT
        # logits - (B, seq_len, 1+num_negatives)

        return logits.view(-1, logits.shape[-1]) # flatten B x seq_len. Last dim correspond to torch CrossEntropyLoss C 
                                                    # --> will have true class label 0

    def compute_cross_batch_loss(self, true_z, c):
        B, feat, seq_len = true_z.shape
        assert c.shape == true_z.shape, f"c {c.shape} and true_z {true_z.shape} should be the same shape"
        true_z = true_z.permute([0, 2, 1]) # (B, seq_len, F)
        c = c.permute([0, 2, 1]) # (B, seq_len, F)

        positives = F.cosine_similarity(c, true_z, dim=-1) / self.temp # (B, seq_len)
        # create negative batch by randomize batch elements
        negatives_batch_ind = torch.randint(0, B, (B,), device=true_z.device)
        negatives_batch = true_z[negatives_batch_ind] # (B, seq_len, F)
        assert negatives_batch.shape == true_z.shape, f"negatives_batch {negatives_batch.shape} should be the same shape as true_z {true_z.shape}"
        negatives_seq_ind = torch.randint(0, seq_len, (B, seq_len, self.num_negatives), device=true_z.device)
        negatives_batch = torch.gather(negatives_batch.unsqueeze(2).expand(-1, -1, self.num_negatives, -1),  # Expand to (B, seq_len, num_negative, F)
                                dim=1,  # Sample along the time dimension (originally dimension 1)
                                index=negatives_seq_ind.unsqueeze(-1).expand(-1, -1, -1, feat))
        negatives = F.cosine_similarity(c.unsqueeze(-2), negatives_batch, dim=-1) / self.temp # (B, seq_len, num_negatives)

        # assert positives.shape == negatives.shape, f"positives {positives.shape} and negatives {negatives.shape} should be the same shape"
        # assert not torch.allclose(positives, negatives), f"positives {positives.shape} and negatives {negatives.shape} should not be the same"

        # negatives = negatives.unsqueeze(-1).expand(-1, -1, self.num_negatives)
        assert negatives.shape == (B, seq_len, self.num_negatives), f"negatives {negatives.shape} should be (B, seq_len, num_negatives)"

        positives = positives.unsqueeze(-1) # (B, seq_len, 1)
        assert positives.shape == (B, seq_len, 1), f"positives {positives.shape} should be (B, seq_len, 1)"

        logits = torch.cat([positives, negatives], dim=-1) # (B, seq_len, 1+num_negatives)
        logits = logits * 10

        logits = logits.view(-1, logits.shape[-1]) # flatten B x seq_len. Last dim correspond to torch CrossEntropyLoss C 
                                                    # --> will have true class label 0
        
        labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
        # labels - (B x seq_len) all 0s
        # print('labels', labels)

        loss = torch.nn.functional.cross_entropy(logits, labels)

        return loss

    # def calculate_loss(self, inputs, outputs):
    #     logits = outputs[0]
    #     labels = torch.zeros(logits.shape[0], device=logits.device, dtype=torch.long)
    #     # Note the loss_fn here integrates the softmax as per the normal classification pipeline (leveraging logsumexp)
    #     return torch.nn.functional.cross_entropy(logits, labels) + self.beta * outputs[1].pow(2).mean()

    def _make_span_from_seeds(self, seeds, span, total=None):
        inds = list()
        for seed in seeds:
            for i in range(seed, seed + span):
                if total is not None and i >= total:
                    break
                elif i not in inds:
                    inds.append(int(i))
        return np.array(inds)

    def _make_mask(self, shape, p, total, span, allow_no_inds=False):
        # num_mask_spans = np.sum(np.random.rand(total) < p)
        # num_mask_spans = int(p * total)
        mask = torch.zeros(shape, requires_grad=False, dtype=torch.bool)

        for i in range(shape[0]):
            mask_seeds = list()
            while not allow_no_inds and len(mask_seeds) == 0 and p > 0:
                mask_seeds = np.nonzero(np.random.rand(total) < p)[0]

            mask[i, self._make_span_from_seeds(mask_seeds, span, total=total)] = True
        # mask - (B, seq_len)
        return mask

    def generate_negatives_from_batch(self, z):
        """Generate negatives from other samples in the batch"""
        batch_size, feat, full_len = z.shape
        with torch.no_grad():
            z_k = z.clone().permute([0, 2, 1]).reshape(-1, feat)
            negative_inds = torch.randint(0, full_len-1, size=(batch_size, full_len * self.num_negatives))
            # From wav2vec 2.0 implementation, I don't understand
            # negative_inds[negative_inds >= candidates] += 1

            for i in range(1, batch_size):
                negative_inds[i] += i * full_len

            z_k = z_k[negative_inds.view(-1)].view(batch_size, full_len, self.num_negatives, feat)
            return z_k, negative_inds

    def training_step(self, batch, batch_idx):
        z = self.encoder(batch[0])
        batch_size, feat, samples = z.shape
        # z - (B, F, seq_len)

        unmasked_z = z.clone()
        
        mask = None
        mask = self._make_mask((batch_size, samples), self.mask_rate, samples, self.mask_span)
        # make simple mask: only predict the last token
        # mask = torch.zeros((batch_size, samples), dtype=torch.bool)
        # mask[:, -1] = True

        if mask is not None:
            z.transpose(2, 1)[mask] = self.mask_replacement

        c = self.contextualizer(z)
        # c - (B, F, seq_len) 

        loss = self.compute_cross_batch_loss(unmasked_z, c)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        print('train_loss', loss)

        return loss
    
    def validation_step(self, batch, batch_idx):
        z = self.encoder(batch[0])
        Y, subjects = batch[1], batch[3]
        batch_size, feat, samples = z.shape
        # z - (B, F, seq_len)

        unmasked_z = z.clone()
        
        mask = None
        # mask = self._make_mask((batch_size, samples), self.mask_rate, samples, self.mask_span)
        # make simple mask: only predict the last token
        # mask = torch.zeros((batch_size, samples), dtype=torch.bool)
        # mask[:, -1] = True

        if mask is not None:
            z.transpose(2, 1)[mask] = self.mask_replacement

        c = self.contextualizer(z)
        
        loss = self.compute_cross_batch_loss(unmasked_z, c)

        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        for evaluator in self.evaluators:
            c_last = c[:, :, -1]
            evaluator.update((c_last, Y, subjects))

In [None]:
config['model']['init_args']['contextualizer_path'] = 'libs.ssl_model.BENDRContextualizer'
config['model']['init_args']['contextualizer_kwargs'] = {'in_features': 512, 'start_token': None}
config['model']['init_args']['learning_rate'] = 0.00005
cpc_model = CPCLit(downsampling_factor=96, **config['model']['init_args'])

config['trainer']['callbacks'] = None
config['trainer']['logger'] = None
config['trainer']['overfit_batches'] = 0.0
config['trainer']['fast_dev_run'] = False
config['trainer']['max_epochs'] = 1
config['trainer']['detect_anomaly'] = False # default: False
config['trainer']['gradient_clip_val'] = 0 # default: 0
trainer = L.Trainer(**config['trainer'])
trainer.fit(model=cpc_model, datamodule=litDataModule)

You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Using datasets: ['ds005505', 'ds005506', 'ds005507', 'ds005508', 'ds005509', 'ds005511', 'ds005512', 'ds005514', 'ds005515', 'ds005516']
Validation release: ds005505


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                | Params | Mode 
---------------------------------------------------------------
0 | encoder        | ConvEncoderBENDR    | 4.1 M  | train
1 | contextualizer | BENDRContextualizer | 153 M  | train
  | other params   | n/a                 | 512    | n/a  
---------------------------------------------------------------
157 M     Trainable params
0         Non-trainable params
157 M     Total params
629.230   Total estimated model params size (MB)
129       Modules in train mode
0         Modules in eval mode


learning rate 5e-05


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Number of datasets: 2567
Number of examples: 51234


Training: |          | 0/? [00:00<?, ?it/s]

train_loss tensor(7.6740, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(7.4982, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(3.7361, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(3.3492, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(2.6545, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(1.7920, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(1.5229, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(1.2385, device='cuda:0', grad_fn=<NllLossBackward0>)
train_loss tensor(0.9747, device='cuda:0', grad_fn=<NllLossBackward0>)



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined