In [3]:
%load_ext autoreload
%autoreload 2
# %matplotlib inline
# %matplotlib notebook

In [4]:
import sys
import os

sys.path.append('../tools')

In [5]:
import numpy as np
import matplotlib

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import sklearn

import torch
import torch.nn as nn
from torch import Tensor
import math
from torchsummary import summary
import math
import lightning as L
import sys
import numpy as np
import time
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from model_config import MODEL_CONFIG
from lightly.models.modules import SwaVPrototypes, SwaVProjectionHead

import torchvision.transforms as T
import pytorch_lightning as pl
from lightning.pytorch import callbacks as pl_callbacks
from lightning.pytorch import loggers as pl_loggers

matplotlib.use("nbAgg")

import data_utility

from itertools import groupby
from tqdm import tqdm

In [6]:
import matplotlib.pyplot as plt

In [7]:
data_dir = "../../../user_data/"
log_folder_root = '../../../user_data/logs/'
ckpt_folder_root = '../../../user_data/checkpoints/'
torch.set_float32_matmul_precision('medium')

random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)
np.random.seed(random_seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    # True ensures the algorithm selected by CUFA is deterministic
    torch.backends.cudnn.deterministic = True
    # torch.set_deterministic(True)
    # False ensures CUDA select the same algorithm each time the application is run
    torch.backends.cudnn.benchmark = False

In [8]:
dir_list = os.listdir(data_dir+'rns_data')
patientIDs = [s for s in dir_list for type_string in ['HUP', 'RNS'] if type_string in s.upper()]

In [9]:
os.listdir(data_dir+'rns_cache')

['HUP047.npy',
 'HUP059.npy',
 'HUP084.npy',
 'HUP096.npy',
 'HUP101.npy',
 'HUP108.npy',
 'HUP109.npy',
 'HUP121.npy',
 'HUP127.npy',
 'HUP128.npy',
 'HUP129.npy',
 'HUP131.npy',
 'HUP136.npy',
 'HUP137.npy',
 'HUP143.npy',
 'HUP147.npy',
 'HUP153.npy',
 'HUP156.npy',
 'HUP159.npy',
 'HUP182.npy',
 'HUP192.npy',
 'HUP197.npy',
 'HUP199.npy',
 'HUP205.npy',
 'RNS021.npy',
 'RNS022.npy',
 'RNS026.npy',
 'RNS029.npy']

In [10]:
data_import = data_utility.read_files(path = data_dir+'rns_data', path_data = data_dir+'rns_raw_cache',patientIDs=patientIDs[:10], annotation_only = False, verbose=True)
ids = list(data_import.keys())

100%|██████████| 10/10 [00:05<00:00,  1.88it/s]


In [11]:
window_len = 10
stride = 10
concat_n = 1

id = ids[0]

data_import[id].set_window_parameter(window_length=window_len, window_displacement=stride)
data_import[id].normalize_data()
_, sliced_data = data_import[id].get_windowed_data(data_import[id].catalog["Event Start idx"],data_import[id].catalog["Event End idx"])

In [12]:
class RNSDataset(Dataset):
    def __init__(self, sliced_data, transform=False):
        # load data
        self.data = torch.tensor(sliced_data)
        self.transform = transform

        self.tensor_transform = T.Compose([
            T.RandomApply([T.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)], p=0.8),
            T.RandomApply([T.GaussianBlur(kernel_size=(5, 5))], p=0.6),
        ])

    def get_cropped_data(self, tensor, cropped_len):

        resized_len = 2000
        tensor = tensor.transpose(1,0)

        start_indices = torch.randint(low=0, high=tensor.size()[-1] - cropped_len, size=(1,))
        indices = (start_indices + torch.arange(cropped_len)).repeat(tensor.size(0), 1)
        cropped_tensor = tensor.gather(dim=1, index=indices)

        cropped_tensor = cropped_tensor.repeat_interleave(int(resized_len/cropped_len),dim= 1)

        return cropped_tensor

    def multicrop(self, tensor):
        # get multiple crops of the data in time domain, but perseveres all the channels

        size_list = [2000, 2000, 500, 500, 500, 500]
        if self.transform:
            data_list = [self.tensor_transform(self.get_cropped_data(tensor, sl).unsqueeze(0)).squeeze(0)
                         for sl in size_list]
        else:
            data_list = [self.get_cropped_data(tensor, sl) for sl in size_list]

        return data_list

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample_data = self.data[idx]

        sample_data = self.multicrop(sample_data)

        return sample_data, idx

