### Import

In [None]:
import os
from shutil import copy

In [None]:
# os.chdir("/kaggle/input/ecg2mit-bih/datasets") # Перейдем в Input (только для чтения!)
# !ls # Посмотреть содержимое

In [None]:
# os.chdir("/kaggle/working/") # Перейдем в Output
# !ls

In [None]:
os.chdir("/kaggle/input/tfc/pytorch/checkpoint1/1/")
source_file = "ckp_last.pt"
target_dir = "/kaggle/working/experiments_logs/ECG_2_MIT-BIH/run1/pre_train_seed_42_2layertransformer/saved_models/"
target_file = os.path.join(target_dir, "ckp_last.pt")
os.makedirs(target_dir, exist_ok=True)
copy(source_file, target_file)
os.chdir("/kaggle/working/") # Перейдем в Output
!ls

In [None]:
import torch
from torch.utils.data import Dataset
import torch.fft as fft
import torch.nn.functional as F
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix, \
    average_precision_score, accuracy_score, precision_score,f1_score,recall_score
from sklearn.metrics import cohen_kappa_score,accuracy_score
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
from datetime import datetime
import argparse
import sys
from tqdm.notebook import tqdm
import random
import pandas as pd
import logging

from transformers import BertModel, BertTokenizer

import warnings
warnings.filterwarnings("ignore")


In [None]:
# Set up device
with_gpu = torch.cuda.is_available()
if with_gpu:
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"We are using {device} now.")

### config_files.py

In [None]:
class Configs(object):
    def __init__(self):
        # model configs
        self.input_channels = 1
        self.kernel_size = 8
        self.stride = 1
        self.final_out_channels = 128

        self.num_classes = 4
        self.dropout = 0.35
        self.features_len = 18

        # training configs
        self.num_epoch = 100  # 40 # 100  # 1

        # optimizer parameters
        self.beta1 = 0.9
        self.beta2 = 0.99
        self.lr = 3e-7 # 3e-4 # 3e-6

        # data parameters
        self.drop_last = True
        self.batch_size = 128

        self.Context_Cont = Context_Cont_configs()
        self.TC = TC()
        self.augmentation = augmentations()

        """New hyperparameters"""
        self.TSlength_aligned = 1500  # 1500 # 186
        self.lr_f = self.lr
        self.target_batch_size = 41  # 82 # 41
        self.increased_dim = 1
        self.final_out_channels = 128
        self.num_classes_target = 5
        self.features_len_f = self.features_len
        self.CNNoutput_channel = 190  # 751


class augmentations(object):
    def __init__(self):
        self.jitter_scale_ratio = 1.1
        self.jitter_ratio = 0.8
        self.max_seg = 8


class Context_Cont_configs(object):
    def __init__(self):
        self.temperature = 0.2
        self.use_cosine_similarity = True


class TC(object):
    def __init__(self):
        self.hidden_dim = 100
        self.timesteps = 6

### augmentations.py

In [None]:
def one_hot_encoding(X):
    X = [int(x) for x in X]
    n_values = np.max(X) + 1
    b = np.eye(n_values)[X]
    return b


def DataTransform(sample, config):
    """Weak and strong augmentations"""
    weak_aug = scaling(sample, config.augmentation.jitter_scale_ratio)
    # weak_aug = permutation(sample, max_segments=config.augmentation.max_seg)
    strong_aug = jitter(permutation(sample, max_segments=config.augmentation.max_seg), config.augmentation.jitter_ratio)

    return weak_aug, strong_aug

# def DataTransform_TD(sample, config):
#     """Weak and strong augmentations"""
#     weak_aug = sample
#     strong_aug = jitter(permutation(sample, max_segments=config.augmentation.max_seg),
# config.augmentation.jitter_ratio) #masking(sample)
#     return weak_aug, strong_aug
#
# def DataTransform_FD(sample, config):
#     """Weak and strong augmentations in Frequency domain """
#     # weak_aug =  remove_frequency(sample, 0.1)
#     strong_aug = add_frequency(sample, 0.1)
#     return weak_aug, strong_aug


def DataTransform_TD(sample, config):
    """Simplely use the jittering augmentation. Feel free to add more autmentations you want,
    but we noticed that in TF-C framework, the augmentation has litter impact on the final tranfering performance."""
    aug = jitter(sample, config.augmentation.jitter_ratio)
    return aug


def DataTransform_TD_bank(sample, config):
    """Augmentation bank that includes four augmentations and randomly select one as the positive sample.
    You may use this one the replace the above DataTransform_TD function."""
    aug_1 = jitter(sample, config.augmentation.jitter_ratio)
    aug_2 = scaling(sample, config.augmentation.jitter_scale_ratio)
    aug_3 = permutation(sample, max_segments=config.augmentation.max_seg)
    aug_4 = masking(sample, keepratio=0.9)

    li = np.random.randint(0, 4, size=[sample.shape[0]])
    li_onehot = one_hot_encoding(li)
    aug_1 = aug_1 * li_onehot[:, 0][:, None, None]  # the rows that are not selected are set as zero.
    aug_2 = aug_2 * li_onehot[:, 0][:, None, None]
    aug_3 = aug_3 * li_onehot[:, 0][:, None, None]
    aug_4 = aug_4 * li_onehot[:, 0][:, None, None]
    aug_T = aug_1 + aug_2 + aug_3 + aug_4
    return aug_T


def DataTransform_FD(sample, config):
    """Weak and strong augmentations in Frequency domain """
    aug_1 = remove_frequency(sample, pertub_ratio=0.1)
    aug_2 = add_frequency(sample, pertub_ratio=0.1)
    aug_F = aug_1 + aug_2
    return aug_F


def remove_frequency(x, pertub_ratio=0.0):
    mask = torch.cuda.FloatTensor(x.shape).uniform_() > pertub_ratio  # maskout_ratio are False
    mask = mask.to(x.device)
    return x*mask


