In [1]:
%load_ext autoreload
%autoreload 2
# %matplotlib inline
# %matplotlib notebook

In [2]:
import sys
import os

sys.path.append('../tools')

In [3]:
import numpy as np
import matplotlib

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sklearn

import torch
import torch.nn as nn
from torch import Tensor
import math
from torchsummary import summary
import math
import lightning as L
import sys
import numpy as np
import time
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from model_config import MODEL_CONFIG
from lightly.models.modules import SwaVPrototypes, SwaVProjectionHead

import torchvision.transforms as T
import pytorch_lightning as pl
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import loggers as pl_loggers

matplotlib.use("nbAgg")

import data_utility

from itertools import groupby
from tqdm import tqdm

In [4]:
import matplotlib.pyplot as plt

In [5]:
data_dir = "../../../user_data/"
log_folder_root = '../../../user_data/logs/'
ckpt_folder_root = '../../../user_data/checkpoints/'
torch.set_float32_matmul_precision('medium')

In [6]:
dir_list = os.listdir(data_dir + 'rns_data')
patientIDs = [s for s in dir_list for type_string in ['HUP', 'RNS'] if type_string in s.upper()]

In [7]:
data_import = data_utility.read_files(path=data_dir + 'rns_data', annotation_only=False, verbose=True)
ids = list(data_import.keys())

100%|██████████| 28/28 [00:47<00:00,  1.69s/it]


In [8]:
window_len = 9
stride = 9
concat_n = 1

ids = ['HUP047',
       'HUP059',
       'HUP084',
       'HUP096',
       'HUP101',
       'HUP108',
       'HUP109',
       'HUP121',
       'HUP127',
       'HUP128',
       'HUP129',
       'HUP131',
       'HUP136',
       # 'HUP137',
       'HUP143',
       'HUP147',
       'HUP153',
       # 'HUP156',
       'HUP159',
       'HUP182',
       'HUP192',
       'HUP197',
       'HUP199',
       'HUP205',
       'RNS021',
       'RNS022',
       'RNS026',
       'RNS029']

data_list = []
for id in ids:
    print(id)
    data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
    data_import[id].normalize_data()
    _, sliced_data = data_import[id].get_windowed_data(data_import[id].catalog["Event Start idx"],
                                                       data_import[id].catalog["Event End idx"])
    data_list.append(sliced_data)


HUP047
HUP059
HUP084
HUP096
HUP101
bla
HUP108
bla
HUP109
bla
HUP121
HUP127
HUP128
bla
HUP129
bla
HUP131
HUP136
bla
HUP143
HUP147
bla
bla
HUP153
HUP159
HUP182
bla
HUP192
bla
HUP197
HUP199
bla
HUP205
RNS021
bla
bla
RNS022
RNS026
bla
RNS029
bla
bla
bla


In [9]:
# next(iter(RNSDataset(sliced_data)))[0][1].size()

In [10]:
# next(iter(unlabeled_dataset))[0].size()

In [29]:
# import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
import os
import random
from lightly.loss import SwaVLoss
from lightly.loss.memory_bank import MemoryBankModule
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.transforms.swav_transform import SwaVTransform
from lightly.data import SwaVCollateFunction
from model_config import MODEL_CONFIG