In [13]:
# next(iter(RNSDataset(sliced_data)))[0][2].size()

In [14]:
# tt = torch.tensor([[1,2,3,4,5,6],[1,2,3,4,5,6]])
# tt.repeat_interleave(3,dim= 1)

In [15]:
# tt.size()

In [16]:
# from models.rns_dataloader import RNS_Raw
# # unlabeled_dataset = RNS_Raw(file_list, transform=True,astensor = False)
# from lightly.data import SwaVCollateFunction
#
# dir_list = os.listdir(data_dir+'rns_cache')
# # file_list = ['HUP084.npy','HUP131.npy','HUP096.npy']
# # if self.current_epoch == 0:
# #     file_list = ['HUP101.npy']
# file_list = ['RNS026.npy', 'HUP159.npy', 'HUP129.npy', 'HUP096.npy']
# unlabeled_dataset = RNS_Raw(file_list, transform=True, astensor=True)

In [17]:
# next(iter(unlabeled_dataset))[0].size()

In [20]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head),
                FeedForward(dim, mlp_dim)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)


# %%
class Transpose(nn.Module):
    def __init__(self, dim1, dim2):
        super(Transpose, self).__init__()
        self.dim1 = dim1
        self.dim2 = dim2

    def forward(self, x):
        return x.transpose(self.dim1, self.dim2)


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# %%
class MultivarWav2Vec2(nn.Module):
    """
    Multivariate Wav2Vec2 model. This model takes in multiple 1D waveforms as input and combines the representations.
    The input waveforms are passed through individual 1D convolutional layers, and the representations are combined
    by taking the mean across the channels. The combined representations are then passed through a Wav2Vec2 transformer
    encoder.
    """

    def __init__(self, config):
        super(MultivarWav2Vec2, self).__init__()

        #  1D convolutions applied to each waveform, a series of 1D convolutions
        #  followed by layer normalization and GELU activation
        self.ft_enc = nn.ModuleList()
        for i, _ in enumerate(config.ft_enc_dims):
            if i == 0:
                self.ft_enc.append(
                    nn.Conv1d(
                        in_channels=4,
                        out_channels=config.ft_enc_dims[i],
                        kernel_size=config.ft_enc_kernel_widths[i],
                        stride=config.ft_enc_strides[i],
                        padding=0,
                        groups=config.channel_buffer_size,
                    )
                )
            else:
                self.ft_enc.append(
                    nn.Conv1d(
                        in_channels=config.ft_enc_dims[i - 1],
                        out_channels=config.ft_enc_dims[i],
                        kernel_size=config.ft_enc_kernel_widths[i],
                        stride=config.ft_enc_strides[i],
                        padding=0,
                        groups=config.channel_buffer_size,
                    )
                )
            # transpose the output of the convolutional layer
            self.ft_enc.append(Transpose(1, 2))
            # layer normalization
            self.ft_enc.append(nn.LayerNorm(config.ft_enc_dims[i]))
            # GELU activation
            self.ft_enc.append(nn.GELU())
            # transpose the output of the convolutional layer
            self.ft_enc.append(Transpose(1, 2))

        # add a adaptive pool so different sized crops can be the same length afterward
        self.ft_enc.append(nn.AdaptiveAvgPool1d(config.spatial_transformer_hidden))

        # convert the list of modules to a sequential module
        self.ft_enc = nn.Sequential(*self.ft_enc)

        self.fc_layer = nn.Linear(
            in_features=config.ft_enc_dims[-1] * config.channel_buffer_size,
            out_features=1,
        )

        # positional encoding
        self.temporal_pos_encoder = PositionalEncoding(config.temporal_transformer_hidden, config.dropout)

        # add learnable class token for embeddings
        self.class_token = nn.Parameter(torch.randn(1, 1, config.temporal_transformer_hidden))

        self.projection_head = SwaVProjectionHead(config.temporal_transformer_hidden, 256, config.prototype_dim)

        self.prototypes = SwaVPrototypes(config.prototype_dim, n_prototypes=config.prototype_n,
                                         n_steps_frozen_prototypes=0)

        self.spatial_transformer = Transformer(config.spatial_transformer_hidden,
                                               config.spatial_transformer_blocks,
                                               config.spatial_transformer_heads,
                                               config.spatial_transformer_inner_heads,
                                               config.spatial_transformer_mlp_dim)

        self.temporal_transformer = Transformer(config.temporal_transformer_hidden,
                                                config.temporal_transformer_blocks,
                                                config.temporal_transformer_heads,
                                                config.temporal_transformer_inner_heads,
                                                config.temporal_transformer_mlp_dim)
        self.config = config

        crop_transforms = []
        crop_sizes=[224, 96]
        crop_min_scales = [0.14, 0.05]
        crop_max_scales = [1.0, 0.14]
        crop_counts = [2, 6]
        for i in range(len(crop_sizes)):
            random_resized_crop = T.RandomResizedCrop(crop_sizes[i], scale=(crop_min_scales[i], crop_max_scales[i]))

            crop_transforms.extend([T.Compose([random_resized_crop, ])] * crop_counts[i])
        crop_transforms

        self.crop_transforms = crop_transforms
    def forward(self, x):

        # x is a list of waveforms, each of shape (batch_size, num_channels, sequence_length)

        # computing eeg_channel wise convolution, reshape to put eeg_channels in batches so it can run faster
        # x = torch.reshape(x, (-1, x.size()[-1])).unsqueeze(1)

        # print(x.size())
        x = self.ft_enc(x)

        views = []
        for tf in self.crop_transforms:
            views.append(tf(x))

        x = views[2]


        # print(x.size())
        # x = x.reshape((-1, self.config.spatial_transformer_hidden)).reshape(
        #     (-1, self.config.temporal_transformer_hidden, self.config.spatial_transformer_hidden))
        # print(x.size())
        # print('=======================')


        # print(x.size())

        # pass to spatial encoder
        # x = self.spatial_transformer(x)
        #
        # x = x.transpose(1, 2)
        #
        # # add class tokens
        # cls_tokens = repeat(self.class_token, '1 1 d -> b 1 d', b=x.size(0))
        # x = torch.cat((cls_tokens, x), dim=1)
        # # add temporal positional encoding
        # x = self.temporal_pos_encoder(x.transpose(0, 1)).transpose(0, 1)
        # # pass to temporal encoder
        # x = self.temporal_transformer(x)
        # # recover class token
        # x = x[:, 0]
        # # x = x.flatten(1)
        #
        # # pass to swav projection head
        # x = self.projection_head(x)
        # x = nn.functional.normalize(x, dim=1, p=2)

        return x