def add_frequency(x, pertub_ratio=0.0):

    mask = torch.cuda.FloatTensor(x.shape).uniform_() > (1-pertub_ratio)  # only pertub_ratio of all values are True
    mask = mask.to(x.device)
    max_amplitude = x.max()
    random_am = torch.rand(mask.shape)*(max_amplitude*0.1)
    pertub_matrix = mask*random_am
    return x+pertub_matrix


def generate_binomial_mask(B, T, D, p=0.5):  # p is the ratio of not zero
    return torch.from_numpy(np.random.binomial(1, p, size=(B, T, D))).to(torch.bool)


def masking(x, keepratio=0.9, mask='binomial'):
    global mask_id
    nan_mask = ~x.isnan().any(axis=-1)
    x[~nan_mask] = 0
    # x = self.input_fc(x)  # B x T x Ch

    if mask == 'binomial':
        mask_id = generate_binomial_mask(x.size(0), x.size(1), x.size(2), p=keepratio).to(x.device)
    # elif mask == 'continuous':
    #     mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device)
    # elif mask == 'all_true':
    #     mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
    # elif mask == 'all_false':
    #     mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool)
    # elif mask == 'mask_last':
    #     mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool)
    #     mask[:, -1] = False

    # mask &= nan_mask
    x[~mask_id] = 0
    return x


def jitter(x, sigma=0.8):
    # https://arxiv.org/pdf/1706.00527.pdf
    return x + np.random.normal(loc=0., scale=sigma, size=x.shape)


def scaling(x, sigma=1.1):
    # https://arxiv.org/pdf/1706.00527.pdf
    factor = np.random.normal(loc=2., scale=sigma, size=(x.shape[0], x.shape[2]))
    ai = []
    for i in range(x.shape[1]):
        xi = x[:, i, :]
        ai.append(np.multiply(xi, factor[:, :])[:, np.newaxis, :])
    return np.concatenate((ai), axis=1)


def permutation(x, max_segments=5, seg_mode="random"):
    orig_steps = np.arange(x.shape[2])

    num_segs = np.random.randint(1, max_segments, size=(x.shape[0]))

    ret = np.zeros_like(x)
    for i, pat in enumerate(x):
        if num_segs[i] > 1:
            if seg_mode == "random":
                split_points = np.random.choice(x.shape[2] - 2, num_segs[i] - 1, replace=False)
                split_points.sort()
                splits = np.split(orig_steps, split_points)
            else:
                splits = np.array_split(orig_steps, num_segs[i])
            warp = np.concatenate(np.random.permutation(splits)).ravel()
            ret[i] = pat[0, warp]
        else:
            ret[i] = pat
    return torch.from_numpy(ret)


### dataloader.py

In [None]:
def generate_freq(dataset, config):
    X_train = dataset["samples"]
    y_train = dataset['labels']
    # shuffle
    data = list(zip(X_train, y_train))
    np.random.shuffle(data)
    data = data[:10000]  # take a subset for testing.
    X_train, y_train = zip(*data)
    X_train, y_train = torch.stack(list(X_train), dim=0), torch.stack(list(y_train), dim=0)

    if len(X_train.shape) < 3:
        X_train = X_train.unsqueeze(2)

    if X_train.shape.index(min(X_train.shape)) != 1:  # make sure the Channels in second dim
        X_train = X_train.permute(0, 2, 1)

    """Align the TS length between source and target datasets"""
    X_train = X_train[:, :1, :int(config.TSlength_aligned)]  # take the first 178 samples

    if isinstance(X_train, np.ndarray):
        x_data = torch.from_numpy(X_train)
    else:
        x_data = X_train

    """Transfer x_data to Frequency Domain. If use fft.fft, the output has the same shape; if use fft.rfft,
    the output shape is half of the time window."""

    x_data_f = fft.fft(x_data).abs()  # /(window_length) # rfft for real value inputs.
    return (X_train, y_train, x_data_f)


class Load_Dataset(Dataset):
    # Initialize your data, download, etc.
    def __init__(self, dataset, config, training_mode, target_dataset_size=64, subset=False):
        super(Load_Dataset, self).__init__()
        self.training_mode = training_mode
        X_train = dataset["samples"]
        y_train = dataset["labels"]
        # shuffle
        data = list(zip(X_train, y_train))
        np.random.shuffle(data)
        X_train, y_train = zip(*data)
        X_train, y_train = torch.stack(list(X_train), dim=0), torch.stack(list(y_train), dim=0)

        if len(X_train.shape) < 3:
            X_train = X_train.unsqueeze(2)

        if X_train.shape.index(min(X_train.shape)) != 1:  # make sure the Channels in second dim
            X_train = X_train.permute(0, 2, 1)

        """Align the TS length between source and target datasets"""
        X_train = X_train[:, :1, :int(config.TSlength_aligned)]  # take the first 178 samples

        """Subset for debugging"""
        if subset:
            subset_size = target_dataset_size * 10  # 30 #7 # 60*1
            """if the dimension is larger than 178, take the first 178 dimensions.
            If multiple channels, take the first channel"""
            X_train = X_train[:subset_size]
            y_train = y_train[:subset_size]
            print('Using subset for debugging, the datasize is:', y_train.shape[0])

        if isinstance(X_train, np.ndarray):
            self.x_data = torch.from_numpy(X_train)
            self.y_data = torch.from_numpy(y_train).long()
        else:
            self.x_data = X_train
            self.y_data = y_train

        """Transfer x_data to Frequency Domain. If use fft.fft, the output has the same shape; if use fft.rfft,
        the output shape is half of the time window."""

        # window_length = self.x_data.shape[-1]
        self.x_data_f = fft.fft(self.x_data).abs()  # /(window_length) # rfft for real value inputs.
        self.len = X_train.shape[0]

        """Augmentation"""
        if training_mode == "pre_train":  # no need to apply Augmentations in other modes
            self.aug1 = DataTransform_TD(self.x_data, config)
            self.aug1_f = DataTransform_FD(self.x_data_f, config)  # [7360, 1, 90]

    def __getitem__(self, index):
        if self.training_mode == "pre_train":
            return self.x_data[index], self.y_data[index], self.aug1[index],  \
                   self.x_data_f[index], self.aug1_f[index]
        else:
            return self.x_data[index], self.y_data[index], self.x_data[index], \
                   self.x_data_f[index], self.x_data_f[index]

    def __len__(self):
        return self.len


