In [None]:
# libraries
import numpy as np
import pandas as pd 
import torch
from transformers import XLNetConfig, XLNetForTokenClassification
import itertools
from tqdm import tqdm

In [None]:
threshold = 0.3
longZerosThresh = 20
percNansThresh = 0.05

In [None]:
# conditions
conditions_list = ['CTRL', 'LEU', 'ILE', 'VAL', 'LEU_ILE', 'LEU_ILE_VAL']
condition_values = {'CTRL': 64, 'ILE': 65, 'LEU': 66, 'LEU_ILE': 67, 'LEU_ILE_VAL': 68, 'VAL': 69}
id_to_codon = {idx:''.join(el) for idx, el in enumerate(itertools.product(['A', 'T', 'C', 'G'], repeat=3))}
codon_to_id = {v:k for k,v in id_to_codon.items()}

In [None]:
def pseudolabel(ground_truth, mean_preds):
    # process ground truth
    if ground_truth == 'NA':
        # make a list of np.nans
        ground_truth = [np.nan for j in range(len(mean_preds))]
    else:
        ground_truth = [float(k) for k in ground_truth]

    annot = []
    for j in range(len(mean_preds)):
        if (np.isnan(ground_truth[j]) or ground_truth[j] == 0.0):
            annot.append(np.abs(mean_preds[j]))
        else:
            annot.append(ground_truth[j])
    
    return annot

# dataset generation functions
def longestZeroSeqLength(a):
    '''
    length of the longest sub-sequence of zeros
    '''
    a = [float(k) for k in a]
    # longest sequence of zeros
    longest = 0
    current = 0
    for i in a:
        if i == 0.0:
            current += 1
        else:
            longest = max(longest, current)
            current = 0
    longest = max(longest, current)
    return longest

def percNans(a):
    '''
    returns the percentage of nans in the sequence
    '''
    a = [float(k) for k in a]
    a = np.asarray(a)
    perc = np.count_nonzero(np.isnan(a)) / len(a)

    return perc

def coverageMod(a, window_size=30):
    '''
    returns the modified coverage function val in the sequence
    '''
    a = [float(k) for k in a]
    for i in range(len(a) - window_size):
        if np.all(a[i:i+window_size] == 0.0):
            a[i:i+window_size] = np.nan

    # num non zero, non nan
    num = 0
    den = 0
    for i in a:
        if i != 0.0 and not np.isnan(i):
            num += 1
        if not np.isnan(i):
            den += 1

    if den == 0:
        return 0
    
    return num / den

def ntseqtoCodonSeq(seq, condition, add_cond=True):
    """
    Convert nucleotide sequence to codon sequence
    """
    codon_seq = []
    # cut seq to remove last codon if not complete
    for i in range(0, len(seq), 3):
        # check if codon is complete
        if len(seq[i:i+3]) == 3:
            codon_seq.append(seq[i:i+3])

    codon_seq = [codon_to_id[codon] for codon in codon_seq]

    if add_cond:
        # prepend condition token
        codon_seq = [condition_values[condition]] + codon_seq

    return codon_seq

def sequenceLength(a):
    '''
    returns the length of the sequence
    '''
    a = [float(k) for k in a]
    return len(a)

def mergeAnnotations(annots):
    '''
    merge the annotations for the same gene
    '''
    # get the annotations
    annots = [a[1:-1].split(', ') for a in annots]
    annots = [[float(k) for k in a] for a in annots]

    # merge the annotations
    merged_annots = []
    for i in range(len(annots[0])):
        # get the ith annotation for all the transcripts, only non zero and non nan
        ith_annots = [a[i] for a in annots if a[i] != 0.0 and not np.isnan(a[i])]
        # take the mean of the ith annotation
        ith_mean = np.mean(ith_annots)
        merged_annots.append(ith_mean)

    return merged_annots

