In [13]:
import wfdb
import pickle
from scipy import interpolate

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.models import resnet18, resnet34, resnet50
from tqdm import tqdm
from collections import Counter
from info_nce import InfoNCE
import matplotlib.colors as mcolors

import os
import pandas as pd
import einops
from scipy.stats import mode
import torch.optim as optim
from sklearn.metrics import f1_score
import sys
from torch.utils import data

from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import src.config, src.utils, src.models, src.data
import torchvision

import torch
from torch.utils.data import DataLoader, Dataset, Subset
import random

from sklearn.datasets import load_iris
from sklearn.cluster import KMeans
from sklearn.metrics import davies_bouldin_score
from sklearn.metrics import calinski_harabasz_score
import matplotlib.pyplot as plt
import numpy
from sklearn.metrics import precision_score, recall_score

from torch.nn import GRU, Linear, CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import random_split

print("torchvision version:", torchvision.__version__)

torchvision version: 0.17.0


In [14]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True)

In [15]:
# Local imports
# from deepecg.config.config import DATA_DIR
DATA_DIR = "./data"

In [16]:
afib_dict = {"AFIB":0, "AFL":1, "J":2, "N":3}
flipped_afib_dict = {0: "AFIB", 1: "AFL", 2: "J", 3: "N"}

In [17]:
class AFDB(object):
    """
        The MIT-BIH Atrial Fibrillation Database
        https://physionet.org/physiobank/database/afdb/
        """

    def __init__(self):
        # Set attributes
        self.db_name = 'afdb'
        self.raw_path = os.path.join(DATA_DIR, 'mittecg')
        self.processed_path = os.path.join(DATA_DIR, 'processed')
        self.label_dict = {
                'AFIB': 'atrial fibrillation',
                'ASYS': 'asystole',
                'B': 'ventricular bigeminy',
                'BI': 'first degree heart block',
                'HGEA': 'high grade ventricular ectopic activity',
                'N': 'normal sinus rhythm',
                'NSR': 'normal sinus rhythm',
                'NOD': 'nodal ("AV junctional") rhythm',
                'NOISE': 'noise',
                'PM': 'pacemaker (paced rhythm)',
                'SBR': 'sinus bradycardia',
                'SVTA': 'supraventricular tachyarrhythmia',
                'VER': 'ventricular escape rhythm',
                'VF': 'ventricular fibrillation',
                'VFIB': 'ventricular fibrillation',
                'VFL': 'ventricular flutter',
                'VT': 'ventricular tachycardia'
            }

        self.fs = 300
        self.length = 60
        self.length_sp = self.length * self.fs
        self.record_ids = None
        self.sections = None
        self.samples = None
        self.labels = None

    def generate_db(self):
        """Generate raw and processed databases."""
        # Generate raw database
        self.generate_raw_db()

        # Generate processed database
        self.generate_processed_db()

    def generate_raw_db(self):
        """Generate the raw version of the MIT-BIH Atrial Fibrillation database in the 'raw' folder."""
        # Download database
        if len(os.listdir(self.raw_path))==0:
            print('Generating Raw MIT-BIH Atrial Fibrillation Database ...')
            wfdb.dl_database(self.db_name, self.raw_path)
            print('Complete!\n')

        # Get list of recordings
        self.record_ids = [file.split('.')[0] for file in os.listdir(self.raw_path) if '.dat' in file]

    def generate_processed_db(self):
        """Generate the processed version of the MIT-BIH Atrial Fibrillation database in the 'processed' folder."""
        print('Generating Processed MIT-BIH Atrial Fibrillation Database ...')
        all_signals, all_labels = self._get_sections()

        signal_lens = [len(sig) for sig in all_labels]
        all_signals = np.array([sig[:,:min(signal_lens)] for sig in all_signals])
        all_labels = np.array([sig[:min(signal_lens)] for sig in all_labels])

        

        # Normalize signals
        data_n, data_n = self._normalize(all_signals, all_signals)


        # Save signals to file
        if not os.path.exists(self.processed_path):
            os.mkdir(self.processed_path)
        
        with open(os.path.join(self.processed_path, 'tensor_data.pkl'), 'wb') as f:
            pickle.dump(data_n, f)
        with open(os.path.join(self.processed_path, 'tensor_label.pkl'), 'wb') as f:
            pickle.dump(all_labels, f)
        

    def _normalize(self, train_data, test_data):
        """ Calculate the mean and std of each feature from the training set
        """
        feature_means = np.mean(train_data, axis=(0, 2))
        feature_std = np.std(train_data, axis=(0, 2))
        train_data_n = (train_data - feature_means[np.newaxis, :, np.newaxis]) / \
                       np.where(feature_std == 0, 1, feature_std)[np.newaxis, :, np.newaxis]
        test_data_n = (test_data - feature_means[np.newaxis, :, np.newaxis]) /\
                      np.where(feature_std == 0, 1, feature_std)[np.newaxis, :, np.newaxis]
        return train_data_n, test_data_n

    def _get_sections(self):
        """Collect continuous arrhythmia sections."""
        # Empty dictionary for arrhythmia sections
        all_signals = []
        all_labels = []

        # Loop through records
        for record_id in self.record_ids:
            # Import recording
            record = wfdb.rdrecord(os.path.join(self.raw_path, record_id))

            # Import annotations
            annotation = wfdb.rdann(os.path.join(self.raw_path, record_id), 'atr')

            # Get sample frequency
            fs = record.__dict__['fs']

            # Get waveform
            waveform = record.__dict__['p_signal']  #shape: (length, n_channels=2)

            # labels
            labels = [label[1:] for label in annotation.__dict__['aux_note']]

            # Samples
            sample = annotation.__dict__['sample']

            padded_labels = np.zeros(len(waveform))
            labels = [label.strip('\x00') for label in labels]
            for i,l in enumerate(labels):

                if i==len(labels)-1:
                    padded_labels[sample[i]:] = afib_dict[l]
                else:
                    padded_labels[sample[i]:sample[i+1]] = afib_dict[l]
            padded_labels = padded_labels[sample[0]:]
            all_labels.append(padded_labels)
            all_signals.append(waveform[sample[0]:,:].T)

        return all_signals, all_labels

