In [5]:
import numpy as np
import pandas as pd
import pickle

import re

import os
import gc
import sys

import pysam

import torch
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

sys.path.append('../../')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from encoding_utils import sequence_encoders

import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.metrics import MaskedAccuracy
from helpers.temperature_scaling import ModelWithTemperature

from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [7]:
datadir = '/s/project/mll/sergey/effect_prediction/MLM/'

In [9]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform):
        
        if fasta_fa:
            self.fasta = pysam.FastaFile(fasta_fa)
        else:
             self.fasta = None
        
        self.seq_df = seq_df
        self.transform = transform
        
    def __len__(self):
        
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        
        if self.fasta:
            seq = self.fasta.fetch(self.seq_df.iloc[idx].seq_name).upper()
        else:
            seq = self.seq_df.iloc[idx].seq.upper()
                
        species_label = self.seq_df.iloc[idx].species_label
        
        seq = seq.replace('-','')
                
        masked_sequence, target_labels_masked, target_labels, _, _ = self.transform(seq, motifs = {})
        
        masked_sequence = (masked_sequence, species_label)
            
        return masked_sequence, target_labels_masked, target_labels, seq
    
    def close(self):
        self.fasta.close()

In [10]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: CPU



  return torch._C._cuda_getDeviceCount() > 0


In [11]:
gc.collect()
torch.cuda.empty_cache()

In [12]:
input_params = misc.dotdict({})

input_params.species_list = datadir + 'fasta/240_mammals/240_species.txt'

input_params.output_dir = './test'

input_params.seq_len = 5000

input_params.tot_epochs = 100

input_params.d_model = 32
input_params.n_layers = 4
input_params.dropout = 0.

input_params.batch_size = 32
input_params.learning_rate = 1e-4
input_params.weight_decay = 0#5e-4

In [13]:
input_params.fasta = datadir + 'aligned/data/3_prime_UTR/366/ENST00000381365.4_utr3_2_0_chr17_4900534_f.fa'

In [14]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])

seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])

species_encoding = pd.read_csv(input_params.species_list, header=None).squeeze().to_dict()

species_encoding = {species:idx for idx,species in species_encoding.items()}
species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']

seq_df['species_label'] = seq_df.species_name.map(species_encoding)

test_df = seq_df[seq_df.species_name=='Homo_sapiens'] #for training and calibration
train_df = seq_df[seq_df.species_name!='Homo_sapiens'] #for probas

In [70]:
refseq = pysam.FastaFile(input_params.fasta).fetch(seq_df.seq_name.iloc[0]).upper()

In [71]:
utr_name = seq_df.iloc[0].seq_name.split(':')[0]

seqs = []

#generate all possible mutations
for seq_idx,ref in enumerate(refseq):
    for alt in 'ACGT':
        if alt!=ref:
            altseq = list(refseq)
            altseq[seq_idx] = alt
            seqs.append((f'{utr_name}:{seq_idx}:{ref}:{alt}',''.join(altseq)))

In [72]:
background_df = pd.DataFrame(seqs,columns=['seq_name','seq'])
background_df = background_df.sample(n=min(3000,len(seqs)), random_state=1) #background for embeddings

In [73]:
clinvar_fa = datadir + 'clinvar/clinvar.fa'