In [21]:
from torchsummary import summary

model = MultivarWav2Vec2(MODEL_CONFIG)
summary(model.cuda(), (4, 2000))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 16, 666]              64
         Transpose-2              [-1, 666, 16]               0
         LayerNorm-3              [-1, 666, 16]              32
              GELU-4              [-1, 666, 16]               0
         Transpose-5              [-1, 16, 666]               0
            Conv1d-6              [-1, 64, 332]             832
         Transpose-7              [-1, 332, 64]               0
         LayerNorm-8              [-1, 332, 64]             128
              GELU-9              [-1, 332, 64]               0
        Transpose-10              [-1, 64, 332]               0
           Conv1d-11             [-1, 128, 166]           4,224
        Transpose-12             [-1, 166, 128]               0
        LayerNorm-13             [-1, 166, 128]             256
             GELU-14             [-1, 1

In [18]:
import torch
import torch.nn as nn
from torch import Tensor
import math

import lightning as L
import torch.optim as optim
import torch.nn.functional as F
from lightly.loss import SwaVLoss
from lightly.models import utils
from typing import Optional, Sequence, Tuple, Union
from model_config import MODEL_CONFIG
from torch.nn import Module


class MemoryBankModule(Module):
    """Memory bank implementation

    This is a parent class to all loss functions implemented by the lightly
    Python package. This way, any loss can be used with a memory bank if
    desired.

    Attributes:
        size:
            Size of the memory bank as (num_features, dim) tuple. If num_features is 0
            then the memory bank is disabled. Deprecated: If only a single integer is
            passed, it is interpreted as the number of features and the feature
            dimension is inferred from the first batch stored in the memory bank.
            Leaving out the feature dimension might lead to errors in distributed
            training.
        gather_distributed:
            If True then negatives from all gpus are gathered before the memory bank
            is updated. This results in more frequent updates of the memory bank and
            keeps the memory bank contents independent of the number of gpus. But it has
            the drawback that synchronization between processes is required and
            diversity of the memory bank content is reduced.
        feature_dim_first:
            If True, the memory bank returns features with shape (dim, num_features).
            If False, the memory bank returns features with shape (num_features, dim).

    Examples:
        >>> class MyLossFunction(MemoryBankModule):
        >>>
        >>>     def __init__(self, memory_bank_size: Tuple[int, int] = (2 ** 16, 128)):
        >>>         super().__init__(memory_bank_size)
        >>>
        >>>     def forward(self, output: Tensor, labels: Optional[Tensor] = None):
        >>>         output, negatives = super().forward(output)
        >>>
        >>>         if negatives is not None:
        >>>             # evaluate loss with negative samples
        >>>         else:
        >>>             # evaluate loss without negative samples

    """

    def __init__(
            self,
            size: Union[int, Sequence[int]] = 65536,
            gather_distributed: bool = False,
            feature_dim_first: bool = True,
    ):
        super().__init__()
        size_tuple = (size,) if isinstance(size, int) else tuple(size)

        if any(x < 0 for x in size_tuple):
            raise ValueError(
                f"Illegal memory bank size {size}, all entries must be non-negative."
            )

        self.size = size_tuple
        self.gather_distributed = gather_distributed
        self.feature_dim_first = feature_dim_first
        self.bank: Tensor
        self.register_buffer(
            "bank",
            tensor=torch.empty(size=size_tuple, dtype=torch.float),
            persistent=False,
        )
        self.bank_ptr: Tensor
        self.register_buffer(
            "bank_ptr",
            tensor=torch.empty(1, dtype=torch.long),
            persistent=False,
        )

        if isinstance(size, int) and size > 0:
            warnings.warn(
                (
                    f"Memory bank size 'size={size}' does not specify feature "
                    "dimension. It is recommended to set the feature dimension with "
                    "'size=(n, dim)' when creating the memory bank. Distributed "
                    "training might fail if the feature dimension is not set."
                ),
                UserWarning,
            )
        elif len(size_tuple) > 1:
            self._init_memory_bank(size=size_tuple)

    @torch.no_grad()
    def _init_memory_bank(self, size: Tuple[int, ...]) -> None:
        """Initialize the memory bank.

        Args:
            size:
                Size of the memory bank as (num_features, dim) tuple.

        """
        self.bank = torch.randn(size).type_as(self.bank)
        self.bank = torch.nn.functional.normalize(self.bank, dim=-1)
        self.bank_ptr = torch.zeros(1).type_as(self.bank_ptr)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, batch: Tensor) -> None:
        """Dequeue the oldest batch and add the latest one

        Args:
            batch:
                The latest batch of keys to add to the memory bank.

        """
        if self.gather_distributed:
            batch = utils.concat_all_gather(batch)

        batch_size = batch.shape[0]
        ptr = int(self.bank_ptr)
        if ptr + batch_size >= self.size[0]:
            self.bank[ptr:] = batch[: self.size[0] - ptr].detach()
            self.bank_ptr.zero_()
        else:
            self.bank[ptr: ptr + batch_size] = batch.detach()
            self.bank_ptr[0] = ptr + batch_size

    def forward(
            self,
            output: Tensor,
            labels: Optional[Tensor] = None,
            update: bool = False,
    ) -> Tuple[Tensor, Union[Tensor, None]]:
        """Query memory bank for additional negative samples

        Args:
            output:
                The output of the model.
            labels:
                Should always be None, will be ignored.
            update:
                If True, the memory bank will be updated with the current output.

        Returns:
            The output if the memory bank is of size 0, otherwise the output
            and the entries from the memory bank. Entries from the memory bank have
            shape (dim, num_features) if feature_dim_first is True and
            (num_features, dim) otherwise.

        """

        # no memory bank, return the output
        if self.size[0] == 0:
            return output, None

        # Initialize the memory bank if it is not already done.
        if self.bank.ndim == 1:
            dim = output.shape[1:]
            self._init_memory_bank(size=(*self.size, *dim))

        # query and update memory bank
        bank = self.bank.clone().detach()
        if self.feature_dim_first:
            # swap bank size and feature dimension for backwards compatibility
            bank = bank.transpose(0, -1)

        # only update memory bank if we later do backward pass (gradient)
        if update:
            self._dequeue_and_enqueue(output)

        return output, bank