In [18]:
def feature_stft(x_train, n_fft = 100, hop_length=50, win_length=100, phase = False, stack_axes = True):
    train_tensor = torch.tensor(x_train).transpose(1,2).reshape(-1, 2).transpose(0,1)
    
    x = torch.stft(
        input=train_tensor,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=torch.hann_window(win_length),
        center=False,
        return_complex=True)  # [num_channels, num_bins, num_frames]

    x_cartesian = src.utils.complex_to_cartesian(x)
    x_magnitude = src.utils.complex_to_magnitude(x, expand=True)

    x = x_cartesian if phase else x_magnitude
    if stack_axes:
        # Stack all spectrograms and put time dim first:
        # [num_channels, num_bins, num_frames, stft_parts] ->
        # [num_frames, num_channels x num_bins x stft_parts]
        x = einops.rearrange(x, 'C F T P -> T (C F P)')  # P=2
    else:
        x = einops.rearrange(x, 'C F T P -> T C F P')
    
    return x

In [19]:
# Function to load a pickle file
def load_pickle_file(filepath):
    with open(filepath, 'rb') as file:
        data = pickle.load(file)
    return data

def feature_stft(x_train, n_fft = 100, hop_length=50, win_length=100, phase = False, stack_axes = True):
    
    print(x_train.shape)
    train_tensor = torch.tensor(x_train).transpose(1,2).reshape(-1, 2).transpose(0,1)
    
    x = torch.stft(
        input=train_tensor,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=torch.hann_window(win_length),
        center=False,
        return_complex=True)  # [num_channels, num_bins, num_frames]

    x_cartesian = src.utils.complex_to_cartesian(x)
    x_magnitude = src.utils.complex_to_magnitude(x, expand=True)

    x = x_cartesian if phase else x_magnitude
    if stack_axes:
        # Stack all spectrograms and put time dim first:
        # [num_channels, num_bins, num_frames, stft_parts] ->
        # [num_frames, num_channels x num_bins x stft_parts]
        x = einops.rearrange(x, 'C F T P -> T (C F P)')  # P=2
    else:
        x = einops.rearrange(x, 'C F T P -> T C F P')
    
    return x

def windowed_labels(
    labels,
    num_labels,
    frame_length,
    frame_step=None,
    pad_end=False,
    kind='density',
):
    """Generates labels that correspond to STFTs

    With kind=None we are able to split the given labels
    array into batches. (T, C) -> (B, T', C)

    Parameters
    ----------
    labels : np.array

    Returns
    -------
    np.array
    """
    labels = torch.tensor(labels).view(-1)
    
    # Labels should be a single vector (int-likes) or kind has to be None
    labels = np.asarray(labels)
    
    if kind is not None and not labels.ndim == 1:
        raise ValueError('Labels must be a vector')
    if not (labels >= 0).all():
        raise ValueError('All labels must be >= 0')
    if not (labels < num_labels).all():
        raise ValueError(f'All labels must be < {num_labels} (num_labels)')
    # Kind determines how labels in each window should be processed
    if not kind in {'counts', 'density', 'onehot', 'argmax', None}:
        raise ValueError('`kind` must be in {counts, density, onehot, argmax, None}')
    # Let frame_step default to one full frame_length
    frame_step = frame_length if frame_step is None else frame_step
    # Process labels with a sliding window. TODO: vectorize?
    output = []
    for i in range(0, len(labels), frame_step):
        chunk = labels[i:i+frame_length]
        chunk = chunk.astype(int)
        # Ignore incomplete end chunk unless padding is enabled
        if len(chunk) < frame_length and not pad_end:
            continue
        # Just append the chunk if kind is None
        if kind == None:
            output.append(chunk)
            continue
        # Count the occurences of each label
        counts = np.bincount(chunk, minlength=num_labels)
        # Then process based on kind
        if kind == 'counts':
            output.append(counts)
        elif kind == 'density':
            output.append(counts / len(chunk))
        elif kind == 'onehot':
            one_hot = np.zeros(num_labels)
            one_hot[np.argmax(counts)] = 1
            output.append(one_hot)
        elif kind == 'argmax':
            output.append(np.argmax(counts))
    if pad_end:
        return output
    else:
        return torch.tensor(output)
    
class STFTDataset(Dataset):
    
    def __init__(self, data_path, class_to_exclude=3, n_fft = 250, hop_length=125, win_length=250, seq_length=500, num_labels=4):
        """
        Args:
            x_data (Tensor): The input features, e.g., from STFT.
            y_data (Tensor): The corresponding labels, windowed and processed.
            seq_length (int): The length of each sequence.
        """
        
        # Load each of the pickle files
        tensor_data = load_pickle_file(f'{data_path}/tensor_data.pkl')
        tensor_label = load_pickle_file(f'{data_path}/tensor_label.pkl')
        
        
        x_data = feature_stft(tensor_data, n_fft = n_fft, hop_length=hop_length, win_length=win_length)
        y_data = windowed_labels(labels=tensor_label, num_labels=num_labels, frame_length=n_fft, frame_step=hop_length, kind='argmax')

        self.x_data = x_data
        self.y_data = y_data
        self.seq_length = seq_length
        self.class_to_exclude = class_to_exclude
        
        # Create a mask that filters out the class_to_exclude
        mask = self.y_data != self.class_to_exclude

        # Apply the mask to filter the data
        self.x_data = self.x_data[mask]
        self.y_data = self.y_data[mask]
        
    def __len__(self):
        # Return the number of full sequences in the dataset
        return len(self.x_data) // self.seq_length

    def __getitem__(self, idx):
        """
        Returns a tuple (input, label) for the given index.
        The input is reshaped to (seq_length, features).
        """
        start_idx = idx * self.seq_length
        end_idx = start_idx + self.seq_length

        # Extract the sequence of data and corresponding labels
        x_seq = self.x_data[start_idx:end_idx]
        y_seq = self.y_data[start_idx:end_idx]

        return x_seq, y_seq

In [20]:
dataset = STFTDataset("data/processed", n_fft = 250, seq_length=119, class_to_exclude=7)

(23, 2, 8324850)


In [21]:
a = AFDB()
a.generate_raw_db()
a.generate_processed_db()

Generating Processed MIT-BIH Atrial Fibrillation Database ...


In [22]:
# Function to load a pickle file
def load_pickle_file(filepath):
    with open(filepath, 'rb') as file:
        data = pickle.load(file)
    return data

# Load each of the pickle files
tensor_data = load_pickle_file('./data/processed/tensor_data.pkl')
tensor_label = load_pickle_file('./data/processed/tensor_label.pkl')

In [23]:
x_data = feature_stft(tensor_data, n_fft = 250, hop_length=125, win_length=250)
y_data = windowed_labels(labels=tensor_label, num_labels=4, frame_length=250, frame_step=125, kind='argmax')

(23, 2, 8324850)


In [24]:
x_data.shape

torch.Size([1531771, 252])

In [25]:
y_data.shape

torch.Size([1531771])

In [26]:
# Get unique labels and their counts
unique_labels, counts = np.unique(y_data, return_counts=True)

