In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import sys
import os

sys.path.append('tools')
import h5py
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sklearn

import torchvision.transforms as T
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks

import data_utility
import times
import segmentation
import preprocess
import autoencoder
import visualizer
import kaggle_data_utility
import annotation_utility
import interactive_plot

In [3]:
raw_annotations = pd.read_csv('full_updated_anns_annotTbl_cleaned.csv')
ids = list(np.unique(raw_annotations[raw_annotations['descriptions'].notnull()]['HUP_ID']))
# ids = list(np.unique(raw_annotations['HUP_ID']))
ids

['HUP047',
 'HUP084',
 'HUP096',
 'HUP109',
 'HUP121',
 'HUP129',
 'HUP131',
 'HUP137',
 'HUP147',
 'HUP153',
 'HUP156',
 'HUP159',
 'HUP182',
 'HUP197',
 'HUP199',
 'HUP205',
 'RNS026',
 'RNS029']

In [4]:
data_import = data_utility.read_files(path='data/rns_data', path_data='rns_raw_cache', patientIDs=ids,
                                      verbose=True)  # Import data with annotation

100%|██████████| 18/18 [00:27<00:00,  1.55s/it]


In [5]:
annotations = annotation_utility.read_annotation(annotation_path='full_updated_anns_annotTbl_cleaned.csv',
                                                 data=data_import, n_class=3)

In [6]:
np.random.seed(seed=42)
annot = annotations.annotations
annot_nonseizure = annot[annot['Class_Code'] == 0]
annot_seizure = annot[annot['Class_Code'] == 1]
# patient_list = list(np.unique(annot['Patient_ID']))
patient_list = ['RNS026', 'HUP159', 'HUP129', 'HUP096', 'HUP182']
clip_dict = {}
for p in patient_list:
    seizure_start_index = np.array([])
    seizure_end_index = np.array([])
    nonseizure_start_index = np.array([])
    nonseizure_end_index = np.array([])
    start_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_seizure[annot_seizure['Patient_ID'] == p]['Episode_End_Index']
    annot_start_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_Start_Index']
    annot_end_list = annot_seizure[annot_seizure['Patient_ID'] == p]['Annotation_End_Index']
    for i, slel in enumerate(zip(annot_start_list, annot_end_list)):
        sl = slel[0]
        el = slel[1]
        annot_array = np.vstack((sl, el))
        test = start_index.iloc[i]
        seizure_start_index = np.hstack((seizure_start_index, annot_array[0, :]))
        seizure_end_index = np.hstack((seizure_end_index, annot_array[1, :]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, start_index.iloc[i]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[0, 0]))

        nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[1, -1]))
        nonseizure_end_index = np.hstack((nonseizure_end_index, end_index.iloc[i]))
        if annot_array.shape[1] > 1:
            test1 = annot_array[0, 1:]
            test2 = annot_array[1, :-1]
            nonseizure_start_index = np.hstack((nonseizure_start_index, annot_array[0, 1:]))
            nonseizure_end_index = np.hstack((nonseizure_end_index, annot_array[1, :-1]))

    nonseizure_valid = np.where(nonseizure_end_index - nonseizure_start_index > 500)
    seizure_valid = np.where(seizure_end_index - seizure_start_index > 500)

    nonseizure_ind_arr = np.vstack(
        (nonseizure_start_index[nonseizure_valid], nonseizure_end_index[nonseizure_valid])).astype(int)
    start_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_Start_Index']
    end_index = annot_nonseizure[annot_nonseizure['Patient_ID'] == p]['Episode_End_Index']

    print(np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(int).shape)
    valid = np.where(end_index - start_index > 500)
    nonseizure_ind_arr_eps = np.vstack((start_index.iloc[valid], end_index.iloc[valid])).astype(int)

    if len(valid[0]) and len(seizure_valid[0]) > 0:
        nonseizure_clip_temp = np.hstack((nonseizure_ind_arr, nonseizure_ind_arr_eps))
        seizure_clip_temp = np.vstack((seizure_start_index[seizure_valid], seizure_end_index[seizure_valid])).astype(
            int)

        nonseizure_clip_label = np.zeros(nonseizure_clip_temp.shape[1]).astype(int)
        seizure_clip_label = np.ones(seizure_clip_temp.shape[1]).astype(int)

        seizure_clip = np.vstack((seizure_clip_temp, seizure_clip_label))
        non_seizure_clip = np.vstack((nonseizure_clip_temp, nonseizure_clip_label))

        combined_clip = np.hstack((seizure_clip, non_seizure_clip))

        shuffled_index = np.arange(combined_clip.shape[1])
        np.random.shuffle(shuffled_index)

        clip_dict[p] = combined_clip[:, shuffled_index]