In [19]:
class WrapperWithQueue(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.criterion = SwaVLoss()

        self.start_queue_at_epoch = 50
        self.queues = nn.ModuleList([MemoryBankModule(size=(3084, 128)) for _ in range(2)])

    def training_step(self, batch, batch_idx):
        # normalize prototype so it is on a sphere
        self.model.prototypes.normalize()
        signals, idx = batch

        high_resolution, low_resolution = signals[:2], signals[2:]

        high_resolution_features = [self.model(x.float().to(self.device)) for x in high_resolution]
        low_resolution_features = [self.model(x.float().to(self.device)) for x in low_resolution]

        high_resolution_prototypes = [self.model.prototypes(x, self.current_epoch) for x in high_resolution_features]
        low_resolution_prototypes = [self.model.prototypes(x, self.current_epoch) for x in low_resolution_features]

        queue_prototypes = self._get_queue_prototypes(high_resolution_features)

        loss = self.criterion(high_resolution_prototypes, low_resolution_prototypes, queue_prototypes)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    @torch.no_grad()
    def _get_queue_prototypes(self, high_resolution_features):
        if len(high_resolution_features) != len(self.queues):
            raise ValueError(
                f"The number of queues ({len(self.queues)}) should be equal to the number of high "
                f"resolution inputs ({len(high_resolution_features)}). Set `n_queues` accordingly."
            )

        # Get the queue features
        queue_features = []
        for i in range(len(self.queues)):
            _, features = self.queues[i](high_resolution_features[i], update=True)
            # Queue features are in (num_ftrs X queue_length) shape, while the high res
            # features are in (batch_size X num_ftrs). Swap the axes for interoperability.
            features = torch.permute(features, (1, 0))
            queue_features.append(features)

        # If loss calculation with queue prototypes starts at a later epoch,
        # just queue the features and return None instead of queue prototypes.
        if (
                self.start_queue_at_epoch > 0
                and self.current_epoch < self.start_queue_at_epoch
        ):
            return None

        # Assign prototypes
        queue_prototypes = [
            self.model.prototypes(x, self.current_epoch) for x in queue_features
        ]
        return queue_prototypes

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=MODEL_CONFIG.learning_rate)

        return optimizer