# Print the results
for label, count in zip(unique_labels, counts):
    print(f"Label {label}: {count} occurrences")

Label 0: 612853 occurrences
Label 1: 4974 occurrences
Label 2: 661 occurrences
Label 3: 913283 occurrences


In [27]:
def count_consecutive_labels(labels):
    if len(labels) == 0:
        return []

    result = []
    current_label = labels[0]
    count = 1

    for i in range(1, len(labels)):
        if labels[i] == current_label:
            count += 1
        else:
            result.append((current_label, count))
            current_label = labels[i]
            count = 1

    # Append the last label count
    result.append((current_label, count))
    return result

labels = y_data
consecutive_counts = count_consecutive_labels(labels)

for label, count in consecutive_counts:
    if label != 3:
        print(f"Label {label}: {count} times")

Label 0: 625 times
Label 0: 1341 times
Label 0: 881 times
Label 1: 955 times
Label 0: 12314 times
Label 1: 2139 times
Label 0: 2133 times
Label 0: 689 times
Label 0: 91 times
Label 0: 168 times
Label 1: 12 times
Label 0: 391 times
Label 0: 732 times
Label 0: 805 times
Label 0: 846 times
Label 0: 471 times
Label 0: 487 times
Label 0: 1073 times
Label 0: 140 times
Label 0: 155 times
Label 0: 829 times
Label 0: 202 times
Label 0: 697 times
Label 0: 262 times
Label 0: 152 times
Label 0: 288 times
Label 0: 149 times
Label 0: 358 times
Label 0: 214 times
Label 0: 140 times
Label 0: 198 times
Label 0: 253 times
Label 0: 376 times
Label 0: 73 times
Label 0: 114 times
Label 0: 330 times
Label 0: 151 times
Label 0: 844 times
Label 0: 408 times
Label 0: 522 times
Label 0: 213 times
Label 0: 524 times
Label 0: 528 times
Label 0: 154 times
Label 0: 460 times
Label 0: 47 times
Label 0: 474 times
Label 0: 1545 times
Label 0: 206 times
Label 0: 87 times
Label 0: 4212 times
Label 0: 177 times
Label 0: 

In [28]:
class CustomDataset(Dataset):
    def __init__(self, x_data, y_data, seq_length, class_to_exclude):
        """
        Args:
            x_data (Tensor): The input features, e.g., from STFT.
            y_data (Tensor): The corresponding labels, windowed and processed.
            seq_length (int): The length of each sequence.
        """
        self.x_data = x_data
        self.y_data = y_data
        self.seq_length = seq_length
        self.class_to_exclude = class_to_exclude
        
        # Create a mask that filters out the class_to_exclude
        mask = self.y_data != self.class_to_exclude

        # Apply the mask to filter the data
        self.x_data = self.x_data[mask]
        self.y_data = self.y_data[mask]
        
    def __len__(self):
        # Return the number of full sequences in the dataset
        return len(self.x_data) // self.seq_length

    def __getitem__(self, idx):
        """
        Returns a tuple (input, label) for the given index.
        The input is reshaped to (seq_length, features).
        """
        start_idx = idx * self.seq_length
        end_idx = start_idx + self.seq_length

        # Extract the sequence of data and corresponding labels
        x_seq = self.x_data[start_idx:end_idx]
        y_seq = self.y_data[start_idx:end_idx]

        return x_seq, y_seq

In [29]:
class SequentialRandomSampler(torch.utils.data.Sampler):
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size


    def __iter__(self):
        

        indices = list(range(len(self.data_source)))
        
        remaining = len(indices) % self.batch_size
        if remaining > 0:
            indices = indices[:-remaining]
        final_indices = np.reshape(indices, (-1, self.batch_size))

        # Shuffle the batches
        np.random.shuffle(final_indices)

        # Flatten the list of batches to get the final order of indices
        final_indices = [idx for batch in final_indices for idx in batch]
        
        return iter(final_indices)

    def __len__(self):
        return len(self.data_source)

In [30]:
# dataset = CustomDataset(x_data, y_data, seq_length=500, class_to_exclude=3)
valid_split = 0.2
            
valid_amount = int(np.floor(len(dataset)*valid_split))
train_amount = len(dataset) - valid_amount


train_indices = list(range(train_amount))
valid_indices = list(range(train_amount, train_amount + valid_amount))

# Create subsets
train_ds = Subset(dataset, train_indices)
valid_ds = Subset(dataset, valid_indices)

# train_ds, valid_ds = random_split(dataset, [train_amount, valid_amount])
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

batch_size= 32
train_loader = torch.utils.data.DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
#     sampler=SequentialRandomSampler(train_ds, batch_size),
    shuffle = True,
    num_workers=0,
    drop_last = True,
    worker_init_fn=seed_worker
)

valid_loader = torch.utils.data.DataLoader(
    dataset=valid_ds,
    batch_size=batch_size,
#             sampler=SequentialRandomSampler(valid_ds, args['batch_size']),
    shuffle = False,
    num_workers=0,
    drop_last = True,
    worker_init_fn=seed_worker
)

In [31]:
train_amount

10298

In [32]:
valid_amount

2574

In [33]:
len(train_ds)

10298

In [34]:
len(valid_ds)

2574

In [35]:
# Create a DataLoader to iterate over the dataset
data_loader = DataLoader(train_ds, batch_size=512, shuffle=False)

# Initialize a list to store all labels
all_labels = []

# Iterate over the DataLoader to collect all labels
for _, labels in data_loader:
    all_labels.extend(labels.numpy())  # Convert the labels to numpy and extend the list

# Convert the list of labels to a numpy array
tall_labels = np.array(all_labels)

# Get unique labels and their counts
unique_labels, counts = np.unique(tall_labels, return_counts=True)

# Print the results
for label, count in zip(unique_labels, counts):
    print(f"Label {label}: {count} occurrences")

Label 0: 454757 occurrences
Label 1: 4002 occurrences
Label 2: 202 occurrences
Label 3: 766501 occurrences


In [36]:
# Create a DataLoader to iterate over the dataset
data_loader = DataLoader(valid_ds, batch_size=512, shuffle=False)

# Initialize a list to store all labels
all_labels = []

# Iterate over the DataLoader to collect all labels
for _, labels in data_loader:
    all_labels.extend(labels.numpy())  # Convert the labels to numpy and extend the list

# Convert the list of labels to a numpy array
vall_labels = np.array(all_labels)

# Get unique labels and their counts
unique_labels, counts = np.unique(vall_labels, return_counts=True)

# Print the results
for label, count in zip(unique_labels, counts):
    print(f"Label {label}: {count} occurrences")

