In [None]:
# libraries
import numpy as np
import pandas as pd 
import torch
import random
from torch.nn.utils.rnn import pad_sequence
from torchmetrics.functional import pearson_corrcoef
from torchmetrics.regression import MeanAbsolutePercentageError, MeanAbsoluteError
from torchmetrics import Metric
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer
from sklearn.model_selection import train_test_split
import itertools

In [None]:
# model functions
class MAECoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("mae", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        abs_error = MeanAbsoluteError()
        for p, t, m in zip(preds, target, mask):
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_mae = abs_error(mp, mt)
            coeffs.append(temp_mae)
        coeffs = torch.stack(coeffs)
        self.mae += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.mae / self.total

class CorrCoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("corrcoefs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        for p, t, m in zip(preds, target, mask):
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_pearson = pearson_corrcoef(mp, mt)
            coeffs.append(temp_pearson)
        coeffs = torch.stack(coeffs)
        self.corrcoefs += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.corrcoefs / self.total

class MAPECoef(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("mapecoefs", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
    def update(self, preds, target, mask):
        preds = torch.sum(preds, dim=2)
        preds = preds[:, 1:]
        assert preds.shape == target.shape
        assert preds.shape == mask.shape
        coeffs = []
        perc_error = MeanAbsolutePercentageError()
        for p, t, m in zip(preds, target, mask):
            # remove first token in p
            mp, mt = torch.masked_select(p, m), torch.masked_select(t, m)
            temp_mape = perc_error(mp, mt)
            coeffs.append(temp_mape)
        coeffs = torch.stack(coeffs)
        self.mapecoefs += torch.sum(coeffs)
        self.total += len(coeffs)
    def compute(self):
        return self.mapecoefs / self.total

# collate function
def collate_fn(batch):
    # batch is a list of tuples (x, y)
    x, y, ctrl_y, gene, transcript = zip(*batch)

    # sequence lenghts 
    lengths = torch.tensor([len(x) for x in x])
    
    x = pad_sequence(x, batch_first=True, padding_value=384) 
    y = pad_sequence(y, batch_first=True, padding_value=-1)
    ctrl_y = pad_sequence(ctrl_y, batch_first=True, padding_value=-1)

    out_batch = {}

    out_batch["input_ids"] = x
    out_batch["labels"] = y
    out_batch["lengths"] = lengths
    out_batch["labels_ctrl"] = ctrl_y

    return out_batch

# compute metrics
def compute_metrics(pred):
    labels = pred.label_ids 
    preds = pred.predictions
    inputs = pred.inputs
    mask = labels != -100.0
    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    preds = torch.squeeze(preds, dim=2)
    
    mask = torch.tensor(mask)
    
    # mask = torch.arange(preds.shape[1])[None, :].to(lengths) < lengths[:, None]
    mask = torch.logical_and(mask, torch.logical_not(torch.isnan(labels)))

    corr_coef = CorrCoef()
    corr_coef.update(preds, labels, mask)

    mae_coef = MAECoef()
    mae_coef.update(preds, labels, mask)

    mape_coef = MAPECoef()
    mape_coef.update(preds, labels, mask)

    return {"r": corr_coef.compute(), "mae": mae_coef.compute(), "mape": mape_coef.compute()}

# compute metrics
def compute_metrics_saved(pred):
    '''
    additional function to just save everything to do analysis later
    '''
    labels = pred.label_ids 
    preds = pred.predictions
    inputs = pred.inputs
    mask = labels != -100.0
    labels = torch.tensor(labels)
    preds = torch.tensor(preds)
    preds = torch.squeeze(preds, dim=2)
    
    mask = torch.tensor(mask)
    
    # mask = torch.arange(preds.shape[1])[None, :].to(lengths) < lengths[:, None]
    mask = torch.logical_and(mask, torch.logical_not(torch.isnan(labels)))

    mae_coef = MAECoef()
    mae_coef.update(preds, labels, mask)

    corr_coef = CorrCoef()
    corr_coef.update(preds, labels, mask)

    mape_coef = MAPECoef()
    mape_coef.update(preds, labels, mask)

    # save predictions
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()

    np.save("preds/preds.npy", preds)
    np.save("preds/labels.npy", labels)
    np.save("preds/inputs.npy", inputs)

    return {"r": corr_coef.compute(), "mae": mae_coef.compute(), "mape": mape_coef.compute()}

In [None]:
# global variables
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

# dataset generation functions
def longestZeroSeqLength(a):
    '''
    length of the longest sub-sequence of zeros
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    # longest sequence of zeros
    longest = 0
    current = 0
    for i in a:
        if i == 0.0:
            current += 1
        else:
            longest = max(longest, current)
            current = 0
    longest = max(longest, current)
    return longest

def percNans(a):
    '''
    returns the percentage of nans in the sequence
    '''
    a = a[1:-1].split(', ')
    a = [float(k) for k in a]
    a = np.asarray(a)
    perc = np.count_nonzero(np.isnan(a)) / len(a)

    return perc

def slidingWindowZeroToNan(a, window_size=30):
    '''
    use a sliding window, if all the values in the window are 0, then replace them with nan
    '''
    a = np.asarray(a)
    for i in range(len(a) - window_size):
        if np.all(a[i:i+window_size] == 0.0):
            a[i:i+window_size] = np.nan

    return a

def RiboDatasetGWS(depr_folder: str, ds: str, threshold: float = 0.6, longZerosThresh: int = 20, percNansThresh: float = 0.1):
    if ds == 'ALL':
        # paths 
        ctrl_depr_path = depr_folder + 'CTRL_AA.csv'
        ile_path = depr_folder + 'ILE_AA.csv'
        leu_path = depr_folder + 'LEU_AA.csv'
        val_path = depr_folder + 'VAL_AA.csv'
        leu_ile_path = depr_folder + 'LEU-ILE_AA.csv'
        leu_ile_val_path = depr_folder + 'LEU-ILE-VAL_AA.csv'
        liver_path = depr_folder + 'LIVER.csv'

        # load the control data
        df_liver = pd.read_csv(liver_path)
        df_liver['condition'] = 'CTRL'

        # load ctrl_aa data
        df_ctrl_depr = pd.read_csv(ctrl_depr_path)
        df_ctrl_depr['condition'] = 'CTRL'

        # add to the liver data the genes from ctrl depr which are not in liver
        tr_liver = df_liver['transcript'].unique()
        tr_ctrl_depr = df_ctrl_depr['transcript'].unique()
        tr_to_add = [tr for tr in tr_liver if tr not in tr_ctrl_depr]

        df_liver = df_liver[df_liver['transcript'].isin(tr_to_add)]

        # df ctrldepr without liver intersection
        df_ctrldepr_liver = pd.concat([df_liver, df_ctrl_depr], axis=0)

        # check if there are any duplicates in transcript column of the df
        assert len(df_ctrldepr_liver['transcript'].unique()) == len(df_ctrldepr_liver['transcript'])

        # other conditions
        df_ile = pd.read_csv(ile_path)
        df_ile['condition'] = 'ILE'

        df_leu = pd.read_csv(leu_path)
        df_leu['condition'] = 'LEU'

        df_val = pd.read_csv(val_path)
        df_val['condition'] = 'VAL'

        df_leu_ile = pd.read_csv(leu_ile_path)
        df_leu_ile['condition'] = 'LEU_ILE'

        df_leu_ile_val = pd.read_csv(leu_ile_val_path)
        df_leu_ile_val['condition'] = 'LEU_ILE_VAL'

        df_full = pd.concat([df_liver, df_ctrl_depr, df_ile, df_leu, df_val, df_leu_ile, df_leu_ile_val], axis=0) # liver + ctrl depr + ile + leu + val + leu ile + leu ile val

        df_full.columns = ['index_val', 'gene', 'transcript', 'sequence', 'annotations', 'perc_non_zero_annots', 'condition']

        # drop index_val column
        df_full = df_full.drop(columns=['index_val'])

        # apply annot threshold
        df_full = df_full[df_full['perc_non_zero_annots'] >= threshold]

        # for all the sequences in a condition that is not CTRL, add their respective CTRL sequence to them
        sequences_ctrl = []
        annotations_list = list(df_full['annotations'])
        condition_df_list = list(df_full['condition'])
        transcripts_list = list(df_full['transcript'])

        for i in range(len(condition_df_list)):
            try:
                if condition_df_list != 'CTRL':
                    # find the respective CTRL sequence for the transcript
                    ctrl_sequence = df_full[(df_full['transcript'] == transcripts_list[i]) & (df_full['condition'] == 'CTRL')]['annotations'].iloc[0]
                    sequences_ctrl.append(ctrl_sequence)
                else:
                    sequences_ctrl.append(annotations_list[i])
            except:
                sequences_ctrl.append('NA')

        # add the sequences_ctrl to the df
        print(len(sequences_ctrl), len(annotations_list))
        df_full['ctrl_sequence'] = sequences_ctrl

        # remove those rows where the ctrl_sequence is NA
        df_full = df_full[df_full['ctrl_sequence'] != 'NA']

        # sanity check for the ctrl sequences
        # get the ds with only condition as CTRL
        df_ctrl_full = df_full[df_full['condition'] == 'CTRL']
        ctrl_sequences_san = list(df_ctrl_full['annotations'])
        ctrl_sequences_san2 = list(df_ctrl_full['ctrl_sequence'])

        for i in range(len(ctrl_sequences_san)):
            assert ctrl_sequences_san[i] == ctrl_sequences_san2[i]

        print("Sanity Checked")

        # get longest zero sequence length for each sequence in annotations and ctrl_sequence
        annotations_list = list(df_full['annotations'])
        sequences_ctrl = list(df_full['ctrl_sequence'])
        annotation_long_zeros = []
        ctrl_sequence_long_zeros = []
        num_nans_full = []
        num_nans_ctrl = []
        for i in range(len(annotations_list)):
            annotation_long_zeros.append(longestZeroSeqLength(annotations_list[i]))
            ctrl_sequence_long_zeros.append(longestZeroSeqLength(sequences_ctrl[i]))
            num_nans_full.append(percNans(annotations_list[i]))
            num_nans_ctrl.append(percNans(sequences_ctrl[i]))

        # add the longest zero sequence length to the df
        df_full['longest_zero_seq_length_annotation'] = annotation_long_zeros
        df_full['longest_zero_seq_length_ctrl_sequence'] = ctrl_sequence_long_zeros

        # add the number of nans to the df
        df_full['perc_nans_annotation'] = num_nans_full
        df_full['perc_nans_ctrl_sequence'] = num_nans_ctrl

        # apply the threshold for the longest zero sequence length
        df_full = df_full[df_full['longest_zero_seq_length_annotation'] <= longZerosThresh]
        df_full = df_full[df_full['longest_zero_seq_length_ctrl_sequence'] <= longZerosThresh]

        # apply the threshold for the number of nans
        df_full = df_full[df_full['perc_nans_annotation'] <= percNansThresh]
        df_full = df_full[df_full['perc_nans_ctrl_sequence'] <= percNansThresh]

        # GWS
        genes = df_full['gene'].unique()
        genes_train, genes_test = train_test_split(genes, test_size=0.2, random_state=42)

        # split the dataframe
        df_train = df_full[df_full['gene'].isin(genes_train)]
        df_test = df_full[df_full['gene'].isin(genes_test)]

        out_train_path = 'data/dh/train_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'
        out_test_path = 'data/dh/test_' + str(threshold) + '_NZ_' + str(longZerosThresh) + '_PercNan_' + str(percNansThresh) + '.csv'

        df_train.to_csv(out_train_path, index=False)
        df_test.to_csv(out_test_path, index=False)

        df_train = pd.read_csv(out_train_path)
        df_test = pd.read_csv(out_test_path)

        return df_train, df_test

class GWSDatasetFromPandas(Dataset):
    def __init__(self, df):
        self.df = df
        self.counts = list(self.df['annotations'])
        self.sequences = list(self.df['sequence'])
        self.condition_lists = list(self.df['condition'])
        self.condition_values = {'CTRL': 64, 'ILE': 65, 'LEU': 66, 'LEU_ILE': 67, 'LEU_ILE_VAL': 68, 'VAL': 69}

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        X = self.df['sequence'].iloc[idx]
        # convert to int
        X = X[1:-1].split(', ')
        X = [int(i) for i in X]

        y = self.df['annotations'].iloc[idx]
        # convert string into list of floats
        y = y[1:-1].split(', ')
        y = [float(i) for i in y]
        # sliding window zero to nan
        y = slidingWindowZeroToNan(y, window_size=30)

        y = [1+i for i in y]
        y = np.log(y)

        # ctrl sequence 
        ctrl_y = self.df['ctrl_sequence'].iloc[idx]
        # convert string into list of floats
        ctrl_y = ctrl_y[1:-1].split(', ')
        ctrl_y = [float(i) for i in ctrl_y]
        # sliding window zero to nan
        ctrl_y = slidingWindowZeroToNan(ctrl_y, window_size=30)

        # no min max scaling
        ctrl_y = [1+i for i in ctrl_y]
        ctrl_y = np.log(ctrl_y)

        X = np.array(X)
        # multiply X with condition value times 64 + 1
        cond_token = self.condition_values[self.condition_lists[idx]]
        
        # prepend the condition token to X
        X = np.insert(X, 0, cond_token)

        y = np.array(y)

        X = torch.from_numpy(X).long()
        y = torch.from_numpy(y).float()
        ctrl_y = torch.from_numpy(ctrl_y).float()

        gene = self.df['gene'].iloc[idx]
        transcript = self.df['transcript'].iloc[idx]

        return X, y, ctrl_y, gene, transcript

In [None]:
# loss functions
class MaskedPearsonLoss(nn.Module):
    def __init__(self):
        super().__init__()
    def __call__(self, y_pred, y_true, mask, eps=1e-6):
        y_pred_mask = torch.masked_select(y_pred, mask)
        y_true_mask = torch.masked_select(y_true, mask)
        cos = nn.CosineSimilarity(dim=0, eps=eps)
        return 1 - cos(
            y_pred_mask - y_pred_mask.mean(),
            y_true_mask - y_true_mask.mean(),
        )

class MaskedL1Loss(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, y_pred, y_true, mask):
        y_pred_mask = torch.masked_select(y_pred, mask).float()
        y_true_mask = torch.masked_select(y_true, mask).float()

        loss = nn.functional.l1_loss(y_pred_mask, y_true_mask, reduction="none")
        return torch.sqrt(loss.mean())

class MaskedNormMAELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def __call__(self, y_pred, y_true, mask):
        y_pred_mask = torch.masked_select(y_pred, mask).float()
        y_true_mask = torch.masked_select(y_true, mask).float()

        loss = nn.functional.l1_loss(y_pred_mask, y_true_mask, reduction="none") 
        # get mean y true without nans
        # convert y_true_mask to numpy
        y_true_mask = y_true_mask.cpu().numpy()
        y_true_max = np.nanmax(y_true_mask)

        if y_true_max == 0:
            return torch.sqrt(loss.mean())
        else:
            return torch.sqrt(loss.mean()) / y_true_max

class MaskedCombinedFiveDH(nn.Module):
    def __init__(self):
        super().__init__()
        self.pearson = MaskedPearsonLoss()
        self.l1 = MaskedL1Loss()
    
    def __call__(self, y_pred, labels, labels_ctrl, mask_full, mask_ctrl):
        # remove the first output cause that corresponds to the condition token
        y_pred_ctrl = y_pred[:, :, 0]
        y_pred_depr_diff = y_pred[:, :, 1]
        y_pred_full = torch.sum(y_pred, dim=2)

        labels_diff = labels - labels_ctrl

        # combine masks to make mask diff 
        mask_diff = mask_full & mask_ctrl

        loss_ctrl = self.pearson(y_pred_ctrl, labels_ctrl, mask_ctrl) + self.l1(y_pred_ctrl, labels_ctrl, mask_ctrl)
        loss_depr_diff = self.l1(y_pred_depr_diff, labels_diff, mask_diff)
        loss_full = self.pearson(y_pred_full, labels, mask_full) + self.l1(y_pred_full, labels, mask_full)

        return loss_ctrl + loss_depr_diff + loss_full

In [None]:
# custom Five regression trainer
class RegressionTrainerFive(Trainer):
    def __init__(self, **kwargs,):
        super().__init__(**kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        labels_ctrl = inputs.pop("labels_ctrl")
        outputs = model(**inputs)
        logits = outputs.logits
        logits = torch.squeeze(logits, dim=2)
        # remove the first output cause that corresponds to the condition token
        logits = logits[:, 1:, :]
        lengths = inputs['lengths']

        loss_fnc = MaskedCombinedFiveDH()
        
        mask_full = torch.arange(labels.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_full = torch.logical_and(mask_full, torch.logical_not(torch.isnan(labels)))

        mask_ctrl = torch.arange(labels_ctrl.shape[1])[None, :].to(lengths) < lengths[:, None]
        mask_ctrl = torch.logical_and(mask_ctrl, torch.logical_not(torch.isnan(labels_ctrl)))
        
        loss = loss_fnc(logits, labels, labels_ctrl, mask_full, mask_ctrl)

        return (loss, outputs) if return_outputs else loss 