In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import sys
import os

sys.path.append('../tools')
import h5py
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sklearn

import torchvision.transforms as T
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks
from models.rns_dataloader import get_data

import data_utility
import annotation_utility
import interactive_plot

In [3]:
data_dir = "../../../user_data/"
log_folder_root = '../../../user_data/logs/'
ckpt_folder_root = '../../../user_data/checkpoints/'

In [4]:
raw_annotations = pd.read_csv(data_dir + 'full_updated_anns_annotTbl_cleaned.csv')
ids = list(np.unique(raw_annotations[raw_annotations['descriptions'].notnull()]['HUP_ID']))
# ids = list(np.unique(raw_annotations['HUP_ID']))
ids

['HUP047',
 'HUP084',
 'HUP096',
 'HUP109',
 'HUP121',
 'HUP129',
 'HUP131',
 'HUP137',
 'HUP147',
 'HUP153',
 'HUP156',
 'HUP159',
 'HUP182',
 'HUP197',
 'HUP199',
 'HUP205',
 'RNS026',
 'RNS029']

In [5]:
data_import = data_utility.read_files(path=data_dir+'rns_data', path_data=data_dir+'rns_raw_cache', patientIDs=ids,
                                      verbose=True)  # Import data with annotation

100%|██████████| 18/18 [00:12<00:00,  1.40it/s]


In [6]:
annotations = annotation_utility.read_annotation(annotation_path=data_dir +'full_updated_anns_annotTbl_cleaned.csv',
                                                 data=data_import, n_class=3)

In [None]:
annotations

In [7]:
np.random.seed(seed=42)
annot = annotations.annotations
annot_nonseizure = annot[annot['Class_Code'] == 0]
annot_seizure = annot[annot['Class_Code'] == 1]
patient_list = list(np.unique(annot['Patient_ID']))
# patient_list = ['RNS026', 'HUP159', 'HUP129', 'HUP096', 'HUP182']
# patient_list = ['HUP159']
clip_dict = {}
for p in patient_list:
    seizure_start_index = np.array([])
    seizure_end_index = np.array([])
    nonseizure_start_index = np.array([])
    nonseizure_end_index = np.array([])
    start_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_End_Index']
    annot_start_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_Start_Index']
    annot_end_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_End_Index']

    episode_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_Start_Index']


    for i, slel in enumerate(zip(annot_start_list, annot_end_list)):
        sl = slel[0]
        el = slel[1]
        annot_array = np.vstack((sl, el))
        test = start_index.iloc[i]
        seizure_start_index = np.hstack((seizure_start_index, annot_array[0, :]))
        seizure_end_index = np.hstack((seizure_end_index, annot_array[1, :]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, start_index.iloc[i]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[0, 0]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[1, -1]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, end_index.iloc[i]))
        if annot_array.shape[1] > 1:
            test1 = annot_array[0, 1:]
            test2 = annot_array[1, :-1]
            nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[0, 1:]))
            nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[1, :-1]))

    nonseizure_valid = np.where(nonseizure_end_index - nonseizure_start_index > 500)
    seizure_valid = np.where(seizure_end_index - seizure_start_index > 500)

    nonseizure_ind_arr = np.vstack(
        (nonseizure_start_index[nonseizure_valid], nonseizure_end_index[nonseizure_valid])).astype(int)
    start_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_End_Index']

    print(np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(int).shape)
    valid = np.where(end_index - start_index > 500)
    nonseizure_ind_arr_eps = np.vstack((start_index.iloc[valid], end_index.iloc[valid])).astype(int)

    if len(valid[0]) and len(seizure_valid[0]) > 0:
        nonseizure_clip_temp = np.hstack((nonseizure_ind_arr, nonseizure_ind_arr_eps))
        seizure_clip_temp = np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(
            int)

        nonseizure_clip_label = np.zeros(nonseizure_clip_temp.shape[1]).astype(int)
        seizure_clip_label = np.ones(seizure_clip_temp.shape[1]).astype(int)

        seizure_clip = np.vstack((seizure_clip_temp, seizure_clip_label))
        non_seizure_clip = np.vstack((nonseizure_clip_temp, nonseizure_clip_label))

        combined_clip = np.hstack((seizure_clip, non_seizure_clip))

        shuffled_index = np.arange(combined_clip.shape[1])
        np.random.shuffle(shuffled_index)

        clip_dict[p] = combined_clip[:, shuffled_index]