Label 0: 158093 occurrences
Label 1: 972 occurrences
Label 2: 459 occurrences
Label 3: 146782 occurrences


In [37]:
def count_consecutive_labels(labels):
    if len(labels) == 0:
        return []

    result = []
    current_label = labels[0]
    count = 1

    for i in range(1, len(labels)):
        if labels[i] == current_label:
            count += 1
        else:
            result.append((current_label, count))
            current_label = labels[i]
            count = 1

    # Append the last label count
    result.append((current_label, count))
    return result

labels = tall_labels.reshape(-1)
consecutive_counts = count_consecutive_labels(labels)

for label, count in consecutive_counts:
    print(f"Label {label}: {count} times")

Label 3: 17897 times
Label 0: 625 times
Label 3: 49 times
Label 0: 1341 times
Label 3: 61 times
Label 0: 881 times
Label 3: 50583 times
Label 1: 955 times
Label 3: 142 times
Label 0: 12314 times
Label 3: 16270 times
Label 1: 2139 times
Label 3: 159 times
Label 0: 2133 times
Label 3: 66 times
Label 0: 689 times
Label 3: 112 times
Label 0: 91 times
Label 3: 52 times
Label 0: 168 times
Label 3: 7470 times
Label 1: 12 times
Label 3: 43610 times
Label 0: 391 times
Label 3: 145 times
Label 0: 732 times
Label 3: 288 times
Label 0: 805 times
Label 3: 101 times
Label 0: 846 times
Label 3: 342 times
Label 0: 471 times
Label 3: 86 times
Label 0: 487 times
Label 3: 121 times
Label 0: 1073 times
Label 3: 82 times
Label 0: 140 times
Label 3: 103 times
Label 0: 155 times
Label 3: 4477 times
Label 0: 829 times
Label 3: 59 times
Label 0: 202 times
Label 3: 94 times
Label 0: 697 times
Label 3: 61 times
Label 0: 262 times
Label 3: 71 times
Label 0: 152 times
Label 3: 43 times
Label 0: 288 times
Label 3: 

In [38]:
class FeatureProjector(nn.Module):
    def __init__(self, input_size=156, output_size=32):
        super(FeatureProjector, self).__init__()
        
        # 1D Convolutional Layers
        self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=128, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=1)
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=output_size, kernel_size=1)
        
        # Batch Normalization
        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn3 = nn.BatchNorm1d(output_size)
        
    def forward(self, x):
        # Input x shape: (batch_size, sequence_length, input_size)
        
        # Permute to match Conv1D input: (batch_size, input_size, sequence_length)
        x = x.float()
        x = x.permute(0, 2, 1)
        
        # First convolutional layer
        x = self.conv1(x)  # Shape: (batch_size, 128, sequence_length)
        x = self.bn1(x)
        x = F.relu(x)
        
        # Second convolutional layer
        x = self.conv2(x)  # Shape: (batch_size, 64, sequence_length)
        x = self.bn2(x)
        x = F.relu(x)
        
        # Third convolutional layer
        x = self.conv3(x)  # Shape: (batch_size, output_size, sequence_length)
        x = self.bn3(x)
        x = F.relu(x)
        
        # Permute back to original order: (batch_size, sequence_length, output_size)
        x = x.permute(0, 2, 1)
        
        return x

In [39]:
class FeatureProjector2(nn.Module):
    def __init__(self, input_size=156, output_size=32):
        super(FeatureProjector2, self).__init__()
        
        # 1D Convolutional Layers with ReLU activations and Batch Normalization
        self.conv1 = nn.Conv1d(in_channels=input_size, out_channels=128, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(128)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv1d(in_channels=64, out_channels=output_size, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(output_size)
        self.relu3 = nn.ReLU()

        
    def forward(self, x):
        # Input x shape: (batch_size, sequence_length, input_size)
        
        # Permute to match Conv1D input: (batch_size, input_size, sequence_length)
        x = x.float()
        x = x.permute(0, 2, 1)
        
        # First convolutional layer
        x = self.conv1(x)  # Shape: (batch_size, 128, sequence_length - 2)
        x = self.bn1(x)
        x = self.relu1(x)
        
        # Second convolutional layer
        x = self.conv2(x)  # Shape: (batch_size, 64, sequence_length - 4)
        x = self.bn2(x)
        x = self.relu2(x)
        
        # Third convolutional layer
        x = self.conv3(x)  # Shape: (batch_size, output_size, sequence_length - 6)
        x = self.bn3(x)
        x = self.relu3(x)
        
        # Permute back to original order: (batch_size, sequence_length - 6, output_size)
        x = x.permute(0, 2, 1)
        
        return x


In [40]:
class ECGEncoder(nn.Module):
    def __init__(self, input_size=102, output_size=64):
        super(ECGEncoder, self).__init__()
        
        self.conv1 = nn.Conv1d(input_size, 32, kernel_size=7, stride=2, padding=3)  # Downsampling by 2
        self.conv2 = nn.Conv1d(32, 64, kernel_size=7, stride=2, padding=3)             # Downsampling by 2
        self.conv3 = nn.Conv1d(64, 128, kernel_size=7, stride=2, padding=3)            # Downsampling by 2
        self.conv4 = nn.Conv1d(128, 256, kernel_size=7, stride=2, padding=3)           # Downsampling by 2
        self.conv5 = nn.Conv1d(256, 512, kernel_size=7, stride=2, padding=3)           # Downsampling by 2
        self.conv6 = nn.Conv1d(512, 1024, kernel_size=7, stride=2, padding=3)          # Downsampling by 2
        
        # Fully connected layer to map to the 64-dimensional output
        self.fc = nn.Linear(1024, output_size)
        
    def forward(self, x):
        x = x.float()
        x = x.permute(0, 2, 1)
        
        # Apply each convolutional layer with ReLU activation
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        
        # Global average pooling (to get a single vector for each channel)
#         x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
        
        # Final fully connected layer to produce 64-dimensional vector
#         x = self.fc(x)
#         x = x.permute(0, 2, 1)
        
        return x

In [41]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def initialize_model_supervised(seed):
    set_seed(seed)
    model = FeatureProjector(input_size=252, output_size=4)
    return model

In [42]:
# Instantiate the model
seed = 42
attn_model = initialize_model_supervised(seed)

In [None]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(attn_model.parameters(), lr=0.001)  # Example optimizer

# Move model to device
device = torch.device("cpu" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu")
attn_model.to(device)

# Training and validation loop
num_epochs = 15
for epoch in range(num_epochs):
    # Training phase
    attn_model.train()  # Set the model to training mode
    train_running_loss = 0.0
    train_correct_predictions = 0
    train_total_samples = 0
    
    all_preds = []
    all_labels = []
    
    for time_series, labels in tqdm(train_loader):
        time_series = time_series.to(device)
        labels = labels.to(device)
        
        # Forward pass
        features = attn_model(time_series)

        # Flatten y_hat to have dimensions [batch_size * sequence_length, num_classes]
        y_hat_flat = features.reshape(-1, features.size(-1))

        # Reshape y to have dimensions [batch_size * sequence_length]
        labels_flat = labels.view(-1)
        
        # Compute training loss
        train_loss = criterion(y_hat_flat, labels_flat)

        # Backward pass and optimization
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Update training statistics
        train_running_loss += train_loss.item() * time_series.size(0)
        
        _, predicted = torch.max(y_hat_flat, 1)
        train_correct_predictions += (predicted == labels_flat).sum().item()
        
        #Store the labels for future computation of F1-score
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels_flat.cpu().numpy())
        
        train_total_samples += labels_flat.size(0)
    
    # Calculate average training loss and accuracy for the epoch
    train_epoch_loss = train_running_loss / len(train_loader.dataset)
    train_epoch_accuracy = 100*train_correct_predictions / train_total_samples
    
    f1 = f1_score(all_labels, all_preds,average='weighted')
    
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f},\
          Train Accuracy: {train_epoch_accuracy:.2f}%, F1-score: {f1:.4f}")

    # Validation phase
    attn_model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    
    with torch.no_grad():
        val_preds = []
        val_labels = []
        for time_series, labels in tqdm(valid_loader):
            time_series = time_series.to(device)
            labels = labels.to(device)

            # Forward pass
            features = attn_model(time_series)

            # Flatten y_hat to have dimensions [batch_size * sequence_length, num_classes]
            y_hat_flat = features.reshape(-1, features.size(-1))

            # Reshape y to have dimensions [batch_size * sequence_length]
            labels_flat = labels.view(-1)

            # Compute validation loss
            val_loss = criterion(y_hat_flat, labels_flat)

            # Update validation statistics
            val_running_loss += val_loss.item() * time_series.size(0)
            
            _, predicted = torch.max(y_hat_flat, 1)
            val_correct_predictions += (predicted == labels_flat).sum().item()
            val_total_samples += labels_flat.size(0)
            
            val_preds.extend(predicted.cpu().numpy())
            val_labels.extend(labels_flat.cpu().numpy())
    
    # Calculate average validation loss and accuracy for the epoch
    val_epoch_loss = val_running_loss / len(valid_loader.dataset)
    val_epoch_accuracy = 100*val_correct_predictions / val_total_samples
    
    f1 = f1_score(val_labels, val_preds, average='weighted')
    print(f"Epoch {epoch + 1}/{num_epochs}, Val Loss: {val_epoch_loss:.4f},\
          Val Accuracy: {val_epoch_accuracy:.2f}%, F1-score: {f1:.2f}")

print("Training and validation complete.")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:21<00:00, 15.02it/s]


