In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install pytorch-lightning
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.0.3-py3-none-any.whl (720 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m720.6/720.6 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m53.2 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.8.0-py3-none-any.whl (20 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]>2021.06.0->pytorch-lightning)
  Downloading aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m57.7 MB/s[0m eta [36m0:00:00[0m
Collecting multidict

In [3]:
#############################################
import numpy as np


#############################################
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics

#############################################
import pytorch_lightning as pl

#############################################
import clip

#############################################
from PIL import Image



In [15]:
class LinearWeightBlock(nn.Module):
    def __init__(self, lenght_sequence, in_channels, n_head = 2) -> None:
        super().__init__()

        ##################################################
        # 1. Convolutional Block
        ##################################################

        self.in_channels = in_channels

        self.conv_11 = nn.Conv1d(
                                 in_channels=in_channels,
                                 out_channels=in_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 dilation=1,
                                 groups=1,
                                 bias=True
                            )
        self.conv_12 = nn.Conv1d(
                                 in_channels=in_channels,
                                 out_channels=in_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1,
                                 dilation=1,
                                 groups=1,
                                 bias=True
                            )
        self.conv_11_12 = nn.Conv1d(
                                    in_channels=in_channels,
                                    out_channels=in_channels//2,
                                    kernel_size=3,
                                    stride=2,
                                    padding=1,
                                    dilation=1,
                                    groups=1,
                                    bias=True
                            )


        # layer norm
        self.lenght_sequence = lenght_sequence

        self.layer_norm_1 = nn.LayerNorm(self.lenght_sequence)
        self.layer_norm_11_12 = nn.LayerNorm(self.lenght_sequence//2+1)


        ##################################################
        # 2. Linear Weighting Block
        ##################################################

        self.dim_embedding = self.lenght_sequence//2 + 1
        c = 1

        self.linear = nn.Linear(self.dim_embedding, self.dim_embedding//8 + c)

        ##################################################
        # 3. Transformer Encoder Block
        ##################################################

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=self.dim_embedding//8 + c, nhead=n_head), num_layers=2
        )

        ##################################################
        # 4. Residual Bilinear Block
        ##################################################

        self.bilinear = nn.Bilinear(self.dim_embedding//8 + c, self.dim_embedding//8 + c, self.dim_embedding//8 + c)

        self.dropout = nn.Dropout(0.1)


        ##################################################
        # - Layer Norm
        # - GELU
        ##################################################

        self.gelu = nn.GELU()

        self.layer_norm = nn.LayerNorm(self.dim_embedding//8 + c)

    def forward(self, x):

        # down sampling 1
        x_1 = self.gelu(self.layer_norm_1(self.conv_11(x)))
        x_2 = self.gelu(self.layer_norm_1(self.conv_12(x)))

        x = self.layer_norm_11_12(self.conv_11_12(x_1 + x_2))


        # linear weighting block
        x = self.linear(x)

        x = self.layer_norm(x)

        # transformer encoder block
        x_0 = self.transformer_encoder(x)

        # residual bilinear block
        x = self.bilinear(x_0 + self.dropout(x_0), x)
        x = self.gelu(x)

        return x


class Classifier(nn.Module):
    def __init__(self, lenght_sequence = 601, in_channels = 306, n_head = 2, n_class = 92) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.n_class = n_class

        self.linear_weight_block = LinearWeightBlock(lenght_sequence, in_channels, n_head)

        self.flatten = nn.Flatten(start_dim=1, end_dim=- 1)

        c = 1

        self.linear_transformation_1 = nn.Linear(((lenght_sequence//2)//8 + c) * (in_channels//2), lenght_sequence//32)
        self.linear_transformation_2 = nn.Linear(lenght_sequence//32, 512)

        self.layer_norm = nn.LayerNorm(512)

        self.gelu = nn.GELU()

        self.dropout = nn.Dropout(0.1)

        self.mlp = nn.Sequential(
            nn.Linear(512, lenght_sequence//16),
            nn.GELU(),
            nn.LayerNorm(lenght_sequence//16),
            nn.Linear(lenght_sequence//16, n_class)
        )

    def forward(self, x, x_feat):
        x = self.linear_weight_block(x)
        x = self.flatten(x)
        x = self.linear_transformation_1(x)
        x = self.gelu(x)
        x = self.linear_transformation_2(x)
        x = self.gelu(self.layer_norm(x))
        x = self.mlp(x_feat + self.dropout(x))
        return x


class LitClassifier(pl.LightningModule):
    def __init__(self, lenght_sequence = 601, in_channels = 306, n_head = 2, n_class = 92) -> None:
        super().__init__()

        self.model = Classifier(lenght_sequence, in_channels, n_head, n_class)

        self.loss = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=n_class)
        self.f1 = torchmetrics.F1Score(task="multiclass",num_classes=n_class)
        self.precision = torchmetrics.Precision(task="multiclass",num_classes=n_class)
        self.recall = torchmetrics.Recall(task="multiclass",num_classes=n_class)


    def forward(self, x, subject_info):
        return self.model(x, subject_info)

    def training_step(self, batch, batch_idx):
        x, subject_info, y = batch

        # y = y.long()

        y_hat = self(x, subject_info)
        loss = self.loss(y_hat, y)

        y = torch.argmax(y, dim=-1)

        self.log("train_loss", loss)
        self.log("train_acc", self.accuracy(y_hat, y))
        self.log("train_f1", self.f1(y_hat, y))
        self.log("train_precision", self.precision(y_hat, y))
        self.log("train_recall", self.recall(y_hat, y))

        return loss

    def validation_step(self, batch, batch_idx):
        x, subject_info, y = batch

        # y = y.long()

        y_hat = self(x, subject_info)
        loss = self.loss(y_hat, y)

        y = torch.argmax(y, dim=-1)

        self.log("val_loss", loss)
        self.log("val_acc", self.accuracy(y_hat, y))
        self.log("val_f1", self.f1(y_hat, y))
        self.log("val_precision", self.precision(y_hat, y))
        self.log("val_recall", self.recall(y_hat, y))

        return loss

    # def test_step(self, batch, batch_idx):
    #     x, subject_info, y = batch

    #     # y = y.long()

    #     y_hat = self(x, subject_info)
    #     loss = self.loss(y_hat, y)

    #     y = torch.argmax(y, dim=-1)

    #     self.log("test_loss", loss)
    #     self.log("test_acc", self.accuracy(y_hat, y))
    #     self.log("test_f1", self.f1(y_hat, y))
    #     self.log("test_precision", self.precision(y_hat, y))
    #     self.log("test_recall", self.recall(y_hat, y))

    #     return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]

In [13]:
import torch
import pickle

X = torch.load('/content/drive/MyDrive/Neuro/visual_stimuli_preprocessed_data/visual_stimuli_ica_preprocessed.pt')

with open("/content/drive/MyDrive/Neuro/visual_stimuli_preprocessed_data/events_label.pkl", "rb") as f:
    y = pickle.load(f)

print('y: ', y.unique())

# with open("/content/drive/MyDrive/Neuro/visual_stimuli_preprocessed_data/visual_stimuli.pickle", "rb") as f:
#     visual_stimuli = pickle.load(f)

# print(visual_stimuli.keys())
# print('number of stimuli: ',len(visual_stimuli['visual_stimuli'][0]))

visual_stimuli = torch.load('/content/drive/MyDrive/Neuro/visual_stimuli_preprocessed_data/clip_vit32B_visual_stimuli.pt')

# one hot encoding
y = torch.nn.functional.one_hot(y.type(torch.int64), num_classes=92).type(torch.float32)

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import random_split

BACH_SIZE = 64

dataset = TensorDataset(X, visual_stimuli, y)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=BACH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BACH_SIZE, shuffle=False, num_workers=4)


y:  tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
        28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41.,
        42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55.,
        56., 57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69.,
        70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83.,
        84., 85., 86., 87., 88., 89., 90., 91.])


In [None]:
net = LitClassifier().cuda()

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

early_stopping = EarlyStopping('val_loss', patience=150, verbose=True, mode='min')
modelCheckPoint = ModelCheckpoint(monitor='val_loss', save_top_k=1, mode='min', dirpath='/content/drive/MyDrive/Neuro/visual_stimuli_preprocessed_data/architecture/preprocessed_shift/', filename='model-preprocessed_shift-{epoch:02d}-{val_loss:.2f}')

trainer = pl.Trainer(max_epochs=400, accelerator='auto', callbacks=[early_stopping, modelCheckPoint])

trainer.fit(net, train_loader, val_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                | Params
--------------------------------------------------
0 | model     | Classifier          | 1.2 M 
1 | loss      | CrossEntropyLoss    | 0     
2 | accuracy  | MulticlassAccuracy  | 0     
3 | f1        | MulticlassF1Score   | 0     
4 | precision | MulticlassPrecision | 0     
5 | recall    | MulticlassRecall    | 0     
--------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.948     Total estimated model params size (MB

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 4.589


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.043 >= min_delta = 0.0. New best score: 4.546


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 4.545


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# %load_ext tensorboard
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/