In [8]:
# window_len = 1
# stride = 1
# concat_n = 2
# for id in tqdm(clip_dict.keys()):
#     data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
#     data_import[id].set_concatenation_parameter(concatenate_window_n=concat_n)
#     window_indices, _ = data_import[id].get_windowed_data(clip_dict[id][0], clip_dict[id][1])
#     import_label = np.array([])
#     for i, ind in enumerate(window_indices):
#         import_label = np.hstack((import_label, np.repeat(clip_dict[id][2][i], len(ind))))
#     data_import[id].normalize_windowed_data()
#     _, concatenated_data = data_import[id].get_concatenated_data(data_import[id].windowed_data, arrange='channel_stack')
#     assert import_label.shape[0] == concatenated_data.shape[0]
#     np.save(data_dir+'rns_test_cache/' + id + '.npy', {'data': concatenated_data, 'label': import_label})

In [9]:
from models.rns_dataloader import RNS_Downstream
from models.SwaV import SwaV

In [10]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset, SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes


In [11]:
def collate_fn(batch):
    info = list(zip(*batch))
    data = info[0]
    label = info[1]
    return torch.stack(data), torch.stack(label)

In [13]:
data_list = os.listdir(data_dir+'rns_test_cache')

# data_list = ['HUP182.npy',   'HUP129.npy',   'HUP109.npy', 'HUP156.npy', 'HUP096.npy', 'RNS026.npy',  'HUP159.npy']
# data_list = ['RNS026.npy', 'HUP159.npy', 'HUP129.npy', 'HUP096.npy', 'HUP182.npy']
train_data, train_label, test_data, test_label = get_data(data_list, split=0.7)
# data, label,_,_ = get_data(data_list, split=1)
# train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(data, label, test_size=0.8, random_state=42)

print(train_data.shape)
print(train_label.shape)
print(test_data.shape)
print(test_label.shape)

100%|██████████| 15/15 [00:11<00:00,  1.27it/s]

(73940, 249, 20)
(73940,)
(31695, 249, 20)
(31695,)





In [14]:
test_label.sum()

8548.0

In [15]:
from models.SupervisedDownstream import SupervisedDownstream

In [25]:
# ckpt = torch.load(ckpt_folder_root+ 'rns_checkpoints/rns_swav_ckpt_5_patients/checkpoint31.pth')
# resnet = torchvision.models.resnet50()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# swav = SwaV()
# # swav.load_state_dict(ckpt['model_state_dict'])
# model = SupervisedDownstream(swav.backbone)
#

swav = SwaV().load_from_checkpoint(ckpt_folder_root + 'rns_swav_50_12/rns_swav-epoch=32-swav_loss=2.25595.ckpt')
model = SupervisedDownstream(swav.backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='val_loss',
                                                   filename='rns_swav_50_all_linear_eval-{epoch:02d}-{val_loss:.5f}', save_last=True, save_top_k=-1, dirpath=ckpt_folder_root + 'rns_swav_50_all_linear_eval')
csv_logger = pl_loggers.CSVLogger(log_folder_root, name='rns_swav_50_all_linear_eval')

trainer = pl.Trainer(logger=csv_logger, max_epochs=80, callbacks=[checkpoint_callback], accelerator='gpu', devices=1,precision=16, check_val_every_n_epoch=5)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
train_dataset = RNS_Downstream(train_data, train_label, transform=True, astensor=True)
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    # num_workers = 2
)

val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

trainer.fit(model, train_dataloader, val_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | backbone | Sequential | 23.5 M
1 | fc1      | Linear     | 1.0 M 
2 | dp       | Dropout1d  | 0     
3 | fc2      | Linear     | 32.8 K
4 | fc3      | Linear     | 520   
5 | fc4      | Linear     | 18    
6 | softmax  | Softmax    | 0     
----------------------------------------
24.6 M    Trainable params
0         Non-trainable params
24.6 M    Total params
49.181    Total estimated model params size (MB)


data loaded
(73940, 249, 20)
(73940,)
data loaded
(31695, 249, 20)
(31695,)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [18]:
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)
val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

data loaded
(31695, 249, 20)
(31695,)