def data_generator(sourcedata_path, targetdata_path, configs, training_mode, subset=True):
    train_dataset = torch.load(os.path.join(sourcedata_path, "train.pt"))
    finetune_dataset = torch.load(os.path.join(targetdata_path, "train.pt"))  # train.pt
    test_dataset = torch.load(os.path.join(targetdata_path, "test.pt"))  # test.pt
    """In pre-training:
    train_dataset: [371055, 1, 178] from SleepEEG.
    finetune_dataset: [60, 1, 178], test_dataset: [11420, 1, 178] from Epilepsy"""

    # subset = True # if true, use a subset for debugging.
    train_dataset = Load_Dataset(train_dataset,
                                 configs,
                                 training_mode,
                                 target_dataset_size=configs.batch_size,
                                 subset=subset)  # for self-supervised, the data are augmented here
    finetune_dataset = Load_Dataset(finetune_dataset,
                                    configs,
                                    training_mode,
                                    target_dataset_size=configs.target_batch_size,
                                    subset=subset)
    test_dataset = Load_Dataset(test_dataset, configs, training_mode,
                                target_dataset_size=configs.target_batch_size, subset=False)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=configs.batch_size,
                                               shuffle=True, drop_last=configs.drop_last,
                                               num_workers=0)
    finetune_loader = torch.utils.data.DataLoader(dataset=finetune_dataset, batch_size=configs.target_batch_size,
                                                  shuffle=True, drop_last=configs.drop_last,
                                                  num_workers=0)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=configs.target_batch_size,
                                              shuffle=True, drop_last=False,
                                              num_workers=0)

    return train_loader, finetune_loader, test_loader


### loss.py