def uniqueGenes(df):
    # add sequence length column
    df['sequence_length'] = df['annotations'].apply(sequenceLength)

    unique_genes = list(df['gene'].unique())

    # iterate through each gene, and choose the longest transcript, for the annotation, merge the annotations
    for gene in unique_genes:
        # get the df for the gene
        df_gene = df[df['gene'] == gene]
        if len(df_gene) > 1:
            # get the transcript with the longest sequence
            df_gene = df_gene.sort_values('sequence_length', ascending=False)
            # chosen transcript
            chosen_transcript = df_gene['transcript'].values[0]
            other_transcripts = df_gene['transcript'].values[1:]
            # merge the annotations
            annotations = df_gene['annotations'].values
            merged_annotations = mergeAnnotations(annotations)
            # drop the other transcripts from the df
            df = df[~df['transcript'].isin(other_transcripts)]

            # change the annotations for the chosen transcript
            df.loc[df['transcript'] == chosen_transcript, 'annotations'] = str(merged_annotations)

    # drop sequence length column
    df = df.drop(columns=['sequence_length'])

    assert len(df['gene'].unique()) == len(df['gene'])
    assert len(df['transcript'].unique()) == len(df['transcript'])
    assert len(df['transcript']) == len(df['gene'])

    return df

def removeFullGenes(df_mouse, df_full):
    '''
    remove the genes that are already in df_full
    '''
    # gene transcript dict
    tr_unique_full = list(df_full['transcript'].unique())
    transcripts_full_sans_version = [tr.split('.')[0] for tr in tr_unique_full]

    df_mouse_tr_sans_version = [tr.split('.')[0] for tr in df_mouse['transcript']]
    df_mouse_genes = list(df_mouse['gene'])

    mouse_tg_dict = dict(zip(df_mouse_tr_sans_version, df_mouse_genes))

    # for each transcript in df_full, remove the gene from df_mouse
    for tran in transcripts_full_sans_version:
        mouse_gene_for_full_transcript = mouse_tg_dict[tran]
        # remove the gene from df_mouse
        df_mouse = df_mouse[df_mouse['gene'] != mouse_gene_for_full_transcript]

    # get one transcript per gene, choose the longest one
    df_mouse['sequence_length'] = df_mouse['sequence'].apply(seqLenMouse)
    df_mouse = df_mouse.sort_values('sequence_length', ascending=False).drop_duplicates('gene')
    df_mouse = df_mouse.drop(columns=['sequence_length'])

    return df_mouse

In [None]:
# model parameters
annot_thresh = 0.3
longZerosThresh_val = 20
percNansThresh_val = 0.05
d_model_val = 512
n_layers_val = 3
n_heads_val = 4
dropout_val = 0.1
lr_val = 1e-4
batch_size_val = 1
loss_fun_name = '4L' # 4L, 5L

In [None]:
# model name and output folder path
model_name1 = '../checkpoints/XLNet-DH_S1'
model_name2 = '../checkpoints/XLNet-DH_S2'
model_name3 = '../checkpoints/XLNet-DH_S3'
model_name4 = '../checkpoints/XLNet-DH_S4'
model_name42 = '../checkpoints/XLNet-DH_S42'

