In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [3]:
import sys
import os

sys.path.append('../tools')
import h5py
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sklearn

import torchvision.transforms as T
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks

import data_utility
import times
import segmentation
import preprocess
import autoencoder
import visualizer
import kaggle_data_utility
import annotation_utility
import interactive_plot

In [5]:
data_dir = "../../../user_data/"
log_folder_root = '../../../user_data/logs/'
ckpt_folder_root = '../../../user_data/checkpoints/'

In [6]:
raw_annotations = pd.read_csv(data_dir + 'full_updated_anns_annotTbl_cleaned.csv')
ids = list(np.unique(raw_annotations[raw_annotations['descriptions'].notnull()]['HUP_ID']))
# ids = list(np.unique(raw_annotations['HUP_ID']))
ids

['HUP047',
 'HUP084',
 'HUP096',
 'HUP109',
 'HUP121',
 'HUP129',
 'HUP131',
 'HUP137',
 'HUP147',
 'HUP153',
 'HUP156',
 'HUP159',
 'HUP182',
 'HUP197',
 'HUP199',
 'HUP205',
 'RNS026',
 'RNS029']

In [8]:
data_import = data_utility.read_files(path=data_dir+'rns_data', path_data=data_dir+'rns_raw_cache', patientIDs=ids,
                                      verbose=True)  # Import data with annotation

100%|██████████| 18/18 [00:28<00:00,  1.60s/it]


In [9]:
annotations = annotation_utility.read_annotation(annotation_path=data_dir +'full_updated_anns_annotTbl_cleaned.csv',
                                                 data=data_import, n_class=3)

In [10]:
np.random.seed(seed=42)
annot = annotations.annotations
annot_nonseizure = annot[annot['Class_Code'] == 0]
annot_seizure = annot[annot['Class_Code'] == 1]
# patient_list = list(np.unique(annot['Patient_ID']))
# patient_list = ['RNS026', 'HUP159', 'HUP129', 'HUP096', 'HUP182']
patient_list = ['HUP159']
clip_dict = {}
for p in patient_list:
    seizure_start_index = np.array([])
    seizure_end_index = np.array([])
    nonseizure_start_index = np.array([])
    nonseizure_end_index = np.array([])
    start_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_End_Index']
    annot_start_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_Start_Index']
    annot_end_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_End_Index']
    for i, slel in enumerate(zip(annot_start_list, annot_end_list)):
        sl = slel[0]
        el = slel[1]
        annot_array = np.vstack((sl, el))
        test = start_index.iloc[i]
        seizure_start_index = np.hstack((seizure_start_index, annot_array[0, :]))
        seizure_end_index = np.hstack((seizure_end_index, annot_array[1, :]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, start_index.iloc[i]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[0, 0]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[1, -1]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, end_index.iloc[i]))
        if annot_array.shape[1] > 1:
            test1 = annot_array[0, 1:]
            test2 = annot_array[1, :-1]
            nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[0, 1:]))
            nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[1, :-1]))

    nonseizure_valid = np.where(nonseizure_end_index - nonseizure_start_index > 500)
    seizure_valid = np.where(seizure_end_index - seizure_start_index > 500)

    nonseizure_ind_arr = np.vstack(
        (nonseizure_start_index[nonseizure_valid], nonseizure_end_index[nonseizure_valid])).astype(int)
    start_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_End_Index']

    print(np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(int).shape)
    valid = np.where(end_index - start_index > 500)
    nonseizure_ind_arr_eps = np.vstack((start_index.iloc[valid], end_index.iloc[valid])).astype(int)

    if len(valid[0]) and len(seizure_valid[0]) > 0:
        nonseizure_clip_temp = np.hstack((nonseizure_ind_arr, nonseizure_ind_arr_eps))
        seizure_clip_temp = np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(
            int)

        nonseizure_clip_label = np.zeros(nonseizure_clip_temp.shape[1]).astype(int)
        seizure_clip_label = np.ones(seizure_clip_temp.shape[1]).astype(int)

        seizure_clip = np.vstack((seizure_clip_temp, seizure_clip_label))
        non_seizure_clip = np.vstack((nonseizure_clip_temp, nonseizure_clip_label))

        combined_clip = np.hstack((seizure_clip, non_seizure_clip))

        shuffled_index = np.arange(combined_clip.shape[1])
        np.random.shuffle(shuffled_index)

        clip_dict[p] = combined_clip[:, shuffled_index]



