In [1]:
from pathlib import Path
from typing import List, Tuple, Any

import numpy as np
import pandas as pd
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import plotly.express as px
import plotly.graph_objects as go

from ssl_tools.transforms import *

from typing import Any
from torch.nn import TransformerEncoder, TransformerEncoderLayer

  @numba.jit()
  @numba.jit()
  @numba.jit()
  from .autonotebook import tqdm as notebook_tqdm
  @numba.jit()


In [2]:
from typing import Union


class TFCContrastiveDataset(Dataset):
    def __init__(
        self,
        data: torch.Tensor,
        labels: torch.Tensor = None,
        length_alignment: int = 178,
        time_transforms: Union[Transform, List[Transform]] = None,
        frequency_transforms: Union[Transform, List[Transform]] = None,
    ):
        assert len(data) == len(labels), "Data and labels must have the same length"
        
        self.data_time = data
        self.labels = labels
        self.length_alignment = length_alignment
        self.time_transforms = time_transforms or []
        self.frequency_transforms = frequency_transforms or []
        
        if not isinstance(self.time_transforms, list):
            self.time_transforms = [self.time_transforms]
        if not isinstance(self.frequency_transforms, list):
            self.frequency_transforms = [self.frequency_transforms]

        if len(self.data_time.shape) < 3:
            self.data_time = self.data_time.unsqueeze(2)

        if self.data_time.shape.index(min(self.data_time.shape)) != 1:
            self.data_time = self.data_time.permute(0, 2, 1)

        """Align the data to the same length, removing the extra features"""
        self.data_time = self.data_time[:, :1, : self.length_alignment]
        
        """Calculcate the FFT of the data and apply the transforms (if any)"""
        self.data_freq = torch.fft.fft(self.data_time).abs()
        
        # This could be done in the __getitem__ method
        # For now, we do it here to be more similar to the original implementation
        self.data_time_augmented = self.apply_transforms(self.data_time, self.time_transforms)
        self.data_freq_augmented = self.apply_transforms(self.data_freq, self.frequency_transforms)
        
    def apply_transforms(self, x: torch.Tensor, transforms: List[Transform]) -> torch.Tensor:
        for transform in transforms:
            x = transform.fit_transform(x)
        return x
        
    def __len__(self):
        return len(self.data_time)
    
    def __getitem__(self, index):
        # Time processing
        return (
            self.data_time[index].float(),
            self.labels[index],
            self.data_time_augmented[index].float(),
            self.data_freq[index].float(),
            self.data_freq_augmented[index].float(),
        )

# TF-C Pre-train

In [3]:
data_path = Path("data/TFC/SleepEEG")

dataset = torch.load(data_path / "train.pt")
X_train, y_train = dataset["samples"], dataset["labels"]
X_train.shape

torch.Size([371055, 1, 178])

In [4]:
jitter_ratio = 2
length_alignment = 178
drop_last = True
batch_size = 128
num_workers = 10
learning_rate = 3e-4
temperature = 0.2
use_cosine_similarity = True

In [5]:
time_transforms = [
    AddGaussianNoise(std=jitter_ratio)
]

frequency_transforms = [
    AddRemoveFrequency()
]

train_dataset = TFCContrastiveDataset(
    data=X_train,
    labels=y_train,
    time_transforms=time_transforms,
    frequency_transforms=frequency_transforms,
)

len(train_dataset[0])

5

In [6]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=drop_last,
    num_workers=num_workers
)

test_batch = next(iter(train_loader))
len(test_batch), test_batch[0].shape, test_batch[1].shape, test_batch[2].shape, test_batch[3].shape, test_batch[4].shape

(5,
 torch.Size([128, 1, 178]),
 torch.Size([128]),
 torch.Size([128, 1, 178]),
 torch.Size([128, 1, 178]),
 torch.Size([128, 1, 178]))