In [19]:
predictions = trainer.predict(model,val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 88it [00:00, ?it/s]

In [20]:
output_list = []
target_list = []
emb_list = []
m = nn.Softmax(dim=1)
for pred, y, emb in predictions:
    output_list.append(pred)
    target_list.append(y)
    emb_list.append(emb)

TypeError: 'NoneType' object is not iterable

In [None]:
len(target_list)

In [None]:
pred_raw = torch.vstack(output_list)
target = torch.vstack(target_list)
emb = torch.vstack(emb_list)
out = torch.argmax(pred_raw, dim=1)

In [None]:
m(pred_raw.float())[:,0]

In [None]:
emb.shape

In [None]:
target.shape

In [None]:
sklearn.metrics.accuracy_score(torch.argmax(pred_raw, dim=1), target)

In [None]:
clf_report = sklearn.metrics.classification_report(torch.argmax(pred_raw, dim=1), target, digits=6)

print(f"Classification Report : \n{clf_report}")

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [None]:
len(emb)

In [None]:
pca_comp_n = 30
batch_size = 32

pca = PCA(n_components=pca_comp_n, copy=True).fit(emb)
p = pca.transform(emb)

# ind = np.random.choice(len(emb), 10000)
#
tsne = TSNE(n_components=2, verbose=1, perplexity=75, random_state=142, init='pca')
z = tsne.fit_transform(emb)
interictal_inds = np.where(target == 0)[0]
ictal_inds = np.where(target == 1)[0]

In [None]:
pca.explained_variance_ratio_

In [None]:
len(interictal_inds)

In [None]:
len(ictal_inds)

In [None]:
# spc = p
#
# fig = plt.figure(figsize=(10, 8))
# ax = fig.add_subplot(projection='3d')
# ax.scatter(spc[interictal_inds,0],spc[interictal_inds,1],spc[interictal_inds,2],c='gold',label= 'interictal')
# ax.scatter(spc[ictal_inds, 0], spc[ictal_inds, 1],spc[ictal_inds, 2], c='royalblue', label='ictal')
# # plt.title('Swav Embedding t-SNE')
# ax.set_xlabel('comp 1')
# ax.set_ylabel("comp 2")
# ax.set_zlabel("comp 2")
# plt.legend()
# # plt.xlim(-67, 74)
# # plt.ylim(-67, 75)
# plt.grid()
# # plt.show()

In [None]:
spc = z
plt.figure(figsize=(10, 8))
plt.scatter(spc[interictal_inds,0],spc[interictal_inds,1],c='gold',label= 'interictal')
plt.scatter(spc[ictal_inds, 0], spc[ictal_inds, 1], c='royalblue', label='ictal',s=0.5)
# plt.title('Swav Embedding t-SNE')
plt.xlabel('comp 1')
plt.ylabel("comp 2")
plt.legend()
# plt.xlim(-67, 74)
# plt.ylim(-67, 75)
plt.grid()
plt.show()

In [None]:
# dt = np.vstack((z[:,0], z[:,1])).T
interactive_plot.interactive_plot(z, ['RNS026', 'HUP159', 'HUP129', 'HUP096'], data_import, color_override=target)

In [None]:
interactive_plot.interactive_plot(z, ['HUP159'], data_import, color_override=target)

In [None]:
from sklearn.metrics import RocCurveDisplay

RocCurveDisplay.from_predictions(
    target,
    m(pred_raw.float())[:,1],
    color="darkorange",
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)")
plt.legend()
plt.show()

In [None]:
output[:, 1]

In [None]:

output = torch.argmax(output, dim=1)
output = output.detach().cpu().numpy()
target = target.squeeze().detach().cpu().numpy()

In [None]:
import sklearn

clf_report = sklearn.metrics.classification_report(output, target, digits=6)

print(f"Classification Report : \n{clf_report}")

In [None]:
for batch, label in tqdm(val_dataloader):
    batch = batch.to(device)
    label = label.to(device)
    label = F.one_hot(label).squeeze()
    outputs = model(batch)
    print(batch)
    loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=0.5, gamma=8, reduction='mean')
    print(loss)
    break