clinvar_df = pd.read_csv(clinvar_fa + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
clinvar_df['utr_name'] = clinvar_df.seq_name.apply(lambda x:x.split(':')[-1])
clinvar_seqs = clinvar_df[clinvar_df.utr_name==utr_name].seq_name

clinvar_fa = pysam.FastaFile(clinvar_fa)

clinvar_vars = []

for seq_name in clinvar_seqs:
    seq = clinvar_fa.fetch(seq_name)
    clinvar_id = seq_name.split(':')[0]
    clinvar_vars.append((f'{utr_name}:{clinvar_id}',seq))

clinvar_vars = pd.DataFrame(clinvar_vars,columns=['seq_name','seq'])

In [74]:
embedding_df = pd.concat([background_df, clinvar_vars])

embedding_df.loc[len(embedding_df)] = {'seq_name':f'{utr_name}:0:ref:ref', 'seq':refseq} #reference embedding

embedding_df['species_label'] = species_encoding['Pan_troglodytes']
embedding_df['species_name'] = 'Homo_sapiens'

In [284]:
if not input_params.test:
    
    #Train and Validate
    
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len, 
                                                      mask_rate = 0.15, split_mask = True)
    
    N_train = int(len(seq_df)*(1-input_params.val_fraction))       
    train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]
                  
    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]

    train_dataset = SeqDataset(input_params.fasta, train_df, transform = seq_transform)
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

    test_dataset = SeqDataset(input_params.fasta, test_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

elif input_params.get_embeddings or input_params.get_motif_acc:
    
    #Test and get sequence embeddings (MPRA)
    
    seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
        
    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)
    
else:
    
    #Test
    
    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len, 
                                                      mask_rate = 0.15, split_mask = True, frame = 0)
    
    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 2, collate_fn = None, shuffle = False)

In [271]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = input_params.d_model)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = input_params.d_model, n_layers = input_params.n_layers, 
                     dropout = input_params.dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = input_params.learning_rate, weight_decay = input_params.weight_decay)

In [272]:
last_epoch = 0

if input_params.model_weight:

    if torch.cuda.is_available():
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

In [273]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [285]:
from IPython.display import clear_output

clear_output()

#from utils.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device,
                            silent = False)

        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, *_ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            
    if input_params.temp_scaling:
        
        scaled_model = ModelWithTemperature(model)
        scaled_model.set_temperature(train_dataloader);
        model = scaled_model

else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings, motif_probas =  train_eval.model_eval(model, optimizer, test_dataloader, device, 
                                                          get_embeddings = input_params.get_embeddings, 
                                                          get_motif_acc = input_params.get_motif_acc, 
                                                          silent = False)
    
    

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    if input_params.get_embeddings:
        
        os.makedirs(input_params.output_dir, exist_ok = True)

        with open(input_params.output_dir + '/embeddings.pickle', 'wb') as f:
            #test_embeddings = np.vstack(test_embeddings)
            pickle.dump(test_embeddings,f)
            pickle.dump(seq_df.seq_name.tolist(),f)
            
    if input_params.get_motif_acc:
        
        os.makedirs(input_params.output_dir, exist_ok = True)

        with open(input_params.output_dir + '/probas.pickle', 'wb') as f:
            pickle.dump(motif_probas, f) #seq_index,motif,motif_start,avg_target_proba

        #seq_df.seq_name.to_csv(input_params.output_dir + '/seq_index.csv') #save index seqeunce matchin for 1st column of motif_probas 

print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 11: Test/Inference...


acc: 0.79, masked acc: 0.27, loss: 1.577: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18134/18134 [21:03<00:00, 14.35it/s]


epoch 11 - test, loss: 1.577, total acc: 0.790, masked acc: 0.267

peak GPU memory allocation: 26613 Mb
Done


In [275]:
from helpers.temperature_scaling import ModelWithTemperature

scaled_model = ModelWithTemperature(model)
scaled_model.set_temperature(val_dataloader);

Before temperature - NLL: 2.382, ECE: 0.418
Optimal temperature: 21.870
After temperature - NLL: 1.607, ECE: 0.055


In [276]:
model=scaled_model

In [274]:
seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
     
val_fasta = datadir + 'fasta/240_mammals/240_mammals.shuffled.fa'
val_df = pd.read_csv(val_fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
val_df['species_name'] = val_df.seq_name.apply(lambda x:x.split(':')[1])
val_df['species_label'] = 181

val_df = val_df[val_df.species_name=='Pan_troglodytes'].sample(n=1000, random_state=1)

val_dataset = SeqDataset(val_fasta, val_df, transform = seq_transform)
val_dataloader = DataLoader(dataset = val_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)