class Transpose(nn.Module):
    def __init__(self, dim1, dim2):
        super(Transpose, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


class SwaV(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        resnet = torchvision.models.resnet50()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SwaVProjectionHead(2048, 2048, 128)
        self.prototypes = SwaVPrototypes(128, 20)
        self.start_queue_at_epoch = 3
        self.queues = nn.ModuleList([MemoryBankModule(size=1024) for _ in range(2)])
        self.criterion = SwaVLoss(sinkhorn_epsilon=0.05)

        self.ft_enc = nn.ModuleList()
        for i, _ in enumerate(config.ft_enc_dims):
            if i == 0:
                self.ft_enc.append(
                    nn.Conv1d(
                        in_channels=4,
                        out_channels=config.ft_enc_dims[i],
                        kernel_size=config.ft_enc_kernel_widths[i],
                        stride=config.ft_enc_strides[i],
                        padding=0,
                        groups=config.channel_buffer_size,
                    )
                )
            else:
                self.ft_enc.append(
                    nn.Conv1d(
                        in_channels=config.ft_enc_dims[i - 1],
                        out_channels=config.ft_enc_dims[i],
                        kernel_size=config.ft_enc_kernel_widths[i],
                        stride=config.ft_enc_strides[i],
                        padding=0,
                        groups=config.channel_buffer_size,
                    )
                )
            # transpose the output of the convolutional layer
            self.ft_enc.append(Transpose(1, 2))
            # layer normalization
            self.ft_enc.append(nn.LayerNorm(config.ft_enc_dims[i]))
            # GELU activation
            self.ft_enc.append(nn.GELU())
            # transpose the output of the convolutional layer
            self.ft_enc.append(Transpose(1, 2))

        # add a adaptive pool so different sized crops can be the same length afterward
        self.ft_enc.append(nn.AdaptiveAvgPool1d(config.spatial_transformer_hidden))

        # convert the list of modules to a sequential module
        self.ft_enc = nn.Sequential(*self.ft_enc)

        crop_transforms = []
        crop_sizes = [224, 96]
        crop_min_scales = [0.14, 0.05]
        crop_max_scales = [1.0, 0.14]
        crop_counts = [2, 6]
        for i in range(len(crop_sizes)):
            random_resized_crop = T.RandomResizedCrop(crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]))

            crop_transforms.extend([T.Compose([random_resized_crop])] * crop_counts[i])

        self.crop_transforms = crop_transforms

    def training_step(self, batch, batch_idx):
        x = batch[0].float()

        x = self.ft_enc(x)

        views = []
        for tf in self.crop_transforms:
            views.append(tf(x).unsqueeze(1).repeat(1, 3, 1, 1))

        high_resolution, low_resolution = views[:2], views[2:]
        self.prototypes.normalize()

        high_resolution_features = [self._subforward(x) for x in high_resolution]
        low_resolution_features = [self._subforward(x) for x in low_resolution]

        high_resolution_prototypes = [
            self.prototypes(x, self.current_epoch) for x in high_resolution_features
        ]
        low_resolution_prototypes = [
            self.prototypes(x, self.current_epoch) for x in low_resolution_features
        ]
        queue_prototypes = self._get_queue_prototypes(high_resolution_features)
        loss = self.criterion(
            high_resolution_prototypes, low_resolution_prototypes, queue_prototypes
        )
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.001)
        return optim

    def _subforward(self, input):
        features = self.backbone(input).flatten(start_dim=1)
        features = self.projection_head(features)
        features = nn.functional.normalize(features, dim=1, p=2)
        return features

    @torch.no_grad()
    def _get_queue_prototypes(self, high_resolution_features):
        if len(high_resolution_features) != len(self.queues):
            raise ValueError(
                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
            )

        # Get the queue features
        queue_features = []
        for i in range(len(self.queues)):
            _, features = self.queues[i](high_resolution_features[i], update=True)
            # Queue features are in (num_ftrs X queue_length) shape, while the high res
            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
            features = torch.permute(features, (1, 0))
            queue_features.append(features)

        # If loss calculation with queue prototypes starts at a later epoch,
        # just queue the features and return None instead of queue prototypes.
        if (
                self.start_queue_at_epoch > 0
                and self.current_epoch < self.start_queue_at_epoch
        ):
            return None

        # Assign prototypes
        queue_prototypes = [
            self.prototypes(x, self.current_epoch) for x in queue_features
        ]
        return queue_prototypes

In [12]:
del data_import

In [13]:
from models.rns_dataloader import RNSDataset

unlabeled_dataset = RNSDataset(data_list, transform=False)


In [14]:
# next(iter(unlabeled_dataset))[0].size()

In [15]:
dataloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=128,
    shuffle=True,
    drop_last=True,
    num_workers=10,
    persistent_workers=True
)

In [16]:
# np.vstack(data_list).shape

In [30]:
model = SwaV(MODEL_CONFIG)

accelerator = "gpu" if torch.cuda.is_available() else "cpu"

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='train_loss',
                                                   filename='model_epoch-{epoch:02d}-{train_loss:.5f}',
                                                   save_top_k=-1,
                                                   every_n_epochs=1,
                                                   # enable_version_counter=True,
                                                   dirpath=ckpt_folder_root + 'rns_swav_p20')