In [20]:
unlabeled_dataset = RNSDataset(sliced_data)

dataloader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True
)

In [21]:
model_nn = MultivarWav2Vec2(MODEL_CONFIG)
model = WrapperWithQueue(model_nn)

accelerator = "gpu" if torch.cuda.is_available() else "cpu"

checkpoint_callback = pl_callbacks.ModelCheckpoint(monitor='train_loss',
                                                   filename='model_epoch-{epoch:02d}-{train_loss:.5f}',
                                                   save_top_k=-1,
                                                   every_n_epochs=5,
                                                   # enable_version_counter=True,
                                                   dirpath= 'checkpoints')

early_stopping = pl_callbacks.EarlyStopping(monitor="train_loss",
                                            mode="min",
                                            patience=15)
csv_logger = pl_loggers.CSVLogger('logs',
                                  name="log")

trainer = L.Trainer(log_every_n_steps=5,
                    logger=csv_logger,
                    max_epochs=500,
                    callbacks=[checkpoint_callback,early_stopping],
                    accelerator='gpu',
                    devices=1,
                    precision = 16)
trainer.fit(model=model, train_dataloaders=dataloader)

C:\Users\Patrick Xu\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\fabric\connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\Patrick Xu\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:630: Checkpoint directory C:\Users\Patrick Xu\Desktop\RNS_Annotation-Pipeline\scripts\RNS_LITT_ANNOTATION_PIPELINE\rns_scripts\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | MultivarWav2Vec2 | 14.4 M
1 | criterion | SwaVLoss         | 0     
2 | queues    | ModuleList       | 0     
--------------

Training: |          | 0/? [00:00<?, ?it/s]

C:\Users\Patrick Xu\AppData\Local\Programs\Python\Python310\lib\site-packages\lightning\pytorch\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