Epoch 1/15, Train Loss: 0.9088,          Train Accuracy: 75.25%, F1-score: 0.7724


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 19.08it/s]


Epoch 1/15, Val Loss: 1.1100,          Val Accuracy: 59.67%, F1-score: 0.58


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:21<00:00, 15.28it/s]


Epoch 2/15, Train Loss: 0.6410,          Train Accuracy: 85.29%, F1-score: 0.8550


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 18.94it/s]


Epoch 2/15, Val Loss: 1.1634,          Val Accuracy: 58.28%, F1-score: 0.52


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:21<00:00, 14.96it/s]


Epoch 3/15, Train Loss: 0.5044,          Train Accuracy: 88.74%, F1-score: 0.8882


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 18.68it/s]


Epoch 3/15, Val Loss: 1.2570,          Val Accuracy: 57.90%, F1-score: 0.52


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:21<00:00, 15.25it/s]


Epoch 4/15, Train Loss: 0.4209,          Train Accuracy: 90.08%, F1-score: 0.9011


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 19.10it/s]


Epoch 4/15, Val Loss: 1.0021,          Val Accuracy: 66.94%, F1-score: 0.66


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:20<00:00, 15.43it/s]


Epoch 5/15, Train Loss: 0.3419,          Train Accuracy: 92.30%, F1-score: 0.9229


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 19.21it/s]


Epoch 5/15, Val Loss: 1.0164,          Val Accuracy: 62.81%, F1-score: 0.60


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:20<00:00, 15.34it/s]


Epoch 6/15, Train Loss: 0.2984,          Train Accuracy: 92.93%, F1-score: 0.9292


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 19.45it/s]


Epoch 6/15, Val Loss: 1.0772,          Val Accuracy: 59.25%, F1-score: 0.55


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 321/321 [00:20<00:00, 15.35it/s]


Epoch 7/15, Train Loss: 0.2669,          Train Accuracy: 93.40%, F1-score: 0.9339


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:04<00:00, 19.21it/s]


Epoch 7/15, Val Loss: 1.0195,          Val Accuracy: 63.02%, F1-score: 0.60


 51%|████████████████████████████████████████████████████████████████████████████                                                                         | 164/321 [00:11<00:10, 15.01it/s]