early_stopping = pl_callbacks.EarlyStopping(monitor="train_loss",
                                            mode="min",
                                            patience=15)
csv_logger = pl_loggers.CSVLogger(log_folder_root + 'rns_swav_p20',
                                  name="log")

trainer = L.Trainer(log_every_n_steps=100,
                    logger=csv_logger,
                    max_epochs=500,
                    callbacks=[checkpoint_callback, early_stopping],
                    accelerator='gpu',
                    devices=1,
                    precision=16)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [31]:
trainer.fit(model=model, train_dataloaders=dataloader,ckpt_path=ckpt_folder_root + 'rns_swav_p20/model_epoch-epoch=00-train_loss=1.43917.ckpt')

Restoring states from the checkpoint path at ../../../user_data/checkpoints/rns_swav_p20/model_epoch-epoch=00-train_loss=1.43917.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params
-------------------------------------------------------
0 | backbone        | Sequential         | 23.5 M
1 | projection_head | SwaVProjectionHead | 4.5 M 
2 | prototypes      | SwaVPrototypes     | 2.6 K 
3 | queues          | ModuleList         | 0     
4 | criterion       | SwaVLoss           | 0     
5 | ft_enc          | Sequential         | 22.7 K
-------------------------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
111.976   Total estimated model params size (MB)
Restored all states from the checkpoint at ../../../user_data/checkpoints/rns_swav_p20/model_epoch-epoch=00-train_loss=1.43917.ckpt


Training: |          | 0/? [00:00<?, ?it/s]

In [18]:
from torchsummary import summary

summary(model.ft_enc.cuda(), (4, 2249))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1             [-1, 16, 1124]              64
         Transpose-2             [-1, 1124, 16]               0
         LayerNorm-3             [-1, 1124, 16]              32
              GELU-4             [-1, 1124, 16]               0
         Transpose-5             [-1, 16, 1124]               0
            Conv1d-6              [-1, 64, 561]             832
         Transpose-7              [-1, 561, 64]               0
         LayerNorm-8              [-1, 561, 64]             128
              GELU-9              [-1, 561, 64]               0
        Transpose-10              [-1, 64, 561]               0
           Conv1d-11             [-1, 128, 280]           4,224
        Transpose-12             [-1, 280, 128]               0
        LayerNorm-13             [-1, 280, 128]             256
             GELU-14             [-1, 2

In [19]:
test_a = torch.arange(65536).reshape(256, 256).repeat(10, 1, 1)

In [20]:
crop_transforms = []
crop_sizes = [224, 96]
crop_min_scales = [0.14, 0.05]
crop_max_scales = [1.0, 0.14]
crop_counts = [2, 6]
for i in range(len(crop_sizes)):
    random_resized_crop = T.RandomResizedCrop(crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]))

    crop_transforms.extend([T.Compose([random_resized_crop])] * crop_counts[i])

views = []
for tf in crop_transforms:
    views.append(tf(test_a))

In [21]:
from models.rns_dataloader import RNSDataset

In [22]:
torch.arange(1000).repeat(4,1).size()

torch.Size([4, 1000])

In [23]:
class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """
    def __init__(self, out_length):
        self.out_length = out_length
    def __call__(self, sample):

        slice_length = np.random.choice(np.arange(5,9)*250)
        start_slice = np.random.randint(low=0, high=self.out_length - slice_length)
        padded_tensor = torch.zeros_like(sample)
        sample = T.functional.crop(sample, 0, start_slice, 4, slice_length)


        if np.random.choice([True, False]):
            padded_tensor[:, padded_tensor.size()[-1]-sample.size()[-1]:] = sample
        else:
            padded_tensor[:, :sample.size()[-1]] = sample


        return padded_tensor

In [24]:
sample = torch.arange(2249).repeat(4,1)
RandomCrop(2249)(sample)

tensor([[   0,    0,    0,  ..., 1857, 1858, 1859],
        [   0,    0,    0,  ..., 1857, 1858, 1859],
        [   0,    0,    0,  ..., 1857, 1858, 1859],
        [   0,    0,    0,  ..., 1857, 1858, 1859]])