(2, 99)


In [17]:
window_len = 1
stride = 1
concat_n = 2
for id in tqdm(clip_dict.keys()):
    data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
    data_import[id].set_concatenation_parameter(concatenate_window_n=concat_n)
    window_indices, _ = data_import[id].get_windowed_data(clip_dict[id][0], clip_dict[id][1])
    import_label = np.array([])
    for i, ind in enumerate(window_indices):
        import_label = np.hstack((import_label, np.repeat(clip_dict[id][2][i], len(ind))))
    data_import[id].normalize_windowed_data()
    _, concatenated_data = data_import[id].get_concatenated_data(data_import[id].windowed_data, arrange='channel_stack')
    assert import_label.shape[0] == concatenated_data.shape[0]
    np.save(data_dir+'rns_test_cache/' + id + '.npy', {'data': concatenated_data, 'label': import_label})

100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


In [18]:
from models.rns_dataloader import RNS_Downstream
from models.SwaV import SwaV

In [19]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset, SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes


In [20]:
def collate_fn(batch):
    info = list(zip(*batch))
    data = info[0]
    label = info[1]
    return torch.stack(data), torch.stack(label)

In [21]:
def get_data(file_names, split=0.7):
    file_name_temp = file_names[0]
    cache = np.load(data_dir+'rns_test_cache/' + file_name_temp, allow_pickle=True)
    temp_file = cache.item().get('data')

    train_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    train_label = np.array([])
    test_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    test_label = np.array([])

    for name in tqdm(file_names):
        cache = np.load(data_dir+'rns_test_cache/' + name, allow_pickle=True)
        data = cache.item().get('data')
        label = cache.item().get('label')
        split_n = int(data.shape[0] * (split))
        train_data = np.vstack((train_data, data[:split_n]))
        train_label = np.hstack((train_label, label[:split_n]))
        test_data = np.vstack((test_data, data[split_n:]))
        test_label = np.hstack((test_label, label[split_n:]))

    return train_data, train_label, test_data, test_label

In [22]:
data_list = os.listdir(data_dir+'rns_test_cache')

train_data, train_label, test_data, test_label = get_data(data_list, split=0.3)
# data, label,_,_ = get_data(data_list, split=1)
# train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(data, label, test_size=0.8, random_state=42)

print(train_data.shape)
print(train_label.shape)
print(test_data.shape)
print(test_label.shape)

100%|██████████| 1/1 [00:00<00:00,  4.16it/s]

(3837, 249, 20)
(3837,)
(8953, 249, 20)
(8953,)





In [23]:
test_label.sum()

2446.0

In [29]:
from models.SupervisedDownstream import SupervisedDownstream

In [30]:
swav = SwaV().load_from_checkpoint(ckpt_folder_root + 'rns_swav_34/rns_swav_159-epoch=148-swav_loss=2.99675.ckpt')
model = SupervisedDownstream(swav.backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='val_loss',
                                                   filename='rns_swav_34_159_linear_eval-{epoch:02d}-{val_loss:.5f}', save_last=True, save_top_k=-1, dirpath=ckpt_folder_root + 'rns_swav_34_159_linear_eval')
csv_logger = pl_loggers.CSVLogger(log_folder_root, name='rns_swav_34_159_linear_eval')

trainer = pl.Trainer(logger=csv_logger, max_epochs=80, callbacks=[checkpoint_callback], accelerator='gpu', devices=1,precision=16)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [31]:
train_dataset = RNS_Downstream(train_data, train_label, transform=True, astensor=True)
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