(2, 73)
(2, 99)
(2, 60)
(2, 59)
(2, 24)


In [7]:
window_len = 1
stride = 1
concat_n = 4
for id in tqdm(clip_dict.keys()):
    data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
    data_import[id].set_concatenation_parameter(concatenate_window_n=concat_n)
    window_indices, _ = data_import[id].get_windowed_data(clip_dict[id][0], clip_dict[id][1])
    import_label = np.array([])
    for i, ind in enumerate(window_indices):
        import_label = np.hstack((import_label, np.repeat(clip_dict[id][2][i], len(ind))))
    data_import[id].normalize_windowed_data()
    _, concatenated_data = data_import[id].get_concatenated_data(data_import[id].windowed_data, arrange='channel_stack')
    assert import_label.shape[0] == concatenated_data.shape[0]
    np.save('rns_test_cache/' + id + '.npy', {'data': concatenated_data, 'label': import_label})

100%|██████████| 5/5 [00:05<00:00,  1.11s/it]


In [8]:
class RNS_Downstream(Dataset):
    def __init__(self, data, label, transform=True, astensor=True):
        self.data = data
        self.label = label
        self.transform = transform
        print('data loaded')

        self.label = self.label[np.newaxis].T

        self.length = len(self.data)

        print(data.shape)
        print(label.shape)

        if astensor:
            self.augmentation = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.RandomApply([T.ColorJitter()], p=0.5),
                T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
                T.RandomInvert(p=0.2),
                T.RandomPosterize(4, p=0.2),
                T.ToTensor()
            ])

            self.totensor = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.ToTensor()
            ])
        else:
            self.augmentation = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
                T.RandomApply([T.ColorJitter()], p=0.5),
                T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
                T.RandomInvert(p=0.2),
                T.RandomPosterize(4, p=0.2),
            ])

            self.totensor = T.Compose([
                T.ToPILImage(),
                T.Resize((256, 256), interpolation=T.InterpolationMode.NEAREST),
            ])

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]

        if self.transform:
            concat_len = data.shape[1] / 4
            channel_index = np.arange(4)
            np.random.shuffle(channel_index)
            channel_index = channel_index * concat_len + (concat_len - 1) / 2
            channel_index = np.repeat(channel_index, concat_len)
            concate_len_1 = (concat_len - 1) / 2
            a_repeat = np.arange(-concate_len_1, concate_len_1 + 1)[np.newaxis].T
            base_repeat = np.repeat(a_repeat, 4, axis=1).T.flatten()
            channel_index = channel_index + base_repeat
            data = data[channel_index.astype(int)]
            data = torch.from_numpy(data).clone()
            data = data.repeat(3, 1, 1)
            data = self.augmentation(data)

        else:
            concat_len = data.shape[1] / 4
            channel_index = np.arange(4)
            # np.random.shuffle(channel_index)
            channel_index = channel_index * concat_len + (concat_len - 1) / 2
            channel_index = np.repeat(channel_index, concat_len)
            concate_len_1 = (concat_len - 1) / 2
            a_repeat = np.arange(-concate_len_1, concate_len_1 + 1)[np.newaxis].T
            base_repeat = np.repeat(a_repeat, 4, axis=1).T.flatten()
            channel_index = channel_index + base_repeat
            data = data[channel_index.astype(int)]
            data = torch.from_numpy(data).clone()
            data = data.repeat(3, 1, 1)
            data = self.totensor(data)

        return data, torch.from_numpy(label).to(dtype=torch.long), None

In [9]:
import torch
import torchvision
from torch import nn

from lightly.data import LightlyDataset, SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes


class SwaV(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SwaVProjectionHead(2048, 2048, 128)
        self.prototypes = SwaVPrototypes(128, 2048, 1)

        self.start_queue_at_epoch = 35
        self.queues = nn.ModuleList([MemoryBankModule(size=256) for _ in range(2)])

    def forward(self, high_resolution, low_resolution, epoch):
        self.prototypes.normalize()

        high_resolution_features = [self._subforward(x) for x in high_resolution]
        low_resolution_features = [self._subforward(x) for x in low_resolution]

        high_resolution_prototypes = [
            self.prototypes(x, epoch) for x in high_resolution_features
        ]
        low_resolution_prototypes = [
            self.prototypes(x, epoch) for x in low_resolution_features
        ]
        queue_prototypes = self._get_queue_prototypes(high_resolution_features, epoch)

        return high_resolution_prototypes, low_resolution_prototypes, queue_prototypes

    def _subforward(self, input):
        features = self.backbone(input).flatten(start_dim=1)
        features = self.projection_head(features)
        features = nn.functional.normalize(features, dim=1, p=2)
        return features

    @torch.no_grad()
    def _get_queue_prototypes(self, high_resolution_features, epoch):
        if len(high_resolution_features) != len(self.queues):
            raise ValueError(
                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
            )

        # Get the queue features
        queue_features = []
        for i in range(len(self.queues)):
            _, features = self.queues[i](high_resolution_features[i], update=True)
            # Queue features are in (num_ftrs X queue_length) shape, while the high res
            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
            features = torch.permute(features, (1, 0))
            queue_features.append(features)

        # If loss calculation with queue prototypes starts at a later epoch,
        # just queue the features and return None instead of queue prototypes.
        if self.start_queue_at_epoch > 0 and epoch < self.start_queue_at_epoch:
            return None

        # Assign prototypes
        queue_prototypes = [self.prototypes(x, epoch) for x in queue_features]
        return queue_prototypes


resnet = torchvision.models.resnet50()

backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SwaV(backbone)

In [10]:
import torch


def sigmoid_focal_loss(
        inputs: torch.Tensor,
        targets: torch.Tensor,
        alpha: float = 0.25,
        gamma: float = 2,
        reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

    Args:
        inputs (Tensor): A float tensor of arbitrary shape.
                The predictions for each example.
        targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
                classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha (float): Weighting factor in range (0,1) to balance
                positive vs negative examples or -1 for ignore. Default: ``0.25``.
        gamma (float): Exponent of the modulating factor (1 - p_t) to
                balance easy vs hard examples. Default: ``2``.
        reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
                ``'none'``: No reduction will be applied to the output.
                ``'mean'``: The output will be averaged.
                ``'sum'``: The output will be summed. Default: ``'none'``.
    Returns:
        Loss tensor with the reduction option applied.
    """
    # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py

    p = torch.sigmoid(inputs)

    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    # Check reduction option and return loss accordingly
    if reduction == "none":
        pass
    elif reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()
    else:
        raise ValueError(
            f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
        )
    return loss

In [11]:
import torch.nn.functional as F


class SupervisedDownstream(pl.LightningModule):
    def __init__(self, backbone, unfreeze_backbone_at_epoch=100):
        super().__init__()
        self.backbone = backbone
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 8)
        self.fc4 = nn.Linear(8, 2)
        self.softmax = nn.Softmax(dim=1)
        self.alpha = 0.5
        self.gamma = 8
        self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch

    def training_step(self, batch, batch_idx):
        x, y = batch
        if self.current_epoch < self.unfreeze_backbone_at_epoch:
            self.backbone.eval()
            x = self.backbone(x)
            with torch.no_grad():
                x = x.view(-1, 2048)
        else:
            x = self.backbone(x)
            x = x.view(-1, 2048)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        pred = self.softmax(pred)
        label = F.one_hot(y, num_classes=2).squeeze()
        loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=self.alpha, gamma=self.gamma, reduction='mean')
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = self.backbone(x)
        x = x.view(-1, 2048)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        pred = self.softmax(pred)
        label = F.one_hot(y, num_classes=2).squeeze()
        loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=self.alpha, gamma=self.gamma, reduction='mean')
        out = torch.argmax(pred, dim=1)
        out = out.detach().cpu().numpy()
        target = y.squeeze().detach().cpu().numpy()
        precision, recall, fscore, support = sklearn.metrics.precision_recall_fscore_support(out, target,labels = [0,1],zero_division=0)
        acc = sklearn.metrics.accuracy_score(out, target)
        # print(acc)
        # print(precision)
        # print(recall)
        # print(fscore)
        # Logging to TensorBoard (if installed) by default
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        self.log("val_precision", precision[1])
        self.log("val_recall", recall[1])
        return pred, label

    def predict_step(self, batch, batch_idx):
        x, y = batch
        emb = self.backbone(x)
        emb = emb.view(-1, 2048)
        x = F.relu(self.fc1(emb))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        pred = self.fc4(x)
        # Logging to TensorBoard (if installed) by default
        return pred, y, emb

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

In [12]:
def collate_fn(batch):
    info = list(zip(*batch))
    data = info[0]
    label = info[1]
    return torch.stack(data), torch.stack(label)

In [13]:
os.listdir('rns_test_cache')

['HUP047.npy',
 'HUP084.npy',
 'HUP096.npy',
 'HUP109.npy',
 'HUP121.npy',
 'HUP129.npy',
 'HUP131.npy',
 'HUP137.npy',
 'HUP147.npy',
 'HUP156.npy',
 'HUP159.npy',
 'HUP182.npy',
 'HUP199.npy',
 'RNS026.npy',
 'RNS029.npy']

In [14]:
def get_data(file_names, split=0.7):
    file_name_temp = file_names[0]
    cache = np.load('rns_test_cache/' + file_name_temp, allow_pickle=True)
    temp_file = cache.item().get('data')

    train_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    train_label = np.array([])
    test_data = np.empty((0, temp_file.shape[1], temp_file.shape[2]))
    test_label = np.array([])

    for name in tqdm(file_names):
        cache = np.load('rns_test_cache/' + name, allow_pickle=True)
        data = cache.item().get('data')
        label = cache.item().get('label')
        split_n = int(data.shape[0] * (split))
        train_data = np.vstack((train_data, data[:split_n]))
        train_label = np.hstack((train_label, label[:split_n]))
        test_data = np.vstack((test_data, data[split_n:]))
        test_label = np.hstack((test_label, label[split_n:]))

    return train_data, train_label, test_data, test_label

In [15]:
data_list = os.listdir('rns_test_cache')

train_data, train_label, test_data, test_label = get_data(data_list, split=0.3)
# data, label,_,_ = get_data(data_list, split=1)
# train_data, test_data, train_label, test_label = sklearn.model_selection.train_test_split(data, label, test_size=0.8, random_state=42)

print(train_data.shape)
print(train_label.shape)
print(test_data.shape)
print(test_label.shape)

100%|██████████| 15/15 [00:17<00:00,  1.16s/it]

(31685, 249, 36)
(31685,)
(73950, 249, 36)
(73950,)





In [16]:
test_label.sum()

21221.0

In [17]:
ckpt = torch.load("rns_ckpt/checkpoint31.pth")
resnet = torchvision.models.resnet50()
backbone = nn.Sequential(*list(resnet.children())[:-1])
swav = SwaV(backbone)
swav.load_state_dict(ckpt['model_state_dict'])
model = SupervisedDownstream(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='val_loss',
                                                   filename='swav_pretrained-{epoch:02d}-{val_loss:.5f}',
                                                   dirpath='rns_linear_checkpoints',every_n_epochs = 5)
csv_logger = pl_loggers.CSVLogger("rns_linear_logs", name="logger")

trainer = pl.Trainer(logger=csv_logger, max_epochs=80, callbacks=[checkpoint_callback], accelerator='gpu', devices=1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
train_dataset = RNS_Downstream(train_data, train_label, transform=True, astensor=True)
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

trainer.fit(model, train_dataloader, val_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | backbone | Sequential | 23.5 M
1 | fc1      | Linear     | 1.0 M 
2 | fc2      | Linear     | 32.8 K
3 | fc3      | Linear     | 520   
4 | fc4      | Linear     | 18    
5 | softmax  | Softmax    | 0     
----------------------------------------
24.6 M    Trainable params
0         Non-trainable params
24.6 M    Total params
98.362    Total estimated model params size (MB)


data loaded
(31685, 249, 36)
(31685,)
data loaded
(73950, 249, 36)
(73950,)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [19]:
test_dataset = RNS_Downstream(test_data, test_label, transform=False, astensor=True)
val_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    collate_fn=collate_fn,
    shuffle=False,
    drop_last=True,
)

data loaded
(73950, 249, 36)
(73950,)


In [20]:
predictions = trainer.predict(model,val_dataloader,ckpt_path='rns_linear_checkpoints/swav_pretrained-epoch=05-val_loss=0.00179_linear_eval_all.ckpt')

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Predicting: 19it [00:00, ?it/s]

In [21]:
output_list = []
target_list = []
emb_list = []
m = nn.Softmax(dim=1)
for pred, y, emb in predictions:
    output_list.append(pred)
    target_list.append(y)
    emb_list.append(emb)

In [22]:
pred

tensor([[ 0.1967,  0.0204],
        [ 0.1967,  0.0204],
        [ 0.1967,  0.0204],
        [ 0.3138, -0.1657],
        [ 0.7194, -0.8103],
        [ 0.7422,  0.0987],
        [ 0.7526,  0.3495],
        [ 0.8382,  0.5140],
        [ 0.9884, -1.2379],
        [ 1.0389, -1.3182],
        [ 1.0004, -1.2570],
        [ 0.9462, -1.1708],
        [ 1.1320, -1.4662],
        [ 0.8774, -1.0615],
        [ 0.6491, -0.6987],
        [ 0.7734, -0.8962],
        [ 0.9767, -1.2193],
        [ 1.0257, -1.2973],
        [ 1.1124, -1.4350],
        [ 1.1101, -1.4314],
        [ 1.1618, -1.5135],
        [ 1.3328, -1.7853],
        [ 1.2815, -1.7038],
        [ 1.2903, -1.7177],
        [ 1.3007, -1.7343],
        [ 1.0876, -1.3955],
        [ 0.8402, -1.0024],
        [ 0.7077, -0.7918],
        [ 0.6734, -0.7373],
        [ 0.5488, -0.5391],
        [ 0.4584, -0.3956],
        [ 0.4799, -0.4297],
        [ 0.4111, -0.3203],
        [ 0.2787, -0.1098],
        [ 0.5651, -0.5651],
        [ 0.5445, -0

In [23]:
pred_raw = torch.vstack(output_list)
target = torch.vstack(target_list)
emb = torch.vstack(emb_list)
out = torch.argmax(pred_raw, dim=1)

In [24]:
torch.sum(target)

tensor(21221)

In [25]:
sklearn.metrics.accuracy_score(torch.argmax(pred_raw, dim=1), target)

0.8544962812711291

In [None]:
clf_report = sklearn.metrics.classification_report(torch.argmax(pred_raw, dim=1), target, digits=6)

print(f"Classification Report : \n{clf_report}")

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [None]:
len(emb)

In [None]:
pca_comp_n = 30
batch_size = 32

pca = PCA(n_components=pca_comp_n, copy=True).fit(emb)
p = pca.transform(emb)

# ind = np.random.choice(len(emb), 10000)

tsne = TSNE(n_components=2, verbose=1, perplexity=75, random_state=142, init='pca')
z = tsne.fit_transform(emb)
interictal_inds = np.where(target == 0)[0]
ictal_inds = np.where(target == 1)[0]

In [None]:
spc = z

plt.figure(figsize=(10, 8))
# plt.scatter(spc[interictal_inds,0],spc[interictal_inds,1],c='gold',label= 'interictal')
plt.scatter(spc[ictal_inds, 0], spc[ictal_inds, 1], c='royalblue', label='ictal')
plt.title('Swav Embedding t-SNE')
plt.xlabel('comp 1')
plt.ylabel("comp 2")
plt.legend()
plt.xlim(-67, 74)
plt.ylim(-67, 75)
plt.grid()
plt.show()

In [None]:
# dt = np.vstack((z[:,0], z[:,1])).T
interactive_plot.interactive_plot(z, ['RNS026', 'HUP159', 'HUP129', 'HUP096'], data_import, color_override=target)

In [None]:
interactive_plot.interactive_plot(z, ['HUP159'], data_import, color_override=target)

In [None]:
from sklearn.metrics import RocCurveDisplay

RocCurveDisplay.from_predictions(
    target,
    output[:, 1],
    color="darkorange",
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("One-vs-Rest ROC curves:\nVirginica vs (Setosa & Versicolor)")
plt.legend()
plt.show()

In [None]:
output[:, 1]

In [None]:

output = torch.argmax(output, dim=1)
output = output.detach().cpu().numpy()
target = target.squeeze().detach().cpu().numpy()

In [None]:
import sklearn

clf_report = sklearn.metrics.classification_report(output, target, digits=6)

print(f"Classification Report : \n{clf_report}")

In [None]:
for batch, label in tqdm(val_dataloader):
    batch = batch.to(device)
    label = label.to(device)
    label = F.one_hot(label).squeeze()
    outputs = model(batch)
    print(batch)
    loss = sigmoid_focal_loss(pred.float(), label.float(), alpha=0.5, gamma=8, reduction='mean')
    print(loss)
    break

In [None]:
# import copy
# import torch
# import torchvision
# from torch import nn
#
# from lightly.data import DINOCollateFunction, LightlyDataset
# from lightly.loss import DINOLoss
# from lightly.models.modules import DINOProjectionHead
# from lightly.models.utils import deactivate_requires_grad, update_momentum
# from lightly.utils.scheduler import cosine_schedule
#
#
# class DINO(torch.nn.Module):
#     def __init__(self, backbone, input_dim):
#         super().__init__()
#         self.student_backbone = backbone
#         self.student_head = DINOProjectionHead(
#             input_dim, 512, 64, 2048, freeze_last_layer=1
#         )
#         self.teacher_backbone = copy.deepcopy(backbone)
#         self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
#         deactivate_requires_grad(self.teacher_backbone)
#         deactivate_requires_grad(self.teacher_head)
#
#     def forward(self, x):
#         y = self.student_backbone(x).flatten(start_dim=1)
#         z = self.student_head(y)
#         return z
#
#     def forward_teacher(self, x):
#         y = self.teacher_backbone(x).flatten(start_dim=1)
#         z = self.teacher_head(y)
#         return z
#
#
# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# input_dim = 512
# # instead of a resnet you can also use a vision transformer backbone as in the
# # original paper (you might have to reduce the batch size in this case):
# # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
# # input_dim = backbone.embed_dim
#
# model = DINO(backbone, input_dim)
#
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
#
# # # we ignore object detection annotations by setting target_transform to return 0
# # pascal_voc = torchvision.datasets.VOCDetection(
# #     "datasets/pascal_voc", download=True, target_transform=lambda t: 0
# # )
# # dataset = LightlyDataset.from_torch_dataset(pascal_voc)
# # # or create a dataset from a folder containing images or videos:
# # # dataset = LightlyDataset("path/to/folder")
#
# collate_fn = DINOCollateFunction(solarization_prob = 0, hf_prob = 0,vf_prob = 0,rr_prob=0,cj_prob=0,random_gray_scale=0)
#
# dataloader = torch.utils.data.DataLoader(
#     train_set,
#     batch_size=64,
#     collate_fn=collate_fn,
#     shuffle=True,
#     drop_last=True,
#     num_workers=1,
# )
#
# criterion = DINOLoss(
#     output_dim=2048,
#     warmup_teacher_temp_epochs=5,
# )
# # move loss to correct device because it also contains parameters
# criterion = criterion.to(device)
#
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#
# epochs = 10
#
# print("Starting Training")
# for epoch in range(epochs):
#     total_loss = 0
#     momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
#     for views, _, _ in tqdm(dataloader):
#         update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
#         update_momentum(model.student_head, model.teacher_head, m=momentum_val)
#         views = [view.to(device) for view in views]
#         global_views = views[:2]
#         teacher_out = [model.forward_teacher(view) for view in global_views]
#         student_out = [model.forward(view) for view in views]
#         loss = criterion(teacher_out, student_out, epoch=epoch)
#         total_loss += loss.detach()
#         loss.backward()
#         # We only cancel gradients of student head.
#         model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

In [None]:
augmentation = T.Compose([
    T.ToPILImage(),
    T.Resize((256, 512), interpolation=T.InterpolationMode.NEAREST),
    T.RandomApply([T.ColorJitter()], p=0.5),
    T.RandomApply([T.GaussianBlur(kernel_size=(3, 3))], p=0.5),
    T.RandomInvert(p=0.2),
    T.RandomPosterize(4, p=0.2),
])

data = ictal_data_X[0]

channel_index = np.arange(data.shape[0])
np.random.shuffle(channel_index)
data = data[channel_index]
data = torch.from_numpy(data).clone()
data = data.repeat(3, 1, 1)
data = augmentation(data)
data

In [None]:
channel_index

In [None]:
data[channel_index]

In [None]:
data

In [None]:
#
# print("Starting Training")
# for epoch in range(50):
#     total_loss = 0
#     i = 0
#     for batch, label in tqdm(dataloader):
#         batch = batch.to(device)
#         # print(type(batch))
#         label = label.to(device)
#         label = F.one_hot(label).squeeze()
#         outputs = model(batch)
#         loss = sigmoid_focal_loss(outputs.float(),label.float(), alpha = 0.25, gamma = 7,reduction = 'mean')
#         total_loss += loss.detach()
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#
#     avg_loss = total_loss / len(dataloader)
#     torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': avg_loss,
#             }, 'ckpt/checkpoint'+str(epoch)+'.pth')
#
#     print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")