class XLNetDH(XLNetForTokenClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = torch.nn.Linear(d_model_val, 2, bias=True)

config = XLNetConfig(vocab_size=71, pad_token_id=70, d_model = d_model_val, n_layer = n_layers_val, n_head = n_heads_val, d_inner = d_model_val, num_labels = 1, dropout=dropout_val) # 64*6 tokens + 1 for padding
model = XLNetDH(config)

# load model best weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# load model from the saved model
model1 = model.from_pretrained(model_name1 + "/best_model")
model2 = model.from_pretrained(model_name2 + "/best_model")
model3 = model.from_pretrained(model_name3 + "/best_model")
model4 = model.from_pretrained(model_name4 + "/best_model")
model42 = model.from_pretrained(model_name42 + "/best_model")

models_list = [model1, model2, model3, model4, model42]

for model_chosen in models_list:
    model_chosen.to(device)
    model_chosen.eval()

print("Loaded all the models")

In [None]:
depr_folder = '../data/processed/' # depr data folder

ctrl_depr_path = depr_folder + 'CTRL.csv'
ile_path = depr_folder + 'ILE.csv'
leu_path = depr_folder + 'LEU.csv'
val_path = depr_folder + 'VAL.csv'
leu_ile_path = depr_folder + 'LEU_ILE.csv'
leu_ile_val_path = depr_folder + 'LEU_ILE_VAL.csv'
liver_path = depr_folder + 'LIVER.csv'

In [None]:
# load the control liver data
df_liver = pd.read_csv(liver_path)
df_liver['condition'] = 'CTRL'

# load ctrl_aa data
df_ctrl_depr = pd.read_csv(ctrl_depr_path)
df_ctrl_depr['condition'] = 'CTRL'

# add to the liver data the genes from ctrl depr which are not in liver
tr_liver = df_liver['transcript'].unique()
tr_ctrl_depr = df_ctrl_depr['transcript'].unique()
tr_to_add = [g for g in tr_liver if g not in tr_ctrl_depr]

df_liver = df_liver[df_liver['transcript'].isin(tr_to_add)]

# df ctrldepr without liver intersection
df_ctrldepr_liver = pd.concat([df_liver, df_ctrl_depr], axis=0)

# unique genes
df_ctrldepr_liver = uniqueGenes(df_ctrldepr_liver)

# get ctrl gene, transcript tuple pairs from the df_ctrldepr_liver
ctrl_genes_transcripts = list(zip(df_ctrldepr_liver['gene'], df_ctrldepr_liver['transcript']))
# make a list of lists
ctrl_genes_transcripts = [[gene, transcript] for gene, transcript in ctrl_genes_transcripts]

print("CTRL Done")

# other conditions
df_ile = pd.read_csv(ile_path)
df_ile['condition'] = 'ILE'
# unique genes
df_ile = uniqueGenes(df_ile)
# only choose those genes+transcripts that are in ctrl_depr_liver
# iterate through the df_ile and choose those genes that are in ctrl_genes_transcripts
for index, row in df_ile.iterrows():
    if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
        df_ile.drop(index, inplace=True) 

print("ILE Done")

df_leu = pd.read_csv(leu_path)
df_leu['condition'] = 'LEU'
# unique genes
df_leu = uniqueGenes(df_leu)
# choose those transcripts that are in ctrl_depr_liver
for index, row in df_leu.iterrows():
    if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
        df_leu.drop(index, inplace=True)

print("LEU Done")

df_val = pd.read_csv(val_path)
df_val['condition'] = 'VAL'
# unique genes
df_val = uniqueGenes(df_val)
# choose those transcripts that are in ctrl_depr_liver
for index, row in df_val.iterrows():
    if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
        df_val.drop(index, inplace=True)

print("VAL Done")

df_leu_ile = pd.read_csv(leu_ile_path)
df_leu_ile['condition'] = 'LEU_ILE'
# unique genes
df_leu_ile = uniqueGenes(df_leu_ile)
# choose those transcripts that are in ctrl_depr_liver
for index, row in df_leu_ile.iterrows():
    if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
        df_leu_ile.drop(index, inplace=True)

print("LEU_ILE Done")

df_leu_ile_val = pd.read_csv(leu_ile_val_path)
df_leu_ile_val['condition'] = 'LEU_ILE_VAL'
# unique genes
df_leu_ile_val = uniqueGenes(df_leu_ile_val)
# choose those transcripts that are in ctrl_depr_liver
for index, row in df_leu_ile_val.iterrows():
    if [row['gene'], row['transcript']] not in ctrl_genes_transcripts:
        df_leu_ile_val.drop(index, inplace=True)

print("LEU_ILE_VAL Done")

df_full = pd.concat([df_ctrldepr_liver, df_ile, df_leu, df_val, df_leu_ile, df_leu_ile_val], axis=0) # liver + ctrl depr + ile + leu + val + leu ile + leu ile val

df_full.columns = ['index_val', 'gene', 'transcript', 'sequence', 'annotations', 'perc_non_zero_annots', 'condition']

# drop index_val column
df_full = df_full.drop(columns=['index_val'])

In [None]:
# get the gene, transcript, sequence data 
df_set1 = df_full[['gene', 'transcript', 'sequence']]

# drop duplicates
df_set1 = df_set1.drop_duplicates()

len_sf_set1 = len(df_set1)

# replicate this 6 times, and add condition column
df_set1 = pd.concat([df_set1]*6, ignore_index=True)

# add condition column
cond_col = []
for i in range(6):
    for j in range(len_sf_set1):
        cond_col.append(conditions_list[i])

df_set1['condition'] = cond_col

print(df_set1)

# save this dataframe
df_set1.to_csv("../data/plabel/df_set1.csv", index=False)

In [None]:
final_mean_preds_list_set1 = []
final_stds_preds_list_set1 = []
final_conditions_list_set1 = []
final_genes_list_set1 = []
final_transcripts_list_set1 = []
final_sequence_list_set1 = []
final_annots_list_set1 = []

# load genes file 
sequences_df_set1 = list(df_set1['sequence'])
genes_df_set1 = list(df_set1['gene'])
transcripts_df_set1 = list(df_set1['transcript'])
conditions_df_set1 = list(df_set1['condition'])

# make predictions on all the sequences, using the five models
for j in tqdm(range(len(sequences_df_set1))):
    X = sequences_df_set1[j]
    X = X[1:-1].split(', ')
    X = [int(k) for k in X]

    # prepend condition token
    X = [condition_values[conditions_df_set1[j]]] + X

    X = np.asarray(X)
    X = torch.from_numpy(X).long()

    preds_list_sample = []

    with torch.no_grad():
        for model_chosen in models_list:
            y_pred = model_chosen(X.unsqueeze(0).to(device).to(torch.int32))
            y_pred = torch.sum(y_pred["logits"], dim=2)
            y_pred = y_pred.squeeze(0)

            # remove first token which is condition token
            y_pred = y_pred[1:]

            preds_list_sample.append(y_pred.detach().cpu().numpy())

    # add preds_list_sample to genes_file 
    preds_list_sample = np.asarray(preds_list_sample)
    # take mean and std of the predictions over each codon
    mean_preds = np.mean(preds_list_sample, axis=0)
    stds_preds = np.std(preds_list_sample, axis=0)

    # print(mean_preds.shape, stds_preds.shape)

    # check if this transcript has an annotation in df_full with this condition
    df_full_sample = df_full[df_full['condition'] == conditions_df_set1[j]]
    df_full_sample = df_full_sample[df_full_sample['transcript'] == transcripts_df_set1[j]]

    if len(df_full_sample) > 0:
        # substitute the mean_preds with the annotations if they are not nan or 0
        annots_sample = df_full_sample['annotations'].values[0]
        final_annots_list_set1.append(annots_sample)
    else:
        final_annots_list_set1.append('NA')

    final_mean_preds_list_set1.append(mean_preds)
    final_stds_preds_list_set1.append(stds_preds)
    final_conditions_list_set1.append(conditions_df_set1[j])
    final_genes_list_set1.append(genes_df_set1[j])
    final_transcripts_list_set1.append(transcripts_df_set1[j])
    final_sequence_list_set1.append(sequences_df_set1[j])

# create a dataframe with the final predictions
df_final_preds = pd.DataFrame({'gene': final_genes_list_set1, 'transcript': final_transcripts_list_set1, 'sequence': final_sequence_list_set1, 'mean_preds': final_mean_preds_list_set1, 'stds_preds': final_stds_preds_list_set1, 'condition': final_conditions_list_set1, 'annotations': final_annots_list_set1})


In [None]:
# load training and testing original sets
df_test_orig = pd.read_csv('../data/orig/test.csv')
df_val_orig = pd.read_csv('../data/orig/val.csv')

orig_test_genes = list(set(list(df_test_orig['gene'])))
orig_test_transcripts = list(set(list(df_test_orig['transcript'])))

orig_val_genes = list(set(list(df_val_orig['gene'])))
orig_val_transcripts = list(set(list(df_val_orig['transcript'])))

# remove those for test genes and transcripts
df_final_preds = df_final_preds[~df_final_preds['gene'].isin(orig_test_genes)]
df_final_preds = df_final_preds[~df_final_preds['transcript'].isin(orig_test_transcripts)]

# remove those for val genes and transcripts
df_final_preds = df_final_preds[~df_final_preds['gene'].isin(orig_val_genes)]
df_final_preds = df_final_preds[~df_final_preds['transcript'].isin(orig_val_transcripts)]

annots_imputed = []

# go through each of the samples in the df_full 
# and impute the predictions
for i in tqdm(range(len(df_final_preds))):
    # get the condition
    ground_truth_sample = df_final_preds['annotations'].iloc[i]
    mean_preds_sample = df_final_preds['mean_preds'].iloc[i]

    pred_sample = pseudolabel(ground_truth_sample, mean_preds_sample)

    annots_imputed.append(pred_sample)

df_final_preds['annotations'] = annots_imputed

# drop those that have NA in the annotations
df_final_preds = df_final_preds[df_final_preds['annotations'] != 'NA']

# coverage threshold
df_final_preds['coverage_mod'] = df_final_preds['annotations'].apply(coverageMod)
df_final_preds = df_final_preds[df_final_preds['coverage_mod'] >= threshold]

# add the longest zero sequence length to the df
df_final_preds['longest_zero_seq_length_annotation'] = df_final_preds['annotations'].apply(longestZeroSeqLength)
# add the number of nans to the df
df_final_preds['perc_nans_annotation'] = df_final_preds['annotations'].apply(percNans)

# apply the threshold for the longest zero sequence length
df_final_preds = df_final_preds[df_final_preds['longest_zero_seq_length_annotation'] <= longZerosThresh]

# apply the threshold for the number of nans
df_final_preds = df_final_preds[df_final_preds['perc_nans_annotation'] <= percNansThresh]

print("Added Thresholds on all the factors")

# for all the sequences in a condition that is not CTRL, add their respective CTRL sequence to them
sequences_ctrl = []
annotations_list = list(df_final_preds['annotations'])
condition_df_list = list(df_final_preds['condition'])
transcripts_list = list(df_final_preds['transcript'])

for i in tqdm(range(len(condition_df_list))):
    try:
        if condition_df_list[i] != 'CTRL':
            # find the respective CTRL sequence for the transcript
            ctrl_sequence = df_final_preds[(df_final_preds['transcript'] == transcripts_list[i]) & (df_final_preds['condition'] == 'CTRL')]['annotations'].iloc[0]
            sequences_ctrl.append(ctrl_sequence)
        else:
            sequences_ctrl.append(annotations_list[i])
    except:
        sequences_ctrl.append('NA')

# add the sequences_ctrl to the df
print(len(sequences_ctrl), len(annotations_list))
df_final_preds['ctrl_sequence'] = sequences_ctrl

# remove those rows where the ctrl_sequence is NA
df_final_preds = df_final_preds[df_final_preds['ctrl_sequence'] != 'NA']

# sanity check for the ctrl sequences
# get the ds with only condition as CTRL
df_ctrl_full = df_final_preds[df_final_preds['condition'] == 'CTRL']
ctrl_sequences_san = list(df_ctrl_full['annotations'])
ctrl_sequences_san2 = list(df_ctrl_full['ctrl_sequence'])

for i in range(len(ctrl_sequences_san)):
    assert ctrl_sequences_san[i] == ctrl_sequences_san2[i]

print("Sanity Checked")

out_train_path = '../data/plabel/plabel_train_trial.csv'
df_final_preds.to_csv(out_train_path)