In [None]:
# Wrapper dataset class to flatten the batches
class FlattenedDataset(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset
        self.num_batches = len(original_dataset)
        self.batch_size = original_dataset[0][0].shape[0]  # Assuming shape [599, 156]

    def __len__(self):
        return self.num_batches * self.batch_size

    def __getitem__(self, idx):
        batch_idx = idx // self.batch_size
        sample_idx = idx % self.batch_size
        data_batch, label_batch = self.original_dataset[batch_idx]
        return data_batch[sample_idx], label_batch[sample_idx]

In [None]:
def load_balanced_dataset(dataset, class_counts):
    # Initialize dictionary to store indices of each class
    class_indices = {label: [] for label in class_counts.keys()}
    
    # Populate class_indices with indices of each class
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        label = label.item()  # Ensure the label is a scalar
        if label in class_indices:
            class_indices[label].append(idx)

    # Ensure each class has the required number of instances
    balanced_indices = []
    for label, count in class_counts.items():
        if len(class_indices[label]) >= count:
            balanced_indices.extend(random.sample(class_indices[label], count))
        else:
            raise ValueError(f"Not enough instances of class {label} to satisfy the requested count")

    # Create a subset of the dataset with the balanced indices
    balanced_subset = Subset(dataset, balanced_indices)
    return balanced_subset

In [None]:
# Define the desired count for each class
desired_count_per_class = {0: 1577, 1: 972, 2: 459, 3: 1467}
flattened_data = FlattenedDataset(valid_ds)

# Load balanced dataset
balanced_dataset = load_balanced_dataset(flattened_data, desired_count_per_class)

# Create a DataLoader for the balanced dataset
train_balanced_dataloader = DataLoader(balanced_dataset, batch_size=5000, shuffle=True, worker_init_fn=seed_worker)

# # Print the labels in the balanced dataset to verify
# for features_batch, labels_batch in train_balanced_dataloader:
#     print(features_batch.shape)

In [None]:
# Initialize dictionary to count occurrences of each class
label_counts = {label: 0 for label in flipped_afib_dict}

# Iterate through the dataloader to count label occurrences
for _, labels in train_balanced_dataloader:
    for label in labels:
        label = label.item()  # Convert tensor to scalar
        if label in label_counts:
            label_counts[label] += 1

# Print the counts
print(label_counts)

In [None]:
# Function to apply PCA and visualize the results
def visualize_pca2(images, labels, class_names, model):
    # Flatten the images to vectors (assuming they are 2D images)
#     flattened_images = images.view(images.size(0), -1).numpy()

#     # Standardize the data before applying PCA
#     scaler = StandardScaler()
#     standardized_images = scaler.fit_transform(flattened_images)

    # Apply PCA
    pca = PCA(n_components=2)

    trained_pca = pca.fit(images.view(-1, images.size(-1)))
    reduced_features = trained_pca.transform(images.view(-1, images.size(-1)))
    
    # Plot the results
    plt.figure(figsize=(10, 8))
    for i in range(4):
        indices = labels == i
        count = (labels == i).sum()
        plt.scatter(reduced_features[indices, 0], reduced_features[indices, 1], label=class_names[i])

    plt.title('PCA Visualization of Image Data')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.legend()
    plt.show()
    
    model.eval()
    
    with torch.no_grad():
        standardized_images = model(images)

    trained_pca = pca.fit(standardized_images.view(-1, standardized_images.size(-1)))
    reduced_features = trained_pca.transform(standardized_images.view(-1, standardized_images.size(-1)))

    # Plot the results
    plt.figure(figsize=(10, 8))
    for i in range(4):
        indices = labels == i
        count = (labels == i).sum()
        plt.scatter(reduced_features[indices, 0], reduced_features[indices, 1], label=class_names[i])

    plt.title('PCA Visualization of Image Data (Resnet features)')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.legend()
    plt.show()
    
    # Get explained variance ratio
    explained_variance_ratio = trained_pca.explained_variance_ratio_
    cumulative_explained_variance = np.cumsum(explained_variance_ratio)

    # Print explained variance ratio
    print("Explained variance ratio:", explained_variance_ratio)
    print("Cumulative explained variance:", cumulative_explained_variance)

    # Plot explained variance ratio
    plt.figure(figsize=(8, 6))
    plt.bar(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio, alpha=0.5, align='center', label='Individual explained variance')
    plt.step(range(1, len(cumulative_explained_variance) + 1), cumulative_explained_variance, where='mid', label='Cumulative explained variance')
    plt.ylabel('Explained variance ratio')
    plt.xlabel('Principal components')
    plt.legend(loc='best')
    plt.tight_layout()
    plt.show()

In [None]:
# Assuming your dataset has a 'classes' attribute containing class names
# class_names = dataset.clases

# Assuming you have a DataLoader named 'train_loader'
for batch in train_balanced_dataloader:
    images, labeli = batch
    images = images.view(-1, 1, images.shape[-1])
    
    visualize_pca2(images, labeli, flipped_afib_dict, attn_model)

In [None]:
# Function to apply t-SNE and visualize the results
def visualize_tsne(images, labels, class_names, model):
    # Flatten the images to vectors (assuming they are 2D images)
    # Flatten the images to vectors (assuming they are 2D images)
    flattened_images = images.view(images.size(0), -1).numpy()

    # Standardize the data before applying t-SNE
    scaler = StandardScaler()
    standardized_images = scaler.fit_transform(flattened_images)

    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    reduced_features = tsne.fit_transform(standardized_images)

    # Plot the results
    plt.figure(figsize=(10, 8))
    for i in range(4):
        indices = labels == i
        plt.scatter(reduced_features[indices, 0], reduced_features[indices, 1], label=class_names[i])

    plt.title('t-SNE Visualization of Image Data')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.show()

    # Evaluate the model and get features
    model.eval()
    with torch.no_grad():
        model_features = model(images)

    # Standardize model features before applying t-SNE
    standardized_model_features = scaler.fit_transform(model_features.view(-1, model_features.size(-1)).cpu().numpy())

    # Apply t-SNE to model features
    reduced_features_model = tsne.fit_transform(standardized_model_features)

    # Plot the results
    plt.figure(figsize=(10, 8))
    for i in range(4):
        indices = labels == i
        plt.scatter(reduced_features_model[indices, 0], reduced_features_model[indices, 1], label=class_names[i])

    plt.title('t-SNE Visualization of Image Data (ResNet features)')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.show()

In [None]:
# Assuming your dataset has a 'classes' attribute containing class names
# class_names = dataset.clases

# Assuming you have a DataLoader named 'train_loader'
for batch in train_balanced_dataloader:
    images, labeli = batch
    images = images.view(-1, 1, images.shape[-1])
    
    visualize_tsne(images, labeli, flipped_afib_dict, attn_model)

### 1. Dynamic Margin-CL

In [None]:
class HATCL_LOSS(torch.nn.Module):
    def __init__(self, temperature=0.5):
        super(HATCL_LOSS, self).__init__()
        self.temperature = temperature

    def forward(self, features):
        # Normalize the feature vectors
        features_normalized = F.normalize(features, dim=-1, p=2)

        # Calculate the cosine similarity matrix
        similarities = torch.matmul(features_normalized, features_normalized.T)
        
        exp_similarities = torch.exp(similarities / self.temperature)
        
        # Removing the similarity of a window with itself i.e main diagonal
        exp_similarities = exp_similarities - torch.diag(exp_similarities.diag())        

        # Lower diagonal elements represent positive pairs
        positives = torch.diagonal(exp_similarities, offset=-1)

        # The denominator is the sum of the column vectors minus the positives
        denominator = torch.sum(exp_similarities[:,:-1], dim=0) - positives
        
        # Calculate NT-Xent loss
        loss = -torch.log(positives / denominator).mean()

        return loss

In [None]:
class LS_HATCL_LOSS(torch.nn.Module):
    def __init__(self, temperature=0.5):
        super(LS_HATCL_LOSS, self).__init__()
        self.temperature = temperature

    def forward(self, features):
        
        # Normalize the feature vectors
        features_normalized = torch.nn.functional.normalize(features, p=2, dim=-1)

        # Calculate the cosine similarity matrix
        similarities = torch.matmul(features_normalized, features_normalized.T)

        
        exp_similarities = torch.exp(similarities / self.temperature)
        
        # Removing the similarity of a window with itself i.e main diagonal
        exp_similarities = exp_similarities - torch.diag(exp_similarities.diag())        

        # Lower diagonal elements represent positive pairs
        lower_diag = torch.diagonal(exp_similarities, offset=-1)
        
        # The numerator is the sum of shifted left and right of the positive pairs
        numerator = lower_diag[1:] + lower_diag[:-1]
        
        # The denominator is the sum of the column vectors minus the positives
        denominator = torch.sum(exp_similarities[:,:-2], dim=0) - lower_diag[:-1]\
                + (torch.sum(exp_similarities[:,1:-1], dim=0)  - (lower_diag[1:] + lower_diag[:-1]))
        
        
        # Calculate NT-Xent loss
        loss = -torch.log(numerator / denominator).mean()
        
#         print("Similarities: ", similarities)
#         print("Exp Similarities: ", exp_similarities)
#         print("Numerator: ", numerator)
#         print("Denominator: ", denominator)
        
        return loss

In [None]:
def initialize_model(seed):
    set_seed(seed)
    model = FeatureProjector(input_size=252, output_size=32)
    return model

In [None]:
# Instantiate the model
seed = 42
attn_model = initialize_model(seed)

In [None]:
time_series.shape

In [None]:
# Define loss function and optimizer
cl_loss = LS_HATCL_LOSS(temperature=0.5)
optimizer = optim.AdamW(attn_model.parameters(), lr=0.01)  # Example optimizer

# Move model to device
device = torch.device("cpu" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu")
attn_model.to(device)
l = 0.4

# Training and validation loop
num_epochs = 5

for epoch in range(num_epochs):
    # Training phase
    attn_model.train()  # Set the model to training mode
    train_running_loss = 0.0

    for batch_idx, (time_series, labels) in enumerate(tqdm(train_loader)):
        time_series = time_series.to(device)
        
        
        
        # Forward pass
        features = attn_model(time_series)
        
        linear_layer = nn.Linear(in_features=time_series.shape[-1], out_features=32)

        scaled_timeseries = linear_layer(time_series.float())
        mse_loss = F.mse_loss(scaled_timeseries, features)

    
        # Flatten features to have dimensions [batch_size * sequence_length, feature dim]
        features = features.reshape(-1, features.size(-1))

        # Compute training loss
        contrast_loss = cl_loss(features)
        
#         print(mse_loss.item())
#         print(contrast_loss.item())
        
        train_loss = l*mse_loss + (1-l)*contrast_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()



        # Update training statistics
        train_running_loss += train_loss.item() * time_series.size(0)

#         # Log training loss to Wandb
#         if config.WANDB and batch_idx % 10 == 0:
#             wandb.log({'Train Loss': train_running_loss /(batch_idx + 1), 'Epoch': epoch})


    train_epoch_loss = train_running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f}")

In [None]:
# Calculate Davies-Bouldin Index
db_index = davies_bouldin_score(images.squeeze(), labeli)
ch_index = calinski_harabasz_score(images.squeeze(), labeli)
slh_index = silhouette_score(images.squeeze(), labeli)

print(f"Davies-Bouldin Index: {db_index}")
print(f"Calinski Harabasz Index: {ch_index}")
print(f"Silhouette Index: {slh_index}")

In [None]:
attn_model.eval()
with torch.no_grad():
    features = attn_model(images)

In [None]:
kmeans = KMeans(n_clusters=4, random_state=1).fit(features.detach().squeeze())
cluster_labels = kmeans.labels_

# Calculate Davies-Bouldin Index
db_index2 = davies_bouldin_score(features.detach().squeeze(), cluster_labels)
ch_index2 = calinski_harabasz_score(features.detach().squeeze(), cluster_labels)
slh_index2 = silhouette_score(features.detach().squeeze(), cluster_labels)

print(f"Davies-Bouldin Index Features: {db_index2}")
print(f"Calinski Harabasz Index Features: {ch_index2}")
print(f"Silhouette Index Features: {slh_index2}")

In [None]:
visualize_tsne(images, labeli, flipped_afib_dict, attn_model)

## Linear Evaluation

In [None]:
# Example: Define a frozen backbone and a linear classifier
class LinearEvaluation(nn.Module):
    def __init__(self, backbone, num_classes):
        super(LinearEvaluation, self).__init__()
        self.backbone = backbone
        self.backbone.requires_grad_(False)
        self.classifier = nn.Linear(32, num_classes)  # Add linear classifier
    
    def forward(self, x):
        with torch.no_grad():  # Ensure backbone is not updated
            features = self.backbone(x)  # Extract features using frozen backbone
            
        return self.classifier(features)  # Feed features to linear classifier

In [None]:
def is_backbone_frozen(model):
    frozen = True
    for param in model.parameters():
        if param.requires_grad:
            frozen = False
            break
    return frozen

In [None]:
def linear_evaluation(train_loader, valid_loader, algorithms, seeds, num_epochs):
    diction_algs = {}
        
    
    for alg in algorithms:
        seed_acc = []
        seed_f1 = []
        seed_prec = []
        seed_recall = []
        
        for seed in seeds:
            set_seed(seed)
            frozen_backbone = FeatureProjector(input_size=252, output_size=32)
            frozen_backbone.load_state_dict(torch.load(f'models/my_models/ecg{seed}_{alg}_model_epoch_500.pth',
                                            map_location=torch.device('cpu')))
            
            
            num_activities = 4
            mine_model = LinearEvaluation(frozen_backbone, num_classes=num_activities)
            # Define loss function and optimizer
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.Adam(mine_model.parameters(), lr=0.001)  # Example optimizer

            # Move model to device
            device = torch.device("cpu" if torch.backends.mps.is_available() and torch.backends.mps.is_built() else "cpu")
            mine_model.to(device)

            # Training and validation loop
            num_epochs = num_epochs
            for epoch in tqdm(range(num_epochs)):
                # Training phase
                mine_model.train()  # Set the model to training mode
                train_running_loss = 0.0
                train_correct_predictions = 0
                train_total_samples = 0

                all_preds = []
                all_labels = []

                for time_series, labels in (train_loader):
                    time_series = time_series.to(device)
                    labels = labels.to(device)

                    # Forward pass
                    features = mine_model(time_series)
                    # Flatten y_hat to have dimensions [batch_size * sequence_length, num_classes]
                    y_hat_flat = features.reshape(-1, features.size(-1))

                    # Reshape y to have dimensions [batch_size * sequence_length]
                    labels_flat = labels.view(-1)

                    # Compute training loss
                    train_loss = criterion(y_hat_flat, labels_flat)

                    # Backward pass and optimization
                    optimizer.zero_grad()
                    train_loss.backward()
                    optimizer.step()

                    # Update training statistics
                    train_running_loss += train_loss.item() * time_series.size(0)

                    _, predicted = torch.max(y_hat_flat, 1)
                    train_correct_predictions += (predicted == labels_flat).sum().item()

                    #Store the labels for future computation of F1-score
                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels_flat.cpu().numpy())

                    train_total_samples += labels_flat.size(0)

                # Calculate average training loss and accuracy for the epoch
                train_epoch_loss = train_running_loss / len(train_loader.dataset)
                train_epoch_accuracy = 100*train_correct_predictions / train_total_samples

                f1 = f1_score(all_labels, all_preds,average='weighted')