In [None]:
class NTXentLoss(torch.nn.Module):
    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        """Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)


class NTXentLoss_poly(torch.nn.Module):

    def __init__(self, device, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss_poly, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.device = device
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask.to(self.device)

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        """Criterion has an internal one-hot function. Here, make all positives as 1 while all negatives as 0. """
        labels = torch.zeros(2 * self.batch_size).to(self.device).long()
        CE = self.criterion(logits, labels)

        onehot_label = torch.cat((torch.ones(2 * self.batch_size, 1),
                                 torch.zeros(2 * self.batch_size, negatives.shape[-1])),
                                 dim=-1).to(self.device).long()
        # Add poly loss
        pt = torch.mean(onehot_label * torch.nn.functional.softmax(logits, dim=-1))

        epsilon = self.batch_size
        # loss = CE/ (2 * self.batch_size) + epsilon*(1-pt) # replace 1 by 1/self.batch_size
        loss = CE / (2 * self.batch_size) + epsilon * (1/self.batch_size - pt)
        # loss = CE / (2 * self.batch_size)

        return loss


class hierarchical_contrastive_loss(torch.nn.Module):

    def __init__(self, device):
        super(hierarchical_contrastive_loss, self).__init__()
        self.device = device

    def instance_contrastive_loss(self, z1, z2):
        B, T = z1.size(0), z1.size(1)
        if B == 1:
            return z1.new_tensor(0.)
        z = torch.cat([z1, z2], dim=0)  # 2B x T x C
        z = z.transpose(0, 1)  # T x 2B x C
        sim = torch.matmul(z, z.transpose(1, 2))  # T x 2B x 2B
        logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # T x 2B x (2B-1)
        logits += torch.triu(sim, diagonal=1)[:, :, 1:]
        logits = -F.log_softmax(logits, dim=-1)

        i = torch.arange(B)
        loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
        return loss

    def temporal_contrastive_loss(self, z1, z2):
        B, T = z1.size(0), z1.size(1)
        if T == 1:
            return z1.new_tensor(0.)
        z = torch.cat([z1, z2], dim=1)  # B x 2T x C
        sim = torch.matmul(z, z.transpose(1, 2))  # B x 2T x 2T
        logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # B x 2T x (2T-1)
        logits += torch.triu(sim, diagonal=1)[:, :, 1:]
        logits = -F.log_softmax(logits, dim=-1)

        t = torch.arange(T)
        loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
        return loss

    def forward(self, z1, z2, alpha=0.5, temporal_unit=0):
        loss = torch.tensor(0., device=self.device)  # , device=z1.device
        d = 0
        while z1.size(1) > 1:
            if alpha != 0:
                loss += alpha * self.instance_contrastive_loss(z1, z2)
            if d >= temporal_unit:
                if 1 - alpha != 0:
                    loss += (1 - alpha) * self.temporal_contrastive_loss(z1, z2)
            d += 1
            z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
            z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)
        if z1.size(1) == 1:
            if alpha != 0:
                loss += alpha * self.instance_contrastive_loss(z1, z2)
            d += 1
        return loss / d


### model.py

In [None]:
BERT_PRETRAIN_PATH = "./BERT_pretrain/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TFC(nn.Module):
    """
    ...
    """
    def __init__(self, configs):
        super(TFC, self).__init__()

        self.adaptive_avgpool = nn.AdaptiveAvgPool1d(configs.TSlength_aligned)

        encoder_layers_t = TransformerEncoderLayer(configs.TSlength_aligned,
                                                   dim_feedforward=2*configs.TSlength_aligned,
                                                   nhead=2, )
        self.transformer_encoder_t = TransformerEncoder(encoder_layers_t, 2)

        self.projector_t = nn.Sequential(
            nn.Linear(configs.TSlength_aligned, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

        encoder_layers_f = TransformerEncoderLayer(configs.TSlength_aligned,
                                                   dim_feedforward=2*configs.TSlength_aligned,
                                                   nhead=2,)
        self.transformer_encoder_f = TransformerEncoder(encoder_layers_f, 2)

        self.projector_f = nn.Sequential(
            nn.Linear(configs.TSlength_aligned, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x_in_t, x_in_f):
        """
        ...
        """
        # Adaptive average pooling
        x_in_t = self.adaptive_avgpool(x_in_t)
        x_in_f = self.adaptive_avgpool(x_in_f)

        # Use Transformer
        x = self.transformer_encoder_t(x_in_t)
        h_time = x.reshape(x.shape[0], -1)

        # Cross-space projector
        z_time = self.projector_t(h_time)

        # Frequency-based contrastive encoder
        f = self.transformer_encoder_f(x_in_f)
        h_freq = f.reshape(f.shape[0], -1)

        # Cross-space projector
        z_freq = self.projector_f(h_freq)

        return h_time, z_time, h_freq, z_freq


# Downstream classifier only used in finetuning
# class target_classifier(nn.Module):
#     def __init__(self, configs):
#         super(target_classifier, self).__init__()
#         self.logits = nn.Linear(2 * 128, 64)
#         self.logits_simple = nn.Linear(64, configs.num_classes_target)

#     def forward(self, emb):
#         emb_flat = emb.reshape(emb.shape[0], -1)
#         emb = torch.sigmoid(self.logits(emb_flat))
#         pred = self.logits_simple(emb)
#         return pred

class FrozenLanguageModel(nn.Module):
    """
    Description:
        A frozen version of the language model.
    """
    def __init__(self):
        super(FrozenLanguageModel, self).__init__()
        self.language_model = BertModel.from_pretrained(
            'emilyalsentzer/Bio_ClinicalBERT',
            cache_dir=BERT_PRETRAIN_PATH
        ).to(device)
        for param in self.language_model.parameters():
            param.requires_grad = False
        self.dimension_reducer = nn.Linear(768, 256).to(device)

    def forward(self, input_ids, attention_mask) -> torch.Tensor:
        """
        Description:
            Forward pass of the frozen language model.
        Args:
            input_ids: The input ids of the language model.
            attention_mask: The attention mask of the language model.
        Returns:
            The last hidden state of the language model.
        """
        outputs = self.language_model(input_ids=input_ids, attention_mask=attention_mask)
        sentence_representation = outputs.last_hidden_state[:, 0, :]
        reduced_representation = self.dimension_reducer(sentence_representation)
        return reduced_representation


class TargetClassifier(nn.Module):
    """
    ...
    """
    def __init__(self, configs):
        super(TargetClassifier, self).__init__()
        self.logits = nn.Linear(2 * 128, 64)
        self.logits_simple = nn.Linear(64, configs.num_classes_target)
        self.text_encoder = FrozenLanguageModel().to(device)
        self.embedding_dim = self.text_encoder.language_model.config.hidden_size
        self.tokenizer = BertTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT',
                                                       cache_dir=BERT_PRETRAIN_PATH)

    @staticmethod
    def get_diagnostic_string(label: int):
        class_names = {
            0: "Normal ECG",  # "Normal beat"
            1: "Myocardial Infarction",  # "Supraventricular premature beat"
            2: "ST/T change",  # "Premature ventricular contraction"
            3: "Hypertrophy",  # "Fusion of ventricular and normal beat"
            4: "Conducion Disturbance"  # "Unclassifiable beat"
        }

        if label in class_names:
            diagnostic_type = class_names[label]
            return f"The ECG of {diagnostic_type}, a type of diagnostic"
        else:
            return "Invalid label"

    def zero_shot_process_text(self, labels) -> torch.Tensor:
        """
        Description:
            Process the text data for zero-shot learning.

        Args:
            text_data: The text data to be processed.

        Returns:
            torch.Tensor: The processed text data.
        """
        categories = [
            "Normal ECG",
            "Myocardial Infarction",
            "ST/T change",
            "Hypertrophy",
            "Conducion Disturbance"
        ]

        prompts = [self.get_diagnostic_string(label.item()) for label in labels]
        tokens = self.tokenizer(prompts, padding=True, truncation=True, return_tensors='pt', max_length=100)

        input_ids = tokens['input_ids'].to(device)
        attention_mask = tokens['attention_mask'].to(device)
        text_representation = self.text_encoder(input_ids, attention_mask)

        class_text_representation = {
            label: feature for label, feature in zip(categories, text_representation)
        }

        class_text_rep_tensor = torch.stack(list(class_text_representation.values()))

        return class_text_rep_tensor

    def similarity_classify(self, fea_concat: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        ...
        """
        # Get text embeddings from Language Model
        class_text_rep_tensor = self.zero_shot_process_text(labels)

        # Calculate cosine similarity between the concatenated features and the text representation
        similarities = [F.cosine_similarity(elem.unsqueeze(0), class_text_rep_tensor) for elem in fea_concat]
        similarities = torch.stack(similarities)

        # probabilities = F.softmax(similarities, dim=1).cpu().detach().numpy()
        # max_probability_class = np.argmax(probabilities, axis=1)
        # max_probability_class = torch.tensor(max_probability_class).long()

        # return max_probability_class

        return similarities.to(device)

    def forward(self, fea_concat: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        ...
        """
        pred = self.similarity_classify(fea_concat, labels)
        return pred


### trainer.py

In [None]:
def Trainer(model,  model_optimizer, classifier,
            classifier_optimizer, train_dl, valid_dl,
            test_dl, device, logger,
            config, experiment_log_dir, training_mode):
    """
    Description:
        The main training function.
        This function trains the model and the classifier.
        Trainer is divided into three stages:
            1) pretrain
            2) finetune
            3) test

    Args:
        model: The model used for training.
        model_optimizer: The optimizer used for training.
        classifier: The classifier used for training.
        classifier_optimizer: The optimizer used for training.
        train_dl: The training dataloader.
        valid_dl: The validation dataloader.
        test_dl: The test dataloader.
        device: The device used for training.
        logger: The logger used for logging.
        config: The configuration dictionary.
        experiment_log_dir: The directory where the experiment logs will be stored.
        training_mode: The training mode.

    Returns:
        None
    """
    model = model.to(device)
    classifier = classifier.to(device)
    # Start training
    logger.debug("Training started ....")

    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer, 'min')
    if training_mode == 'pre_train':
        print('Pretraining on source dataset')
        for epoch in range(1, config.num_epoch + 1):
            # Train and validate
            """Train. In fine-tuning, this part is also trained???"""
            train_loss = model_pretrain(model, model_optimizer, criterion, train_dl, config, device, training_mode)
            logger.debug(f"\nPre-training Epoch : {epoch}/{config.num_epoch}, Train Loss : {train_loss.item():.4f}")

        # Save pretrained model
        os.makedirs(os.path.join(experiment_log_dir, "saved_models"), exist_ok=True)
        chkpoint = {'model_state_dict': model.state_dict()}
        torch.save(chkpoint, os.path.join(experiment_log_dir, "saved_models", "ckp_last.pt"))
        print(f"Pretrained model is stored at folder:{experiment_log_dir+'/saved_models'+'/ckp_last.pt'}")

    # Fine-tuning and Test
    if training_mode != 'pre_train':
        # fine-tune
        print('Fine-tune on Fine-tuning set')
        performance_list = []
        total_f1 = []
        KNN_f1 = []
        global emb_finetune, label_finetune, emb_test, label_test

        for epoch in range(1, config.num_epoch + 1):
            logger.debug(f'\nEpoch : {epoch}')

            valid_loss, emb_finetune, label_finetune, F1 = model_finetune(model, model_optimizer, valid_dl,
                                                                          config, device, training_mode,
                                                                          classifier=classifier,
                                                                          classifier_optimizer=classifier_optimizer)
            scheduler.step(valid_loss)

            # save best fine-tuning model
            global arch
            arch = 'ecg2mit-bih'
            if len(total_f1) == 0 or F1 > max(total_f1):
                print('update fine-tuned model')
                os.makedirs('experiments_logs/finetunemodel/', exist_ok=True)
                torch.save(model.state_dict(), 'experiments_logs/finetunemodel/' + arch + '_model.pt')
                torch.save(classifier.state_dict(), 'experiments_logs/finetunemodel/' + arch + '_classifier.pt')
            total_f1.append(F1)

            # evaluate on the test set
            # Testing set
            logger.debug('Test on Target datasts test set')
            model.load_state_dict(torch.load('experiments_logs/finetunemodel/' + arch + '_model.pt'))
            classifier.load_state_dict(torch.load('experiments_logs/finetunemodel/' + arch + '_classifier.pt'))
            test_loss, test_acc, test_auc, test_prc, emb_test, label_test, performance = model_test(
                model, test_dl, config, device, training_mode,
                classifier=classifier, classifier_optimizer=classifier_optimizer
            )
            performance_list.append(performance)

#             # Use KNN as another classifier; it's an alternation of the MLP classifier in function model_test.
#             # Experiments show KNN and MLP may work differently in different settings, so here we provide both.
#             # train classifier: KNN
#             neigh = KNeighborsClassifier(n_neighbors=5)
#             neigh.fit(emb_finetune, label_finetune)
#             knn_acc_train = neigh.score(emb_finetune, label_finetune)
#             # print('KNN finetune acc:', knn_acc_train)
#             representation_test = emb_test.detach().cpu().numpy()

#             knn_result = neigh.predict(representation_test)
#             knn_result_score = neigh.predict_proba(representation_test)
#             one_hot_label_test = one_hot_encoding(label_test)
#             # print(classification_report(label_test, knn_result, digits=4))
#             # print(confusion_matrix(label_test, knn_result))
#             knn_acc = accuracy_score(label_test, knn_result)
#             precision = precision_score(label_test, knn_result, average='macro', )
#             recall = recall_score(label_test, knn_result, average='macro', )
#             F1 = f1_score(label_test, knn_result, average='macro')
#             auc = roc_auc_score(one_hot_label_test, knn_result_score, average="macro", multi_class="ovr")
#             prc = average_precision_score(one_hot_label_test, knn_result_score, average="macro")
#             print('KNN Testing: Acc=%.4f| Precision = %.4f | Recall = %.4f | F1 = %.4f | AUROC= %.4f | AUPRC=%.4f' %
#                   (knn_acc, precision, recall, F1, auc, prc))
#             KNN_f1.append(F1)

        logger.debug("\n################## Best testing performance! #########################")
        performance_array = np.array(performance_list)
        best_performance = performance_array[np.argmax(performance_array[:, 0], axis=0)]
        # print('Best Testing Performance: Acc=%.4f| Precision = %.4f | Recall = %.4f | F1 = %.4f | AUROC= %.4f '
        #       '| AUPRC=%.4f' % (best_performance[0], best_performance[1], best_performance[2], best_performance[3],
        #                         best_performance[4], best_performance[5]))
        # print('Best KNN F1', max(KNN_f1))

        logger.debug('Best Testing Performance: Acc=%.4f | Precision = %.4f | Recall = %.4f | F1 = %.4f | AUROC= %.4f | AUPRC=%.4f' %
                     (best_performance[0], best_performance[1], best_performance[2], best_performance[3],
                      best_performance[4], best_performance[5]))

        logger.debug('Best KNN F1: %.4f' % max(KNN_f1))

    logger.debug("\n################## Training is Done! #########################")


def model_pretrain(model, model_optimizer, criterion, train_loader, config, device, training_mode,):
    """
    Description:
        This function is used for pre-training.

    Args:
        model: the model to be trained
        model_optimizer: the optimizer of the model
        criterion: the loss function
        train_loader: the training data loader

    Returns:
        None
    """
    total_loss = []
    model.train()
    global loss, loss_t, loss_f, l_TF, loss_c, data_test, data_f_test

    # optimizer
    model_optimizer.zero_grad()

    for batch_idx, (data, labels, aug1, data_f, aug1_f) in enumerate(tqdm(train_loader)):
        data, labels = data.float().to(device), labels.long().to(device)  # data: [128, 1, 178], labels: [128]
        aug1 = aug1.float().to(device)  # aug1 = aug2 : [128, 1, 178]
        data_f, aug1_f = data_f.float().to(device), aug1_f.float().to(device)  # aug1 = aug2 : [128, 1, 178]

        """Produce embeddings"""
        h_t, z_t, h_f, z_f = model(data, data_f)
        h_t_aug, z_t_aug, h_f_aug, z_f_aug = model(aug1, aug1_f)

        """Compute Pre-train loss"""
        """NTXentLoss: normalized temperature-scaled cross entropy loss. From SimCLR"""
        nt_xent_criterion = NTXentLoss_poly(device, config.batch_size, config.Context_Cont.temperature,
                                            config.Context_Cont.use_cosine_similarity)  # device, 128, 0.2, True

        loss_t = nt_xent_criterion(h_t, h_t_aug)
        loss_f = nt_xent_criterion(h_f, h_f_aug)
        l_TF = nt_xent_criterion(z_t, z_f)  # this is the initial version of TF loss

        l_1 = nt_xent_criterion(z_t, z_f_aug)
        l_2 = nt_xent_criterion(z_t_aug, z_f)
        l_3 = nt_xent_criterion(z_t_aug, z_f_aug)
        loss_c = (1 + l_TF - l_1) + (1 + l_TF - l_2) + (1 + l_TF - l_3)

        lam = 0.2
        loss = lam*(loss_t + loss_f) + l_TF

        total_loss.append(loss.item())
        loss.backward()
        model_optimizer.step()

    print(f"Pretraining: overall loss: {loss:.4f}, l_t: {loss_t:.4f}, l_f: {loss_f:.4f}, l_TF: {l_TF:.4f}")

    ave_loss = torch.tensor(total_loss).mean()

    return ave_loss


def model_finetune(model, model_optimizer, val_dl,
                   config, device, training_mode,
                   classifier=None, classifier_optimizer=None):
    """
    Description:
        This function is used for finetuning.

    Args:
        model: the model to be trained
        model_optimizer: the optimizer of the model
        criterion: the loss function
        train_loader: the training data loader

    Returns:
        None
    """
    global labels, pred_numpy, fea_concat_flat
    model.train()
    classifier.train()

    total_loss = []
    total_acc = []
    total_auc = []  # it should be outside of the loop
    total_prc = []

    criterion = nn.CrossEntropyLoss()
    outs = np.array([])
    trgs = np.array([])
    feas = np.array([])

    for data, labels, aug1, data_f, aug1_f in tqdm(val_dl):
        # print('Fine-tuning: {} of target samples'.format(labels.shape[0]))
        data, labels = data.float().to(device), labels.long().to(device)
        data_f = data_f.float().to(device)
        aug1 = aug1.float().to(device)
        aug1_f = aug1_f.float().to(device)

        """if random initialization:"""
        model_optimizer.zero_grad()  # The gradients are zero, but the parameters are still randomly initialized.
        classifier_optimizer.zero_grad()  # the classifier is newly added and randomly initialized

        """Produce embeddings"""
        # Get time and frequency embeddings from the encoder
        h_t, z_t, h_f, z_f = model(data, data_f)
        h_t_aug, z_t_aug, h_f_aug, z_f_aug = model(aug1, aug1_f)

        nt_xent_criterion = NTXentLoss_poly(device, config.target_batch_size, config.Context_Cont.temperature,
                                            config.Context_Cont.use_cosine_similarity)
        loss_t = nt_xent_criterion(h_t, h_t_aug)
        loss_f = nt_xent_criterion(h_f, h_f_aug)
        l_TF = nt_xent_criterion(z_t, z_f)

        l_1 = nt_xent_criterion(z_t, z_f_aug)
        l_2 = nt_xent_criterion(z_t_aug, z_f)
        l_3 = nt_xent_criterion(z_t_aug, z_f_aug)
        loss_c = (1 + l_TF - l_1) + (1 + l_TF - l_2) + (1 + l_TF - l_3)

        """Add supervised classifier: 1) it's unique to finetuning. 2) this classifier will also be used in test."""
        # Get text embeddings from Language Model

        fea_concat = torch.cat((z_t, z_f), dim=1).to(device)
        labels = labels.to(device)
        predictions = classifier(fea_concat, labels)
        fea_concat_flat = fea_concat.reshape(fea_concat.shape[0], -1)
        loss_p = criterion(predictions, labels)

        lam = 0.1
        loss = loss_p + l_TF + lam * (loss_t + loss_f)

        acc_bs = labels.eq(predictions.detach().argmax(dim=1)).float().mean()
        onehot_label = F.one_hot(labels, num_classes=5)
        pred_numpy = predictions.detach().cpu().numpy()

        try:
            auc_bs = roc_auc_score(onehot_label.detach().cpu().numpy(), pred_numpy, average="macro", multi_class="ovr")
        except ValueError:
            auc_bs = float(0)
    
        prc_bs = average_precision_score(onehot_label.detach().cpu().numpy(), pred_numpy)

        total_acc.append(acc_bs)
        total_auc.append(auc_bs)
        total_prc.append(prc_bs)
        total_loss.append(loss.item())
        loss.backward()
        model_optimizer.step()
        classifier_optimizer.step()

        if training_mode != "pre_train":
            pred = predictions.max(1, keepdim=True)[1]  # get the index of the max log-probability
            outs = np.append(outs, pred.cpu().numpy())
            trgs = np.append(trgs, labels.data.cpu().numpy())
            feas = np.append(feas, fea_concat_flat.data.cpu().numpy())

    feas = feas.reshape([len(trgs), -1])  # produce the learned embeddings

    labels_numpy = labels.detach().cpu().numpy()
    pred_numpy = np.argmax(pred_numpy, axis=1)
    precision = precision_score(labels_numpy, pred_numpy, average='macro', )
    recall = recall_score(labels_numpy, pred_numpy, average='macro', )
    F1 = f1_score(labels_numpy, pred_numpy, average='macro', )
    ave_loss = torch.tensor(total_loss).mean()
    ave_acc = torch.tensor(total_acc).mean()
    ave_auc = torch.tensor(total_auc).mean()
    ave_prc = torch.tensor(total_prc).mean()

    print(' Finetune: loss = %.4f| Acc=%.4f | Precision = %.4f | Recall = %.4f | F1 = %.4f| AUROC=%.4f | AUPRC = %.4f'
          % (ave_loss, ave_acc * 100, precision * 100, recall * 100, F1 * 100, ave_auc * 100, ave_prc * 100))

    return ave_loss, feas, trgs, F1