In [7]:
class NTXentLoss_poly(torch.nn.Module):
    def __init__(
        self,
        batch_size,
        temperature: float = 0.2,
        use_cosine_similarity: bool = True,
        device: str = "cpu",
    ):
        super(NTXentLoss_poly, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(
            2 * self.batch_size, -1
        )

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        """Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        CE = self.criterion(logits, labels)

        onehot_label = (
            torch.cat(
                (
                    torch.ones(2 * self.batch_size, 1),
                    torch.zeros(2 * self.batch_size, negatives.shape[-1]),
                ),
                dim=-1,
            )
            .to(self.device)
            .long()
        )
        # Add poly loss
        pt = torch.mean(onehot_label * torch.nn.functional.softmax(logits, dim=-1))

        epsilon = self.batch_size
        # loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
        loss = CE / (2 * self.batch_size) + epsilon * (1 / self.batch_size - pt)
        # loss = CE / (2 * self.batch_size)

        return loss


In [8]:
class TFC(pl.LightningModule):
    def __init__(
        self,
        time_encoder: nn.Module,
        frequency_encoder: nn.Module,
        time_projector: nn.Module,
        frequency_projector: nn.Module,
        nxtent_criterion: nn.Module,
        lr: float = 1e-3,
        loss_lambda: float = 0.2,
    ):
        super().__init__()

        self.time_encoder = time_encoder.to(self.device)
        self.time_projector = time_projector.to(self.device)
        self.frequency_encoder = frequency_encoder.to(self.device)
        self.frequency_projector = frequency_projector.to(self.device)
        self.learning_rate = lr
        self.nxtent_criterion = nxtent_criterion.to(self.device)
        self.loss_lambda = loss_lambda

    def forward(self, x_in_t, x_in_f):
        """Use Transformer"""
        x = self.time_encoder(x_in_t)
        h_time = x.reshape(x.shape[0], -1)

        """Cross-space projector"""
        z_time = self.time_projector(h_time)

        """Frequency-based contrastive encoder"""
        f = self.frequency_encoder(x_in_f)
        h_freq = f.reshape(f.shape[0], -1)

        """Cross-space projector"""
        z_freq = self.frequency_projector(h_freq)

        return h_time, z_time, h_freq, z_freq

    def configure_optimizers(self) -> Any:
        learnable_parameters = (
            list(self.time_encoder.parameters()) +
            list(self.time_projector.parameters()) +
            list(self.frequency_encoder.parameters()) +
            list(self.frequency_projector.parameters())
        )
        optimizer = torch.optim.Adam(learnable_parameters, lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        data, labels, aug1, data_f, aug1_f = batch
        
        """Producing embeddings"""
        h_t, z_t, h_f, z_f = self(data, data_f)
        h_t_aug, z_t_aug, h_f_aug, z_f_aug = self(aug1, aug1_f)
        
        """Calculate losses"""
        loss_time = self.nxtent_criterion(h_t, h_t_aug)
        loss_freq = self.nxtent_criterion(h_f, h_f_aug)
        loss_consistency = self.nxtent_criterion(z_t, z_f)
        loss = self.loss_lambda * (loss_time + loss_freq) + loss_consistency
        
        # log loss, only to appear on epoch
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss


In [9]:
time_encoder = TransformerEncoder(
    TransformerEncoderLayer(
        length_alignment, dim_feedforward=2 * length_alignment, nhead=2
    ),
    num_layers=2,
)
frequency_encoder = TransformerEncoder(
    TransformerEncoderLayer(
        length_alignment, dim_feedforward=2 * length_alignment, nhead=2
    ),
    num_layers=2,
)

time_projector = nn.Sequential(
    nn.Linear(length_alignment, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256, 128),
)
frequency_projector = nn.Sequential(
    nn.Linear(length_alignment, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(),
    nn.Linear(256, 128),
)

nxtent = NTXentLoss_poly(
    batch_size=batch_size,
    temperature=temperature,
    use_cosine_similarity=use_cosine_similarity,
    device="cuda",
)

tfc_model = TFC(
    time_encoder=time_encoder,
    frequency_encoder=frequency_encoder,
    time_projector=time_projector,
    frequency_projector=frequency_projector,
    nxtent_criterion=nxtent,
    lr=learning_rate,
)




In [10]:
trainer = pl.Trainer(max_epochs=1, accelerator="gpu", devices=1, limit_train_batches=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
trainer.fit(tfc_model, train_loader)

You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | time_encoder        | TransformerEncoder | 510 K 
1 | time_projector      | Sequential         | 79.2 K
2 | frequency_encoder   | TransformerEncoder | 510 K 
3 | frequency_projector | Sequential         | 79.2 K
4 | nxtent_criterion    | NTXentLoss_poly    | 0     
-----------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.721     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0: 100%|██████████| 10/10 [00:01<00:00,  7.14it/s, v_num=94, train_loss_step=8.520, train_loss_epoch=8.870]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 10/10 [00:01<00:00,  6.32it/s, v_num=94, train_loss_step=8.520, train_loss_epoch=8.870]


# TF-C Fine-Tune 

In [12]:
data_path = Path("data/TFC/Epilepsy")

dataset_train = torch.load(data_path / "train.pt")
X_train, y_train = dataset["samples"], dataset["labels"]

dataset_validation = torch.load(data_path / "val.pt")
X_validation, y_validation = dataset["samples"], dataset["labels"]

dataset_test = torch.load(data_path / "test.pt")
X_test, y_test = dataset["samples"], dataset["labels"]

In [13]:
batch_size = 60
n_classes = 5

In [14]:
train_dataset = TFCContrastiveDataset(
    data=X_train,
    labels=y_train,
    time_transforms=None,
    frequency_transforms=None,
)

validation_dataset = TFCContrastiveDataset(
    data=X_validation,
    labels=y_validation,
    time_transforms=None,
    frequency_transforms=None,
)

test_dataset = TFCContrastiveDataset(
    data=X_test,
    labels=y_test,
    time_transforms=None,
    frequency_transforms=None,
)

In [15]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=drop_last,
    num_workers=num_workers
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=drop_last,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=drop_last,
    num_workers=num_workers
)

In [16]:

from typing import Any

from torchmetrics.functional import accuracy


class SimpleClassifier(torch.nn.Module):
    def __init__(self, num_classes: int = 2):
        super().__init__()
        self.num_classes = num_classes
        self.fc = torch.nn.Linear(2 * 128, 64)
        self.fc2 = torch.nn.Linear(64, num_classes)

    def forward(self, x):
        emb_flatten = x.reshape(x.shape[0], -1)
        x = self.fc(emb_flatten)
        x = torch.sigmoid(x)
        y = self.fc2(x)
        return y


class TFC_classifier(pl.LightningModule):
    def __init__(
        self,
        tfc_model: torch.nn.Module,
        classifier: torch.nn.Module,
        nxtent_criterion: nn.Module,
        lr: float = 1e-3,
        loss_lambda: float = 0.1,
        n_classes: int = 2,
    ):
        super().__init__()
        self.tfc_model = tfc_model
        self.classifier = classifier
        self.nxtent_criterion = nxtent_criterion
        self.learning_rate = lr
        self.n_classes = n_classes
        self.loss_lambda = loss_lambda
        self.loss_func = torch.nn.CrossEntropyLoss()

    def configure_optimizers(self) -> Any:
        learnable_parameters = list(self.tfc_model.parameters()) + list(
            self.classifier.parameters()
        )
        optimizer = torch.optim.Adam(learnable_parameters, lr=self.learning_rate)
        return optimizer

    def forward(self, x_in_t, x_in_f):
        return self.tfc_model(x_in_t, x_in_f)

    def training_step(self, batch, batch_idx):
        data, labels, aug1, data_f, aug1_f = batch

        """Producing embeddings"""
        h_t, z_t, h_f, z_f = self(data, data_f)
        h_t_aug, z_t_aug, h_f_aug, z_f_aug = self(aug1, aug1_f)

        """Add supervised loss"""
        fea_concat = torch.cat((z_t, z_f), dim=1)
        predictions = self.classifier(fea_concat)
        # fea_concat_flat = fea_concat.reshape(fea_concat.shape[0], -1)

        """Calculate losses"""
        loss_time = self.nxtent_criterion(h_t, h_t_aug)
        loss_freq = self.nxtent_criterion(h_f, h_f_aug)
        loss_consistency = self.nxtent_criterion(z_t, z_f)
        loss_p = self.loss_func(predictions, labels)
        loss = loss_p + self.loss_lambda * (loss_time + loss_freq) + loss_consistency

        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "val_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"val_loss": loss, "val_acc": acc}

    def test_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        self.log(
            "test_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        self.log(
            "test_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return {"test_loss": loss, "test_acc": acc}

    def _shared_eval_step(self, batch, batch_idx):
        data, labels, aug1, data_f, aug1_f = batch

        """Producing embeddings"""
        h_t, z_t, h_f, z_f = self(data, data_f)
        h_t_aug, z_t_aug, h_f_aug, z_f_aug = self(aug1, aug1_f)
        
        # print(h_t.shape, z_t.shape, h_f.shape, z_f.shape, h_t_aug.shape, z_t_aug.shape, h_f_aug.shape, z_f_aug.shape)
        loss_time = self.nxtent_criterion(h_t, h_t_aug)
        loss_freq = self.nxtent_criterion(h_f, h_f_aug)
        loss_consistency = self.nxtent_criterion(z_t, z_f)

        """Add supervised loss"""
        fea_concat = torch.cat((z_t, z_f), dim=1)
        predictions = self.classifier(fea_concat)
        loss_p = self.loss_func(predictions, labels)
        
        loss = loss_p + self.loss_lambda * (loss_time + loss_freq) + loss_consistency

        acc = accuracy(
            torch.argmax(predictions, dim=1),
            labels,
            task="multiclass",
            num_classes=self.n_classes,
        )

        return loss, acc

In [17]:
classifier = SimpleClassifier(num_classes=n_classes)
nxtent = NTXentLoss_poly(
    batch_size=batch_size,
    temperature=temperature,
    use_cosine_similarity=use_cosine_similarity,
    device="cuda",
)

tfc_classifier = TFC_classifier(
    tfc_model=tfc_model,
    classifier=classifier,
    nxtent_criterion=nxtent,
    lr=learning_rate,
    n_classes=n_classes,
)

In [18]:
trainer = pl.Trainer(max_epochs=1, accelerator="gpu", devices=1, limit_train_batches=10, limit_test_batches=10, limit_val_batches=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(tfc_classifier, train_loader, validation_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name             | Type             | Params
------------------------------------------------------
0 | tfc_model        | TFC              | 1.2 M 
1 | classifier       | SimpleClassifier | 16.8 K
2 | nxtent_criterion | NTXentLoss_poly  | 0     
3 | loss_func        | CrossEntropyLoss | 0     
------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.788     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 10/10 [00:02<00:00,  4.41it/s, v_num=95, train_loss_step=7.970, val_loss_step=8.830, val_acc_step=0.000, val_loss_epoch=8.430, val_acc_epoch=0.500, train_loss_epoch=8.050]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 10/10 [00:02<00:00,  4.08it/s, v_num=95, train_loss_step=7.970, val_loss_step=8.830, val_acc_step=0.000, val_loss_epoch=8.430, val_acc_epoch=0.500, train_loss_epoch=8.050]


In [20]:
trainer.test(tfc_classifier, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 37.32it/s]


[{'test_loss_epoch': 8.430386543273926, 'test_acc_epoch': 0.5}]