In [None]:
# import copy
# import torch
# import torchvision
# from torch import nn
#
# from lightly.data import DINOCollateFunction, LightlyDataset
# from lightly.loss import DINOLoss
# from lightly.models.modules import DINOProjectionHead
# from lightly.models.utils import deactivate_requires_grad, update_momentum
# from lightly.utils.scheduler import cosine_schedule
#
#
# class DINO(torch.nn.Module):
#     def __init__(self, backbone, input_dim):
#         super().__init__()
#         self.student_backbone = backbone
#         self.student_head = DINOProjectionHead(
#             input_dim, 512, 64, 2048, freeze_last_layer=1
#         )
#         self.teacher_backbone = copy.deepcopy(backbone)
#         self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
#         deactivate_requires_grad(self.teacher_backbone)
#         deactivate_requires_grad(self.teacher_head)
#
#     def forward(self, x):
#         y = self.student_backbone(x).flatten(start_dim=1)
#         z = self.student_head(y)
#         return z
#
#     def forward_teacher(self, x):
#         y = self.teacher_backbone(x).flatten(start_dim=1)
#         z = self.teacher_head(y)
#         return z
#
#
# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# input_dim = 512
# # instead of a resnet you can also use a vision transformer backbone as in the
# # original paper (you might have to reduce the batch size in this case):
# # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# # input_dim = backbone.embed_dim
#
# model = DINO(backbone, input_dim)
#
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
#
# # # we ignore object detection annotations by setting target_transform to return 0
# # pascal_voc = torchvision.datasets.VOCDetection(
# #     "datasets/pascal_voc", download=True, target_transform=lambda t: 0
# # )
# # dataset = LightlyDataset.from_torch_dataset(pascal_voc)
# # # or create a dataset from a folder containing images or videos:
# # # dataset = LightlyDataset("path/to/folder")
#
# collate_fn = DINOCollateFunction(solarization_prob = 0, hf_prob = 0,vf_prob = 0,rr_prob=0,cj_prob=0,random_gray_scale=0)
#
# dataloader = torch.utils.data.DataLoader(
#     train_set,
#     batch_size=64,
#     collate_fn=collate_fn,
#     shuffle=True,
#     drop_last=True,
#     num_workers=1,
# )
#
# criterion = DINOLoss(
#     output_dim=2048,
#     warmup_teacher_temp_epochs=5,
# )
# # move loss to correct device because it also contains parameters
# criterion = criterion.to(device)
#
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#
# epochs = 10
#
# print("Starting Training")
# for epoch in range(epochs):
#     total_loss = 0
#     momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
#     for views, _, _ in tqdm(dataloader):
#         update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
#         update_momentum(model.student_head, model.teacher_head, m=momentum_val)
#         views = [view.to(device) for view in views]
#         global_views = views[:2]
#         teacher_out = [model.forward_teacher(view) for view in global_views]
#         student_out = [model.forward(view) for view in views]
#         loss = criterion(teacher_out, student_out, epoch=epoch)
#         total_loss += loss.detach()
#         loss.backward()
#         # We only cancel gradients of student head.
#         model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

In [None]:
augmentation = T.Compose([
    T.ToPILImage(),
    T.Resize((256, 512), interpolation=T.InterpolationMode.NEAREST),
    T.RandomApply([T.ColorJitter()], p=0.5),
    T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
    T.RandomInvert(p=0.2),
    T.RandomPosterize(4, p=0.2),
])

data = ictal_data_X[0]

channel_index = np.arange(data.shape[0])
np.random.shuffle(channel_index)
data = data[channel_index]
data = torch.from_numpy(data).clone()
data = data.repeat(3, 1, 1)
data = augmentation(data)
data

In [None]:
channel_index

In [None]:
data[channel_index]

In [None]:
data

In [None]:
#
# print("Starting Training")
# for epoch in range(50):
#     total_loss = 0
#     i = 0
#     for batch, label in tqdm(dataloader):
#         batch = batch.to(device)
#         # print(type(batch))
#         label = label.to(device)
#         label = F.one_hot(label).squeeze()
#         outputs = model(batch)
#         loss = sigmoid_focal_loss(outputs.float(),label.float(), alpha = 0.25, gamma = 7,reduction = 'mean')
#         total_loss += loss.detach()
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': avg_loss,
#             }, 'ckpt/checkpoint'+str(epoch)+'.pth')
#
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")