def model_test(model, test_dl, config,  device, training_mode, classifier=None, classifier_optimizer=None):
    """
    Description:
        This function is used for testing.
        model_test is divided into two stages:
            1) test
            2) test_classifier

    Args:
        model: The model used for testing.
        test_dl: The testing dataloader.
        config: The configuration dictionary.
        device: The device used for testing.
        training_mode: The training mode.

    Returns:
        None
    """
    model.eval()
    classifier.eval()

    total_loss = []
    total_acc = []
    total_auc = []
    total_prc = []

    criterion = nn.CrossEntropyLoss()  # the loss for downstream classifier
    outs = np.array([])
    trgs = np.array([])
    emb_test_all = []

    with torch.no_grad():
        labels_numpy_all, pred_numpy_all = np.zeros(1), np.zeros(1)
        for data, labels, _, data_f, _ in tqdm(test_dl):
            data, labels = data.float().to(device), labels.long().to(device)
            data_f = data_f.float().to(device)

            """Add supervised classifier: 1) it's unique to finetuning. 2) this classifier will also be used in test"""
            h_t, z_t, h_f, z_f = model(data, data_f)
            fea_concat = torch.cat((z_t, z_f), dim=1)
            predictions_test = classifier(fea_concat, labels)
            fea_concat_flat = fea_concat.reshape(fea_concat.shape[0], -1)
            emb_test_all.append(fea_concat_flat)

            loss = criterion(predictions_test, labels)
            acc_bs = labels.eq(predictions_test.detach().argmax(dim=1)).float().mean()
            onehot_label = F.one_hot(labels, num_classes=5)
            pred_numpy = predictions_test.detach().cpu().numpy()
            labels_numpy = labels.detach().cpu().numpy()
            try:
                auc_bs = roc_auc_score(onehot_label.detach().cpu().numpy(), pred_numpy,
                                       average="macro", multi_class="ovr")
            except ValueError:
                auc_bs = float(0)
            prc_bs = average_precision_score(onehot_label.detach().cpu().numpy(), pred_numpy, average="macro")
            pred_numpy = np.argmax(pred_numpy, axis=1)

            total_acc.append(acc_bs)
            total_auc.append(auc_bs)
            total_prc.append(prc_bs)

            total_loss.append(loss.item())
            pred = predictions_test.max(1, keepdim=True)[1]  # get the index of the max log-probability
            outs = np.append(outs, pred.cpu().numpy())
            trgs = np.append(trgs, labels.data.cpu().numpy())
            labels_numpy_all = np.concatenate((labels_numpy_all, labels_numpy))
            pred_numpy_all = np.concatenate((pred_numpy_all, pred_numpy))

    labels_numpy_all = labels_numpy_all[1:]
    pred_numpy_all = pred_numpy_all[1:]

    # print('Test classification report', classification_report(labels_numpy_all, pred_numpy_all))
    # print(confusion_matrix(labels_numpy_all, pred_numpy_all))
    precision = precision_score(labels_numpy_all, pred_numpy_all, average='macro', )
    recall = recall_score(labels_numpy_all, pred_numpy_all, average='macro', )
    F1 = f1_score(labels_numpy_all, pred_numpy_all, average='macro', )
    acc = accuracy_score(labels_numpy_all, pred_numpy_all, )

    total_loss = torch.tensor(total_loss).mean()
    total_acc = torch.tensor(total_acc).mean()
    total_auc = torch.tensor(total_auc).mean()
    total_prc = torch.tensor(total_prc).mean()

    performance = [acc * 100, precision * 100, recall * 100, F1 * 100, total_auc * 100, total_prc * 100]
    print('MLP Testing: Acc=%.4f| Precision = %.4f | Recall = %.4f | F1 = %.4f | AUROC= %.4f | AUPRC=%.4f'
          % (acc * 100, precision * 100, recall * 100, F1 * 100, total_auc * 100, total_prc * 100))
    emb_test_all = torch.concat(tuple(emb_test_all))
    return total_loss, total_acc, total_auc, total_prc, emb_test_all, trgs, performance