#                 print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f},\
#                       Train Accuracy: {train_epoch_accuracy:.2f}%, F1-score: {f1:.4f}")
                
            # Validation phase
            mine_model.eval()  # Set the model to evaluation mode
            val_running_loss = 0.0
            val_correct_predictions = 0
            val_total_samples = 0

            with torch.no_grad():
                val_preds = []
                val_labels = []
                for time_series, labels in (valid_loader):
                    time_series = time_series.to(device)
                    labels = labels.to(device)

                    # Forward pass
                    features = mine_model(time_series)

                    # Flatten y_hat to have dimensions [batch_size * sequence_length, num_classes]
                    y_hat_flat = features.reshape(-1, features.size(-1))

                    # Reshape y to have dimensions [batch_size * sequence_length]
                    labels_flat = labels.view(-1)

                    # Compute validation loss
                    val_loss = criterion(y_hat_flat, labels_flat)

                    # Update validation statistics
                    val_running_loss += val_loss.item() * time_series.size(0)

                    _, predicted = torch.max(y_hat_flat, 1)
                    val_correct_predictions += (predicted == labels_flat).sum().item()
                    val_total_samples += labels_flat.size(0)

                    val_preds.extend(predicted.cpu().numpy())
                    val_labels.extend(labels_flat.cpu().numpy())

            # Calculate average validation loss and accuracy for the epoch
            val_epoch_loss = val_running_loss / len(valid_loader.dataset)
            val_epoch_accuracy = 100*val_correct_predictions / val_total_samples

            # Precision and recall using sklearn
            precision = precision_score(val_labels, val_preds, average='macro')
            recall = recall_score(val_labels, val_preds, average='macro')

            f1 = f1_score(val_labels, val_preds, average='weighted')