trainer.fit(model, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | backbone | Sequential | 21.3 M
1 | fc1      | Linear     | 262 K 
2 | fc2      | Linear     | 32.8 K
3 | fc3      | Linear     | 520   
4 | fc4      | Linear     | 18    
5 | softmax  | Softmax    | 0     
----------------------------------------
21.6 M    Trainable params
0         Non-trainable params
21.6 M    Total params
43.161    Total estimated model params size (MB)


data loaded
(3837, 249, 20)
(3837,)
data loaded
(8953, 249, 20)
(8953,)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [32]:
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)
val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

data loaded
(8953, 249, 20)
(8953,)


In [33]:
predictions = trainer.predict(model,val_dataloader,ckpt_path=ckpt_folder_root+'rns_swav_34_159_linear_eval/rns_swav_34_159_linear_eval-epoch=58-val_loss=0.00277.ckpt')

Restoring states from the checkpoint path at ../../../user_data/checkpoints/rns_swav_34_159_linear_eval/rns_swav_34_159_linear_eval-epoch=58-val_loss=0.00277.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at ../../../user_data/checkpoints/rns_swav_34_159_linear_eval/rns_swav_34_159_linear_eval-epoch=58-val_loss=0.00277.ckpt
  rank_zero_warn(


Predicting: 6it [00:00, ?it/s]

In [34]:
output_list = []
target_list = []
emb_list = []
m = nn.Softmax(dim=1)
for pred, y, emb in predictions:
    output_list.append(pred)
    target_list.append(y)
    emb_list.append(emb)

In [35]:
pred

tensor([[  1.4619,  -2.8906],
        [  1.4932,  -2.5801],
        [  1.2217,  -2.1660],
        [  1.1758,  -2.1367],
        [  0.7915,  -1.3877],
        [  1.6162,  -2.7012],
        [  1.6152,  -2.9844],
        [  0.5127,  -1.2471],
        [ -1.4141,  -0.4885],
        [ -0.3228,  -0.2288],
        [ -0.8813,  -0.2332],
        [ -3.3652,  -0.7695],
        [-12.9766,  -0.7441],
        [-10.9531,  -0.5576],
        [ -6.2930,  -0.6133],
        [ -8.7109,  -0.1185],
        [ -9.1094,  -0.1694],
        [ -3.2402,  -0.4854],
        [ -9.3359,  -0.8213],
        [ -8.3438,  -0.6147],
        [-16.6250,  -0.9712],
        [-16.0781,  -1.0488],
        [-10.7188,  -1.2754],
        [-13.4375,  -1.0547],
        [-12.1016,  -1.0586],
        [-13.5781,  -0.9375],
        [ -9.8828,  -0.9805],
        [-12.2578,  -1.0479],
        [ -6.0547,  -1.1758],
        [ -5.0195,  -1.1279],
        [ -4.3906,  -1.2793],
        [ -4.1992,  -1.0205],
        [ -3.1914,   0.0316],
        [ 

In [36]:
pred_raw = torch.vstack(output_list)
target = torch.vstack(target_list)
emb = torch.vstack(emb_list)
out = torch.argmax(pred_raw, dim=1)

In [37]:
torch.sum(target)

tensor(2446)

In [38]:
sklearn.metrics.accuracy_score(torch.argmax(pred_raw, dim=1), target)

0.8821624036635765

In [39]:
clf_report = sklearn.metrics.classification_report(torch.argmax(pred_raw, dim=1), target, digits=6)

print(f"Classification Report : \n{clf_report}")

Classification Report : 
              precision    recall  f1-score   support

           0   0.915783  0.921590  0.918677      6466
           1   0.792723  0.779654  0.786134      2487

    accuracy                       0.882162      8953
   macro avg   0.854253  0.850622  0.852406      8953
weighted avg   0.881599  0.882162  0.881859      8953



In [41]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [42]:
len(emb)

8953

In [43]:
pca_comp_n = 30
batch_size = 32

pca = PCA(n_components=pca_comp_n, copy=True).fit(emb)
p = pca.transform(emb)

# ind = np.random.choice(len(emb), 10000)

tsne = TSNE(n_components=2, verbose=1, perplexity=75, random_state=142, init='pca')
z = tsne.fit_transform(emb)
interictal_inds = np.where(target == 0)[0]
ictal_inds = np.where(target == 1)[0]



[t-SNE] Computing 226 nearest neighbors...
[t-SNE] Indexed 8953 samples in 0.003s...
[t-SNE] Computed neighbors for 8953 samples in 1.903s...
[t-SNE] Computed conditional probabilities for sample 1000 / 8953
[t-SNE] Computed conditional probabilities for sample 2000 / 8953
[t-SNE] Computed conditional probabilities for sample 3000 / 8953
[t-SNE] Computed conditional probabilities for sample 4000 / 8953
[t-SNE] Computed conditional probabilities for sample 5000 / 8953
[t-SNE] Computed conditional probabilities for sample 6000 / 8953
[t-SNE] Computed conditional probabilities for sample 7000 / 8953
[t-SNE] Computed conditional probabilities for sample 8000 / 8953
[t-SNE] Computed conditional probabilities for sample 8953 / 8953
[t-SNE] Mean sigma: 3.155000




[t-SNE] KL divergence after 250 iterations with early exaggeration: 72.321732
[t-SNE] KL divergence after 1000 iterations: 1.255818


In [44]:
spc = z

plt.figure(figsize=(10, 8))
# plt.scatter(spc[interictal_inds,0],spc[interictal_inds,1],c='gold',label= 'interictal')
plt.scatter(spc[ictal_inds, 0], spc[ictal_inds, 1], c='royalblue', label='ictal')
plt.title('Swav Embedding t-SNE')
plt.xlabel('comp 1')
plt.ylabel("comp 2")
plt.legend()
plt.xlim(-67, 74)
plt.ylim(-67, 75)
plt.grid()
plt.show()

<IPython.core.display.Javascript object>

In [45]:
# dt = np.vstack((z[:,0], z[:,1])).T
interactive_plot.interactive_plot(z, ['RNS026', 'HUP159', 'HUP129', 'HUP096'], data_import, color_override=target)

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 1 and the array at index 1 has size 2

In [None]:
interactive_plot.interactive_plot(z, ['HUP159'], data_import, color_override=target)

In [48]:
from sklearn.metrics import RocCurveDisplay

RocCurveDisplay.from_predictions(
    target,
    out,
    color="darkorange",
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)")
plt.legend()
plt.show()

<IPython.core.display.Javascript object>

In [None]:
output[:, 1]

In [None]:

output = torch.argmax(output, dim=1)
output = output.detach().cpu().numpy()
target = target.squeeze().detach().cpu().numpy()

In [None]:
import sklearn

clf_report = sklearn.metrics.classification_report(output, target, digits=6)

print(f"Classification Report : \n{clf_report}")

In [None]:
for batch, label in tqdm(val_dataloader):
    batch = batch.to(device)
    label = label.to(device)
    label = F.one_hot(label).squeeze()
    outputs = model(batch)
    print(batch)
    loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=0.5, gamma=8, reduction='mean')
    print(loss)
    break

In [None]:
# import copy
# import torch
# import torchvision
# from torch import nn
#
# from lightly.data import DINOCollateFunction, LightlyDataset
# from lightly.loss import DINOLoss
# from lightly.models.modules import DINOProjectionHead
# from lightly.models.utils import deactivate_requires_grad, update_momentum
# from lightly.utils.scheduler import cosine_schedule
#
#
# class DINO(torch.nn.Module):
#     def __init__(self, backbone, input_dim):
#         super().__init__()
#         self.student_backbone = backbone
#         self.student_head = DINOProjectionHead(
#             input_dim, 512, 64, 2048, freeze_last_layer=1
#         )
#         self.teacher_backbone = copy.deepcopy(backbone)
#         self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
#         deactivate_requires_grad(self.teacher_backbone)
#         deactivate_requires_grad(self.teacher_head)
#
#     def forward(self, x):
#         y = self.student_backbone(x).flatten(start_dim=1)
#         z = self.student_head(y)
#         return z
#
#     def forward_teacher(self, x):
#         y = self.teacher_backbone(x).flatten(start_dim=1)
#         z = self.teacher_head(y)
#         return z
#
#
# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# input_dim = 512
# # instead of a resnet you can also use a vision transformer backbone as in the
# # original paper (you might have to reduce the batch size in this case):
# # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# # input_dim = backbone.embed_dim
#
# model = DINO(backbone, input_dim)
#
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
#
# # # we ignore object detection annotations by setting target_transform to return 0
# # pascal_voc = torchvision.datasets.VOCDetection(
# #     "datasets/pascal_voc", download=True, target_transform=lambda t: 0
# # )
# # dataset = LightlyDataset.from_torch_dataset(pascal_voc)
# # # or create a dataset from a folder containing images or videos:
# # # dataset = LightlyDataset("path/to/folder")
#
# collate_fn = DINOCollateFunction(solarization_prob = 0, hf_prob = 0,vf_prob = 0,rr_prob=0,cj_prob=0,random_gray_scale=0)
#
# dataloader = torch.utils.data.DataLoader(
#     train_set,
#     batch_size=64,
#     collate_fn=collate_fn,
#     shuffle=True,
#     drop_last=True,
#     num_workers=1,
# )
#
# criterion = DINOLoss(
#     output_dim=2048,
#     warmup_teacher_temp_epochs=5,
# )
# # move loss to correct device because it also contains parameters
# criterion = criterion.to(device)
#
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#
# epochs = 10
#
# print("Starting Training")
# for epoch in range(epochs):
#     total_loss = 0
#     momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
#     for views, _, _ in tqdm(dataloader):
#         update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
#         update_momentum(model.student_head, model.teacher_head, m=momentum_val)
#         views = [view.to(device) for view in views]
#         global_views = views[:2]
#         teacher_out = [model.forward_teacher(view) for view in global_views]
#         student_out = [model.forward(view) for view in views]
#         loss = criterion(teacher_out, student_out, epoch=epoch)
#         total_loss += loss.detach()
#         loss.backward()
#         # We only cancel gradients of student head.
#         model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

In [None]:
augmentation = T.Compose([
    T.ToPILImage(),
    T.Resize((256, 512), interpolation=T.InterpolationMode.NEAREST),
    T.RandomApply([T.ColorJitter()], p=0.5),
    T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
    T.RandomInvert(p=0.2),
    T.RandomPosterize(4, p=0.2),
])

data = ictal_data_X[0]

channel_index = np.arange(data.shape[0])
np.random.shuffle(channel_index)
data = data[channel_index]
data = torch.from_numpy(data).clone()
data = data.repeat(3, 1, 1)
data = augmentation(data)
data

In [None]:
channel_index

In [None]:
data[channel_index]

In [None]:
data

In [None]:
#
# print("Starting Training")
# for epoch in range(50):
#     total_loss = 0
#     i = 0
#     for batch, label in tqdm(dataloader):
#         batch = batch.to(device)
#         # print(type(batch))
#         label = label.to(device)
#         label = F.one_hot(label).squeeze()
#         outputs = model(batch)
#         loss = sigmoid_focal_loss(outputs.float(),label.float(), alpha = 0.25, gamma = 7,reduction = 'mean')
#         total_loss += loss.detach()
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': avg_loss,
#             }, 'ckpt/checkpoint'+str(epoch)+'.pth')
#
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")