In [None]:
# libraries
import numpy as np
import pandas as pd 
import torch
import random
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import pearson_corrcoef
from torchmetrics.regression import MeanAbsolutePercentageError, MeanAbsoluteError
from torchmetrics import Metric
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer
from sklearn.model_selection import train_test_split
import itertools

In [None]:
# model functions
class MAECoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("mae", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        abs_error = MeanAbsoluteError()
        for p, t, m in zip(preds, target, mask):
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_mae = abs_error(mp, mt)
            coeffs.append(temp_mae)
        coeffs = torch.stack(coeffs)
        self.mae += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.mae / self.total

class CorrCoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("corrcoefs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        for p, t, m in zip(preds, target, mask):
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_pearson = pearson_corrcoef(mp, mt)
            coeffs.append(temp_pearson)
        coeffs = torch.stack(coeffs)
        self.corrcoefs += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.corrcoefs / self.total

class MAPECoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("mapecoefs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        perc_error = MeanAbsolutePercentageError()
        for p, t, m in zip(preds, target, mask):
            # remove first token in p
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_mape = perc_error(mp, mt)
            coeffs.append(temp_mape)
        coeffs = torch.stack(coeffs)
        self.mapecoefs += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.mapecoefs / self.total

# collate function
def collate_fn(batch):
    # batch is a list of tuples (x, y)
    x, y, ctrl_y, gene, transcript = zip(*batch)

    # sequence lenghts 
    lengths = torch.tensor([len(x) for x in x])
    
    x = pad_sequence(x, batch_first=True, padding_value=384) 
    y = pad_sequence(y, batch_first=True, padding_value=-1)
    ctrl_y = pad_sequence(ctrl_y, batch_first=True, padding_value=-1)

    out_batch = {}

    out_batch["input_ids"] = x
    out_batch["labels"] = y
    out_batch["lengths"] = lengths
    out_batch["labels_ctrl"] = ctrl_y

    return out_batch

# compute metrics
def compute_metrics(pred):
    labels = pred.label_ids 
    preds = pred.predictions
    inputs = pred.inputs
    mask = labels != -100.0
    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    preds = torch.squeeze(preds, dim=2)
    
    mask = torch.tensor(mask)
    
    # mask = torch.arange(preds.shape[1])[None, :].to(lengths) < lengths[:, None]
    mask = torch.logical_and(mask, torch.logical_not(torch.isnan(labels)))

    corr_coef = CorrCoef()
    corr_coef.update(preds, labels, mask)

    mae_coef = MAECoef()
    mae_coef.update(preds, labels, mask)

    mape_coef = MAPECoef()
    mape_coef.update(preds, labels, mask)

    return {"r": corr_coef.compute(), "mae": mae_coef.compute(), "mape": mape_coef.compute()}

# compute metrics
def compute_metrics_saved(pred):
    '''
    additional function to just save everything to do analysis later
    '''
    labels = pred.label_ids 
    preds = pred.predictions
    inputs = pred.inputs
    mask = labels != -100.0
    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    preds = torch.squeeze(preds, dim=2)
    
    mask = torch.tensor(mask)
    
    # mask = torch.arange(preds.shape[1])[None, :].to(lengths) < lengths[:, None]
    mask = torch.logical_and(mask, torch.logical_not(torch.isnan(labels)))

    mae_coef = MAECoef()
    mae_coef.update(preds, labels, mask)

    corr_coef = CorrCoef()
    corr_coef.update(preds, labels, mask)

    mape_coef = MAPECoef()
    mape_coef.update(preds, labels, mask)

    # save predictions
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()

    np.save("preds/preds.npy", preds)
    np.save("preds/labels.npy", labels)
    np.save("preds/inputs.npy", inputs)

    return {"r": corr_coef.compute(), "mae": mae_coef.compute(), "mape": mape_coef.compute()}

In [None]:
# global variables
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

def checkArrayEquality(arr1, arr2):
    '''
    inputs: two arrays
    outputs: True if the arrays are equal, False otherwise
    '''
    if len(arr1) != len(arr2):
        return False
    
    for i in range(len(arr1)):
        if arr1[i] != arr2[i]:
            return False
    
    return True

# dataset generation functions
def longestZeroSeqLength(a):
    '''
    length of the longest sub-sequence of zeros
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    # longest sequence of zeros
    longest = 0
    current = 0
    for i in a:
        if i == 0.0:
            current += 1
        else:
            longest = max(longest, current)
            current = 0
    longest = max(longest, current)
    return longest

def percNans(a):
    '''
    returns the percentage of nans in the sequence
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    a = np.asarray(a)
    perc = np.count_nonzero(np.isnan(a)) / len(a)

    return perc

def coverageMod(a, window_size=30):
    '''
    returns the modified coverage function val in the sequence
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    for i in range(len(a) - window_size):
        if np.all(a[i:i+window_size] == 0.0):
            a[i:i+window_size] = np.nan

    # num non zero, non nan
    num = 0
    den = 0
    for i in a:
        if i != 0.0 and not np.isnan(i):
            num += 1
        if not np.isnan(i):
            den += 1
    
    return num / den

def sequenceLength(a):
    '''
    returns the length of the sequence
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    return len(a)

def mergeAnnotations(annots):
    '''
    merge the annotations for the same gene
    '''
    # get the annotations
    annots = [a[1:-1].split(', ') for a in annots]
    annots = [[float(k) for k in a] for a in annots]

    # merge the annotations
    merged_annots = []
    for i in range(len(annots[0])):
        # get the ith annotation for all the transcripts, only non zero and non nan
        ith_annots = [a[i] for a in annots if a[i] != 0.0 and not np.isnan(a[i])]
        # take the mean of the ith annotation
        ith_mean = np.mean(ith_annots)
        merged_annots.append(ith_mean)

    return merged_annots

def uniqueGenes(df):
    # add sequence length column
    df['sequence_length'] = df['annotations'].apply(sequenceLength)

    unique_genes = list(df['gene'].unique())

    # iterate through each gene, and choose the longest transcript, for the annotation, merge the annotations
    for gene in unique_genes:
        # get the df for the gene
        df_gene = df[df['gene'] == gene]
        if len(df_gene) > 1:
            # get the transcript with the longest sequence
            df_gene = df_gene.sort_values('sequence_length', ascending=False)
            # chosen transcript
            chosen_transcript = df_gene['transcript'].values[0]
            other_transcripts = df_gene['transcript'].values[1:]
            # merge the annotations
            annotations = df_gene['annotations'].values
            merged_annotations = mergeAnnotations(annotations)
            # drop the other transcripts from the df
            df = df[~df['transcript'].isin(other_transcripts)]

            # change the annotations for the chosen transcript
            df.loc[df['transcript'] == chosen_transcript, 'annotations'] = str(merged_annotations)

    # drop sequence length column
    df = df.drop(columns=['sequence_length'])

    assert len(df['gene'].unique()) == len(df['gene'])
    assert len(df['transcript'].unique()) == len(df['transcript'])
    assert len(df['transcript']) == len(df['gene'])

    return df
    
def slidingWindowZeroToNan(a, window_size=30):
    '''
    use a sliding window, if all the values in the window are 0, then replace them with nan
    '''
    a = [float(k) for k in a]
    a = np.asarray(a)
    for i in range(len(a) - window_size):
        if np.all(a[i:i+window_size] == 0.0):
            a[i:i+window_size] = np.nan

    return a

def RiboDatasetGWS(depr_folder: str, ds: str, threshold: float = 0.6, longZerosThresh: int = 20, percNansThresh: float = 0.1):
    if ds == 'ALL':
        ctrl_depr_path = depr_folder + 'CTRL_AA.csv'
        ile_path = depr_folder + 'ILE_AA.csv'
        leu_path = depr_folder + 'LEU_AA.csv'
        val_path = depr_folder + 'VAL_AA.csv'
        leu_ile_path = depr_folder + 'LEU-ILE_AA_remBadRep.csv'
        leu_ile_val_path = depr_folder + 'LEU-ILE-VAL_AA.csv'
        liver_path = depr_folder + 'LIVER.csv'

        # load the control data
        df_liver = pd.read_csv(liver_path)
        df_liver['condition'] = 'CTRL'

        # load ctrl_aa data
        df_ctrl_depr = pd.read_csv(ctrl_depr_path)
        df_ctrl_depr['condition'] = 'CTRL'

        # add to the liver data the genes from ctrl depr which are not in liver
        tr_liver = df_liver['transcript'].unique()
        tr_ctrl_depr = df_ctrl_depr['transcript'].unique()
        tr_to_add = [g for g in tr_liver if g not in tr_ctrl_depr]

        df_liver = df_liver[df_liver['transcript'].isin(tr_to_add)]

        # df_liver transcripts only, save that info
        df_liver_transcripts = list(df_liver['transcript'])
        np.savez('../../data/extras/liver_transcripts.npz', df_liver_transcripts)
        print("Saved liver transcripts")

        # df ctrldepr without liver intersection
        df_ctrldepr_liver = pd.concat([df_liver, df_ctrl_depr], axis=0)

        # unique genes
        df_ctrldepr_liver = uniqueGenes(df_ctrldepr_liver)

        # get ctrl gene, transcript tuple pairs from the df_ctrldepr_liver
        ctrl_genes_transcripts = list(zip(df_ctrldepr_liver['gene'], df_ctrldepr_liver['transcript']))
        # make a list of lists
        ctrl_genes_transcripts = [[gene, transcript] for gene, transcript in ctrl_genes_transcripts]

        # other conditions
        df_ile = pd.read_csv(ile_path)
        df_ile['condition'] = 'ILE'
        # unique genes
        df_ile = uniqueGenes(df_ile)
        # only choose those genes+transcripts that are in ctrl_depr_liver
        # iterate through the df_ile and choose those genes that are in ctrl_genes_transcripts
        for index, row in df_ile.iterrows():
            if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
                df_ile.drop(index, inplace=True) 

        df_leu = pd.read_csv(leu_path)
        df_leu['condition'] = 'LEU'
        # unique genes
        df_leu = uniqueGenes(df_leu)
        # choose those transcripts that are in ctrl_depr_liver
        for index, row in df_leu.iterrows():
            if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
                df_leu.drop(index, inplace=True)

        df_val = pd.read_csv(val_path)
        df_val['condition'] = 'VAL'
        # unique genes
        df_val = uniqueGenes(df_val)
        # choose those transcripts that are in ctrl_depr_liver
        for index, row in df_val.iterrows():
            if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
                df_val.drop(index, inplace=True)

        df_leu_ile = pd.read_csv(leu_ile_path)
        df_leu_ile['condition'] = 'LEU_ILE'
        # unique genes
        df_leu_ile = uniqueGenes(df_leu_ile)
        # choose those transcripts that are in ctrl_depr_liver
        for index, row in df_leu_ile.iterrows():
            if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
                df_leu_ile.drop(index, inplace=True)

        df_leu_ile_val = pd.read_csv(leu_ile_val_path)
        df_leu_ile_val['condition'] = 'LEU_ILE_VAL'
        # unique genes
        df_leu_ile_val = uniqueGenes(df_leu_ile_val)
        # choose those transcripts that are in ctrl_depr_liver
        for index, row in df_leu_ile_val.iterrows():
            if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
                df_leu_ile_val.drop(index, inplace=True)

        df_full = pd.concat([df_ctrldepr_liver, df_ile, df_leu, df_val, df_leu_ile, df_leu_ile_val], axis=0) # liver + ctrl depr + ile + leu + val + leu ile + leu ile val

        df_full.columns = ['index_val', 'gene', 'transcript', 'sequence', 'annotations', 'perc_non_zero_annots', 'condition']

        # drop index_val column
        df_full = df_full.drop(columns=['index_val'])

        assert len(df_full['transcript'].unique()) == len(df_full['gene'].unique())

        # apply annot threshold
        df_full['coverage_mod'] = df_full['annotations'].apply(coverageMod)
        df_full = df_full[df_full['coverage_mod'] >= threshold]

        # for all the sequences in a condition that is not CTRL, add their respective CTRL sequence to them
        sequences_ctrl = []
        annotations_list = list(df_full['annotations'])
        condition_df_list = list(df_full['condition'])
        genes_list = list(df_full['gene'])

        for i in range(len(condition_df_list)):
            try:
                if condition_df_list[i] != 'CTRL':
                    # find the respective CTRL sequence for the transcript
                    ctrl_sequence = df_full[(df_full['gene'] == genes_list[i]) & (df_full['condition'] == 'CTRL')]['annotations'].iloc[0]
                    sequences_ctrl.append(ctrl_sequence)
                else:
                    sequences_ctrl.append(annotations_list[i])
            except:
                sequences_ctrl.append('NA')

        # add the sequences_ctrl to the df
        print(len(sequences_ctrl), len(annotations_list))
        df_full['ctrl_sequence'] = sequences_ctrl

        # remove those rows where the ctrl_sequence is NA
        df_full = df_full[df_full['ctrl_sequence'] != 'NA']

        # sanity check for the ctrl sequences
        # get the ds with only condition as CTRL
        df_ctrl_full = df_full[df_full['condition'] == 'CTRL']
        ctrl_sequences_san = list(df_ctrl_full['annotations'])
        ctrl_sequences_san2 = list(df_ctrl_full['ctrl_sequence'])

        for i in range(len(ctrl_sequences_san)):
            assert ctrl_sequences_san[i] == ctrl_sequences_san2[i]

        print("Sanity Checked")

        # add the longest zero sequence length to the df
        df_full['longest_zero_seq_length_annotation'] = df_full['annotations'].apply(longestZeroSeqLength)
        df_full['longest_zero_seq_length_ctrl_sequence'] = df_full['ctrl_sequence'].apply(longestZeroSeqLength)

        # add the number of nans to the df
        df_full['perc_nans_annotation'] = df_full['annotations'].apply(percNans)
        df_full['perc_nans_ctrl_sequence'] = df_full['ctrl_sequence'].apply(percNans)

        # apply the threshold for the longest zero sequence length
        df_full = df_full[df_full['longest_zero_seq_length_annotation'] <= longZerosThresh]
        df_full = df_full[df_full['longest_zero_seq_length_ctrl_sequence'] <= longZerosThresh]

        # apply the threshold for the number of nans
        df_full = df_full[df_full['perc_nans_annotation'] <= percNansThresh]
        df_full = df_full[df_full['perc_nans_ctrl_sequence'] <= percNansThresh]

        # GWS for each condition
        genes = df_full['gene'].unique()
        gene_mean_coverage_mod = []
        for gene in genes:
            gene_mean_coverage_mod.append(df_full[df_full['gene'] == gene]['coverage_mod'].mean())

        gene_mean_coverage_mod = np.asarray(gene_mean_coverage_mod)
        genes = np.asarray(genes)

        # sort the genes by coverage_mod in descending order
        genes = genes[np.argsort(gene_mean_coverage_mod)[::-1]]

        # save genes list
        np.savez('genes_list.npz', genes)

        # num_test_genes = int(0.2 * len(genes))
        
        # test_genes = []
        # train_genes = []

        # for i in range(len(genes)):
        #     # alternating until 20% of the genes are in the test set
        #     # the rest are in the train set
        #     if i % 2 == 0 and len(test_genes) < num_test_genes:
        #         test_genes.append(genes[i])
        #     else:
        #         train_genes.append(genes[i])

        # # split the dataframe
        # df_train = df_full[df_full['gene'].isin(train_genes)]
        # df_test = df_full[df_full['gene'].isin(test_genes)]

        # out_train_path = '../../data/orig/train_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'
        # out_test_path = '../../data/orig/test_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'
        # out_val_path = '../../data/orig/val_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'

        # out_train_path = 'data/orig/train_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'
        # out_test_path = 'data/orig/test_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'
        # out_val_path = 'data/orig/val_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'

        # df_train.to_csv(out_train_path, index=False)
        # df_test.to_csv(out_test_path, index=False)
        # df_val.to_csv(out_val_path, index=False)

        # df_train = pd.read_csv(out_train_path)
        # df_test = pd.read_csv(out_test_path)
        # df_val = pd.read_csv(out_val_path)

        return df_train, df_val, df_test

class GWSDatasetFromPandas(Dataset):
    def __init__(self, df):
        self.df = df
        self.counts = list(self.df['annotations'])
        self.sequences = list(self.df['sequence'])
        self.condition_lists = list(self.df['condition'])
        self.condition_values = {'CTRL': 64, 'ILE': 65, 'LEU': 66, 'LEU_ILE': 67, 'LEU_ILE_VAL': 68, 'VAL': 69}

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        X = self.df['sequence'].iloc[idx]
        # convert to int
        X = X[1:-1].split(', ')
        X = [int(i) for i in X]

        y = self.df['annotations'].iloc[idx]
        # convert string into list of floats
        y = y[1:-1].split(', ')
        y = [float(i) for i in y]

        y = slidingWindowZeroToNan(y)

        y = [1+i for i in y]
        y = np.log(y)

        # ctrl sequence 
        ctrl_y = self.df['ctrl_sequence'].iloc[idx]
        # convert string into list of floats
        ctrl_y = ctrl_y[1:-1].split(', ')
        ctrl_y = [float(i) for i in ctrl_y]

        ctrl_y = slidingWindowZeroToNan(ctrl_y)

        # no min max scaling
        ctrl_y = [1+i for i in ctrl_y]
        ctrl_y = np.log(ctrl_y)

        X = np.array(X)
        # multiply X with condition value times 64 + 1
        cond_token = self.condition_values[self.condition_lists[idx]]
        
        # prepend the condition token to X
        X = np.insert(X, 0, cond_token)

        y = np.array(y)

        X = torch.from_numpy(X).long()
        y = torch.from_numpy(y).float()
        ctrl_y = torch.from_numpy(ctrl_y).float()

        gene = self.df['gene'].iloc[idx]
        transcript = self.df['transcript'].iloc[idx]

        return X, y, ctrl_y, gene, transcript

In [None]:
# loss functions
class MaskedPearsonLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def __call__(self, y_pred, y_true, mask, eps=1e-6):
        y_pred_mask = torch.masked_select(y_pred, mask)
        y_true_mask = torch.masked_select(y_true, mask)
        cos = nn.CosineSimilarity(dim=0, eps=eps)
        return 1 - cos(
            y_pred_mask - y_pred_mask.mean(),
            y_true_mask - y_true_mask.mean(),
        )

class MaskedL1Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, y_pred, y_true, mask):
        y_pred_mask = torch.masked_select(y_pred, mask).float()
        y_true_mask = torch.masked_select(y_true, mask).float()

        loss = nn.functional.l1_loss(y_pred_mask, y_true_mask, reduction="none")
        return torch.sqrt(loss.mean())

class MaskedNormMAELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, y_pred, y_true, mask):
        y_pred_mask = torch.masked_select(y_pred, mask).float()
        y_true_mask = torch.masked_select(y_true, mask).float()

        loss = nn.functional.l1_loss(y_pred_mask, y_true_mask, reduction="none") 
        # get mean y true without nans
        # convert y_true_mask to numpy
        y_true_mask = y_true_mask.cpu().numpy()
        y_true_max = np.nanmax(y_true_mask)

        if y_true_max == 0:
            return torch.sqrt(loss.mean())
        else:
            return torch.sqrt(loss.mean()) / y_true_max

class MaskedCombinedFourDH(nn.Module):
    def __init__(self):
        super().__init__()
        self.pearson = MaskedPearsonLoss()
        self.l1 = MaskedL1Loss()
    
    def __call__(self, y_pred, labels, labels_ctrl, mask_full, mask_ctrl, condition_):
        # remove the first output cause that corresponds to the condition token
        # y_pred_ctrl = y_pred[:, :, 0]
        # relu on ctrl prediction
        y_pred_ctrl = torch.relu(y_pred[:, :, 0])
        
        y_pred_depr_diff = y_pred[:, :, 1]
        y_pred_full = torch.sum(y_pred, dim=2)

        labels_diff = labels - labels_ctrl

        # combine masks to make mask diff 
        mask_diff = mask_full & mask_ctrl

        loss_ctrl = self.pearson(y_pred_ctrl, labels_ctrl, mask_ctrl)
        if condition_ != 64:
            loss_depr_diff = self.pearson(y_pred_depr_diff, labels_diff, mask_diff)
        loss_full = self.pearson(y_pred_full, labels, mask_full) + self.l1(y_pred_full, labels, mask_full)

        if condition_ != 64:
            return loss_ctrl + loss_depr_diff + loss_full
        else:
            return loss_ctrl + loss_full

class MaskedCombinedFiveDH(nn.Module):
    def __init__(self):
        super().__init__()
        self.pearson = MaskedPearsonLoss()
        self.l1 = MaskedL1Loss()
    
    def __call__(self, y_pred, labels, labels_ctrl, mask_full, mask_ctrl):
        # remove the first output cause that corresponds to the condition token
        y_pred_ctrl = y_pred[:, :, 0]
        y_pred_depr_diff = y_pred[:, :, 1]
        y_pred_full = torch.sum(y_pred, dim=2)

        labels_diff = labels - labels_ctrl

        # combine masks to make mask diff 
        mask_diff = mask_full & mask_ctrl

        loss_ctrl = self.pearson(y_pred_ctrl, labels_ctrl, mask_ctrl) + self.l1(y_pred_ctrl, labels_ctrl, mask_ctrl)
        loss_depr_diff = self.l1(y_pred_depr_diff, labels_diff, mask_diff)
        loss_full = self.pearson(y_pred_full, labels, mask_full) + self.l1(y_pred_full, labels, mask_full)

        return loss_ctrl + loss_depr_diff + loss_full
    
class MaskedCombinedSixDH(nn.Module):
    def __init__(self):
        super().__init__()
        self.pearson = MaskedPearsonLoss()
        self.l1 = MaskedL1Loss()
    
    def __call__(self, y_pred, labels, labels_ctrl, mask_full, mask_ctrl, condition_):
        # remove the first output cause that corresponds to the condition token
        y_pred_ctrl = y_pred[:, :, 0]
        y_pred_depr_diff = y_pred[:, :, 1]
        y_pred_full = torch.sum(y_pred, dim=2)

        labels_diff = labels - labels_ctrl

        # combine masks to make mask diff 
        mask_diff = mask_full & mask_ctrl

        loss_ctrl = self.pearson(y_pred_ctrl, labels_ctrl, mask_ctrl) + self.l1(y_pred_ctrl, labels_ctrl, mask_ctrl)
        if condition_ == 64:
            loss_depr_diff = self.l1(y_pred_depr_diff, labels_diff, mask_diff)
        else:
            loss_depr_diff = self.l1(y_pred_depr_diff, labels_diff, mask_diff) + self.pearson(y_pred_depr_diff, labels_diff, mask_diff)
        loss_full = self.pearson(y_pred_full, labels, mask_full) + self.l1(y_pred_full, labels, mask_full)

        return loss_ctrl + loss_depr_diff + loss_full

In [None]:
# custom Four regression trainer
class RegressionTrainerFour(Trainer):
    def __init__(self, **kwargs,):
        super().__init__(**kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        condition_ = inputs['input_ids'][0][0]
        labels_ctrl = inputs.pop("labels_ctrl")
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.squeeze(logits, dim=2)
        # remove the first output cause that corresponds to the condition token
        logits = logits[:, 1:, :]
        lengths = inputs['lengths']

        loss_fnc = MaskedCombinedFourDH()
        
        mask_full = torch.arange(labels.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_full = torch.logical_and(mask_full, torch.logical_not(torch.isnan(labels)))

        mask_ctrl = torch.arange(labels_ctrl.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_ctrl = torch.logical_and(mask_ctrl, torch.logical_not(torch.isnan(labels_ctrl)))
        
        loss = loss_fnc(logits, labels, labels_ctrl, mask_full, mask_ctrl, condition_)

        return (loss, outputs) if return_outputs else loss 

# custom Five regression trainer
class RegressionTrainerFive(Trainer):
    def __init__(self, **kwargs,):
        super().__init__(**kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        labels_ctrl = inputs.pop("labels_ctrl")
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.squeeze(logits, dim=2)
        # remove the first output cause that corresponds to the condition token
        logits = logits[:, 1:, :]
        lengths = inputs['lengths']

        loss_fnc = MaskedCombinedFiveDH()
        
        mask_full = torch.arange(labels.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_full = torch.logical_and(mask_full, torch.logical_not(torch.isnan(labels)))

        mask_ctrl = torch.arange(labels_ctrl.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_ctrl = torch.logical_and(mask_ctrl, torch.logical_not(torch.isnan(labels_ctrl)))
        
        loss = loss_fnc(logits, labels, labels_ctrl, mask_full, mask_ctrl)

        return (loss, outputs) if return_outputs else loss 

# custom Six regression trainer
class RegressionTrainerSix(Trainer):
    def __init__(self, **kwargs,):
        super().__init__(**kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        condition_ = inputs['input_ids'][0][0]
        labels_ctrl = inputs.pop("labels_ctrl")
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.squeeze(logits, dim=2)
        # remove the first output cause that corresponds to the condition token
        logits = logits[:, 1:, :]
        lengths = inputs['lengths']

        loss_fnc = MaskedCombinedSixDH()
        
        mask_full = torch.arange(labels.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_full = torch.logical_and(mask_full, torch.logical_not(torch.isnan(labels)))

        mask_ctrl = torch.arange(labels_ctrl.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_ctrl = torch.logical_and(mask_ctrl, torch.logical_not(torch.isnan(labels_ctrl)))
        
        loss = loss_fnc(logits, labels, labels_ctrl, mask_full, mask_ctrl, condition_)

        return (loss, outputs) if return_outputs else loss 