#             print(f"Epoch {epoch + 1}/{num_epochs}, Val Loss: {val_epoch_loss:.4f},\
#                   Val Accuracy: {val_epoch_accuracy:.2f}%, F1-score: {f1:.2f},\
#                   Precision: {precision:.2f}, Recall: {recall:.2f}")
            
            
            seed_acc.append((round(val_epoch_accuracy,2)))
            seed_f1.append((round(f1,2)))
            seed_prec.append((round(precision,2)))
            seed_recall.append((round(recall,2)))
            
            
        diction_algs[f'{alg}'] = [(round(np.mean(seed_acc),2), round(np.std(seed_acc),2)),
                                   (round(np.mean(seed_f1),2), round(np.std(seed_f1),2)),
                                   (round(np.mean(seed_prec),2), round(np.std(seed_prec),2)),
                                   (round(np.mean(seed_recall),2), round(np.std(seed_recall),2))]
        
    return diction_algs

In [None]:
# algorithms = ['vanilla', 'marginCL', 'tnc', 'cpc', 'ts2vec', 'infoTS']
# seeds = [42, 53, 64, 75]
# num_epoch = 50

In [None]:
algorithms = ['cost']
seeds = [42, 53, 64, 75]
num_epoch = 50

In [None]:
metric_data = linear_evaluation(train_loader, valid_loader, algorithms, seeds, num_epoch)

In [None]:
metric_data

In [None]:
# Function to apply t-SNE and visualize the results
def visualize_tsne123(images, colors, labels, class_names, model):

    # Evaluate the model and get features
    model.eval()
    with torch.no_grad():
        model_features = model(images)

    # Standardize the data before applying t-SNE
    scaler = StandardScaler()
    tsne = TSNE(n_components=2, init='random', learning_rate='auto')

    # Standardize model features before applying t-SNE
    standardized_model_features = scaler.fit_transform(model_features.view(-1, model_features.size(-1)).cpu().numpy())

    # Apply t-SNE to model features
    reduced_features_model = tsne.fit_transform(standardized_model_features)

    # Plot the results
    plt.figure(figsize=(6, 6))
    for i, val in enumerate([0, 1, 2]):
        indices = labels == val
        plt.scatter(reduced_features_model[indices, 0], reduced_features_model[indices, 1], color =colors[i], label=class_names[val])

#     plt.title('t-SNE Visualization of Image Data (ResNet features)')
#     plt.xlabel('t-SNE Component 1')
#     plt.ylabel('t-SNE Component 2')
    plt.legend()
    
    plt.legend(prop={'size': 14, 'family': 'Tahoma'})
    plt.show()

In [None]:
model = FeatureProjector(input_size=252, output_size=32)

In [None]:
model.load_state_dict(torch.load(f'models/my_models/ecg53_cost_model_epoch_500.pth',
                                            map_location=torch.device('cpu')))

In [None]:
algorithms = ['vanilla', 'tnc', 'cpc', 'ts2vec', 'infoTS', 'triplet', 'doubleCL', 'marginCL', 'recons', 'monoselfPAB']

In [None]:
flipped_afib_dict = {0: "AFIB", 1: "AFL", 2: "J", 3: "N"}

In [None]:
# Assuming your dataset has a 'classes' attribute containing class names
# class_names = dataset.clases

# Assuming you have a DataLoader named 'train_loader'
for batch in train_balanced_dataloader:
    images, labeli = batch
    images = images.view(-1, 1, images.shape[-1])

In [None]:
# Define your class names and corresponding colors for 5 classes
class_names = ['Class 0', 'Class 1', 'Class 2', 'Class 3']
colors = ['#1F64A1', '#A9A9A9', '#8B4513', '#2E8B57']  # Blue, Light Gray, Brown, Green

# Assume `images`, `labels`, and `model` are already defined
# `labeli` corresponds to the labels for the images, `class_dict` is a dictionary mapping labels to class names
visualize_tsne123(images, colors, labeli, flipped_afib_dict, model)