### utils.py

In [None]:
def set_requires_grad(model, dict_, requires_grad=True):
    for param in model.named_parameters():
        if param[0] in dict_:
            param[1].requires_grad = requires_grad


def fix_randomness(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


def _calc_metrics(pred_labels, true_labels, log_dir, home_path):
    pred_labels = np.array(pred_labels).astype(int)
    true_labels = np.array(true_labels).astype(int)

    # save targets
    labels_save_path = os.path.join(log_dir, "labels")
    os.makedirs(labels_save_path, exist_ok=True)
    np.save(os.path.join(labels_save_path, "predicted_labels.npy"), pred_labels)
    np.save(os.path.join(labels_save_path, "true_labels.npy"), true_labels)

    r = classification_report(true_labels, pred_labels, digits=6, output_dict=True)
    cm = confusion_matrix(true_labels, pred_labels)
    df = pd.DataFrame(r)
    df["cohen"] = cohen_kappa_score(true_labels, pred_labels)
    df["accuracy"] = accuracy_score(true_labels, pred_labels)
    df = df * 100

    # save classification report
    exp_name = os.path.split(os.path.dirname(log_dir))[-1]
    training_mode = os.path.basename(log_dir)
    file_name = f"{exp_name}_{training_mode}_classification_report.xlsx"
    report_Save_path = os.path.join(home_path, log_dir, file_name)
    df.to_excel(report_Save_path)

    # save confusion matrix
    cm_file_name = f"{exp_name}_{training_mode}_confusion_matrix.torch"
    cm_Save_path = os.path.join(home_path, log_dir, cm_file_name)
    torch.save(cm, cm_Save_path)


def _logger(logger_name, level=logging.DEBUG):
    """
    Method to return a custom logger with the given name and level
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(level)
    # format_string = ("%(asctime)s — %(name)s — %(levelname)s — %(funcName)s:"
    #                 "%(lineno)d — %(message)s")
    format_string = "%(message)s"
    log_format = logging.Formatter(format_string)
    # Creating and adding the console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(log_format)
    logger.addHandler(console_handler)
    # Creating and adding the file handler
    file_handler = logging.FileHandler(logger_name, mode='a')
    file_handler.setFormatter(log_format)
    logger.addHandler(file_handler)
    return logger


def copy_Files(destination, data_type):
    # destination: 'experiments_logs/Exp1/run1'
    destination_dir = os.path.join(destination, "model_files")
    os.makedirs(destination_dir, exist_ok=True)
    copy("main.py", os.path.join(destination_dir, "main.py"))
    copy("trainerfun/trainer.py", os.path.join(destination_dir, "trainerfun.py"))
    copy(f"config_files/{data_type}_Configs.py", os.path.join(destination_dir, f"{data_type}_Configs.py"))
    copy("dataloader/augmentations.py", os.path.join(destination_dir, "augmentations.py"))
    copy("dataloader/dataloader.py", os.path.join(destination_dir, "dataloader.py"))
    copy("models/model.py", os.path.join(destination_dir, "model.py"))
    copy("models/loss.py", os.path.join(destination_dir, "loss.py"))
    copy("models/TC.py", os.path.join(destination_dir, "TC.py"))


### main.py

In [None]:
# Args selections
start_time = datetime.now()
parser = argparse.ArgumentParser()

# Model parameters
home_dir = os.getcwd()


# Set up command line arguments and create parser
parser.add_argument('--run_description', default='run1', type=str,
                    help='Experiment Description')
parser.add_argument('--seed', default=42, type=int,
                    help='seed value')
parser.add_argument('--training_mode', default='pre_train', type=str,
                    help='pre_train, fine_tune_test')
parser.add_argument('--pretrain_dataset', default='ECG', type=str,
                    help='Dataset of choice: ECG')
parser.add_argument('--target_dataset', default='MIT-BIH', type=str,
                    help='Dataset of choice: EMG, MIT-BIH, PTB-XL-Superclass, PTB-XL-Form, PTB-XL-Rhythm')
parser.add_argument('--logs_save_dir', default='./experiments_logs', type=str,
                    help='saving directory')
parser.add_argument('--device', default='cuda', type=str,
                    help='cpu or cuda')
parser.add_argument('--home_path', default=home_dir, type=str,
                    help='Project home directory')

args, unknown = parser.parse_known_args()



# Set up paths, experiment description and loggers
pretrain_dataset = args.pretrain_dataset
target_data = args.target_dataset
experiment_description = str(pretrain_dataset) + '_2_' + str(target_data)

method = 'TF-C'
training_mode = args.training_mode
run_description = args.run_description
logs_save_dir = args.logs_save_dir
os.makedirs(logs_save_dir, exist_ok=True)

# Use ECG_Configs
# exec(f'from config_files.{pretrain_dataset}_Configs import Config as Configs')
configs = Configs()

# fix random seeds for reproducibility
SEED = args.seed
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

# Set up experiment log directory and initialize logger
experiment_log_dir = os.path.join(logs_save_dir, experiment_description, run_description,
                                  training_mode + f"_seed_{SEED}_2layertransformer")
# 'experiments_logs/Exp1/run1/train_linear_seed_0'
os.makedirs(experiment_log_dir, exist_ok=True)

# loop through domains
counter = 0
src_counter = 0

# Logging
log_file_name = os.path.join(experiment_log_dir, f"logs_{datetime.now().strftime('%d_%m_%Y_%H_%M_%S')}.log")
# 'experiments_logs/Exp1/run1/train_linear_seed_0/logs_14_04_2022_15_13_12.log'
logger = _logger(log_file_name)
logger.debug("=" * 45)
logger.debug('Pre-training Dataset: %s', pretrain_dataset)
logger.debug('Target (fine-tuning) Dataset: %s', target_data)
logger.debug('Method:  %s', method)
logger.debug('Mode: %s', training_mode)
logger.debug("=" * 45)

# Load datasets
os.chdir("/kaggle/input/ecg2mit-bih")
sourcedata_path = f"./datasets/{pretrain_dataset}"
target_data_path = f"./datasets/{target_data}"
subset = False  # if subset=True, use a subset for debugging.
train_dl, valid_dl, test_dl = data_generator(sourcedata_path, target_data_path,
                                             configs, training_mode, subset=subset)
logger.debug("Data loaded ...")
os.chdir("/kaggle/working/")

# Load Model
# Here are two models, one basemodel, another is temporal contrastive model
TFC_model = TFC(configs).to(device)
classifier = TargetClassifier(configs).to(device)
temporal_contr_model = None

# continue pre-train
# load saved model of this experiment
load_from = os.path.join(os.path.join(logs_save_dir, experiment_description, run_description,
                         f"pre_train_seed_{SEED}_2layertransformer", "saved_models"))
print("The loading file path", load_from)
chkpoint = torch.load(os.path.join(load_from, "ckp_last.pt"), map_location=device)
pretrained_dict = chkpoint["model_state_dict"]
TFC_model.load_state_dict(pretrained_dict)

if training_mode == "fine_tune_test":
    # load saved model of this experiment
    load_from = os.path.join(os.path.join(logs_save_dir, experiment_description, run_description,
                             f"pre_train_seed_{SEED}_2layertransformer", "saved_models"))
    print("The loading file path", load_from)
    chkpoint = torch.load(os.path.join(load_from, "ckp_last.pt"), map_location=device)
    pretrained_dict = chkpoint["model_state_dict"]
    TFC_model.load_state_dict(pretrained_dict)

model_optimizer = torch.optim.Adam(TFC_model.parameters(), lr=configs.lr,
                                   betas=(configs.beta1, configs.beta2), weight_decay=3e-4)
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=configs.lr,
                                        betas=(configs.beta1, configs.beta2), weight_decay=3e-4)

# Trainer
Trainer(TFC_model, model_optimizer, classifier,
        classifier_optimizer, train_dl, valid_dl,
        test_dl, device, logger,
        configs, experiment_log_dir, training_mode)

logger.debug("Training time is : %s", datetime.now() - start_time)
