In [1]:
import numpy as np
import pandas as pd
import pickle

import re

import os
import gc

import pysam

import torch
from torch.utils.data import DataLoader, IterableDataset

from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from encoding_utils import sequence_encoders, sequence_utils

import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.metrics import MaskedAccuracy


from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [3]:
datadir = '/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [4]:
#motifs = pd.read_csv(datadir + 'motif_predictions/motifs.csv').motif.unique()

#selected_motifs = {motif:motif_idx+1 for motif_idx,motif in enumerate(motifs)} #{'ACCTG':1, 'GGTAA':2}

In [5]:
class SeqDataset(IterableDataset):
    
    def __init__(self, fasta_fa, seq_df, transform, chunk_len=1000000000, overlap_bp=0):
        
        if fasta_fa:
            self.fasta = pysam.FastaFile(fasta_fa)
        else:
             self.fasta = None

        self.transform = transform
        self.seq_df = seq_df
        
        self.start = 0
        self.end = len(self.seq_df)

        self.chunk_len = chunk_len
        self.overlap_bp = overlap_bp
        
    def __len__(self):
        return len(self.seq_df)
                
    def __iter__(self):
        
        #worker_total_num = torch.utils.data.get_worker_info().num_workers
        #worker_id = torch.utils.data.get_worker_info().id
        
        for seq_idx in range(self.start,self.end):
            
            if self.fasta:
                seq = self.fasta.fetch(self.seq_df.iloc[seq_idx].seq_name).upper()
            else:
                seq = self.seq_df.iloc[seq_idx].seq.upper()
    
            species_label = self.seq_df.iloc[seq_idx].species_label
            
            seq = seq.replace('-','')
            
            chunks, left_shift_last_chunk = misc.get_chunks(seq, self.chunk_len, self.overlap_bp)

            for chunk_idx,seq_chunk in enumerate(chunks):

                masked_sequence, target_labels_masked, target_labels, _, _ = self.transform(seq_chunk, motifs = {})
        
                masked_sequence = (masked_sequence, species_label)

                chunk_meta = {'seq_name':self.seq_df.iloc[seq_idx].seq_name,
                             'seq':seq_chunk,
                             'left_shift':left_shift_last_chunk if chunk_idx==len(chunks)-1 else 0}
                        
                yield masked_sequence, target_labels_masked, target_labels, chunk_meta

                        
    def close(self):
        self.fasta.close()

In [6]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f'\nCUDA device: {torch.cuda.get_device_name(0)}\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: NVIDIA A100-PCIE-40GB MIG 3g.20gb



In [7]:
gc.collect()
torch.cuda.empty_cache()

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

In [8]:
datadir

'/lustre/groups/epigenereg01/workspace/projects/vale/mlm/'

In [15]:
input_params = misc.dotdict({})

input_params.fasta = datadir + 'fasta/241_mammals.shuffled.fa'
#input_params.fasta = datadir + 'fasta/240_species/species/Homo_sapiens.fa'

input_params.species_list = datadir + '../MLM/fasta/240_species/240_species.txt'

input_params.output_dir = './test'

input_params.test = False
input_params.save_probs = True
input_params.mask_at_test = True

input_params.species_agnostic = True

input_params.seq_len = 5000
input_params.overlap_bp = 128

input_params.tot_epochs = 50
input_params.val_fraction = 0.1
input_params.train_splits = 4

input_params.save_at = []
input_params.validate_every = 1

input_params.d_model = 128
input_params.n_layers = 4
input_params.dropout = 0.

input_params.batch_size = 1
input_params.learning_rate = 1e-4
input_params.weight_decay = 0#5e-4

In [16]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])

seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')).apply(lambda x:x[-1] if len(x)==2 else x[1])

if input_params.test:
    seq_df = seq_df[seq_df.species_name=='Homo_sapiens']
else:
    seq_df = seq_df[seq_df.species_name!='Homo_sapiens']

species_encoding = pd.read_csv(input_params.species_list, header=None).squeeze().to_dict()

if not input_params.species_agnostic:
    species_encoding = {species:idx for idx,species in species_encoding.items()}
else:
    species_encoding = {species:0 for _,species in species_encoding.items()}
    
species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']

seq_df['species_label'] = seq_df.species_name.map(species_encoding)

#seq_df = seq_df.sample(frac = 1., random_state = 1) #DO NOT SHUFFLE, otherwise too slow

In [17]:
if not input_params.test:

    #Train and Validate

    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate = 0.15, split_mask = True)

    N_train = int(len(seq_df)*(1-input_params.val_fraction))
    train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]

    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]

    train_dataset = SeqDataset(input_params.fasta, train_df, transform = seq_transform, chunk_len=input_params.seq_len, overlap_bp=input_params.overlap_bp)
    
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = input_params.batch_size, num_workers = 1, worker_init_fn=misc.worker_init_fn, collate_fn = None, shuffle = False)

    test_dataset = SeqDataset(input_params.fasta, test_df, transform = seq_transform, chunk_len=input_params.seq_len)
    
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 1, worker_init_fn=misc.worker_init_fn, collate_fn = None, shuffle = False)

elif input_params.save_probs:

    if input_params.mask_at_test:
        seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
    else:
        seq_transform = sequence_encoders.PlainOneHot(frame = 0, padding = 'none')

    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform, chunk_len=input_params.seq_len)
    
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, num_workers = 1, collate_fn = None, worker_init_fn=misc.worker_init_fn, shuffle = False)

else:

    #Test

    seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = input_params.seq_len, total_len = input_params.seq_len,
                                                      mask_rate = 0.15, split_mask = True, frame = 0)

    test_dataset = SeqDataset(input_params.fasta, seq_df, transform = seq_transform)
    test_dataloader = DataLoader(dataset = test_dataset, batch_size = input_params.batch_size, num_workers = 1, worker_init_fn=misc.worker_init_fn, collate_fn = None, shuffle = False)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['train_fold'] = train_fold[:N_train]


In [18]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = input_params.d_model)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = input_params.d_model, n_layers = input_params.n_layers, 
                     dropout = input_params.dropout, embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = input_params.learning_rate, weight_decay = input_params.weight_decay)

In [19]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [20]:
get_n_params(model)

1006853

In [34]:
last_epoch = 0

if input_params.model_weight:

    if torch.cuda.is_available():
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

In [35]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [None]:
from IPython.display import clear_output

clear_output()

#from utils.misc import print    #print function that displays time

if not input_params.test:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        train_dataset.end = len(train_dataset.seq_df)

        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device,
                            silent = False)

        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, *_ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')
            
else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings, motif_probas =  train_eval.model_eval(model, optimizer, test_dataloader, device, 
                                                          get_embeddings = input_params.get_embeddings, 
                                                          silent = False)
    
    

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    if input_params.get_embeddings:
        
        os.makedirs(input_params.output_dir, exist_ok = True)

        with open(input_params.output_dir + '/embeddings.pickle', 'wb') as f:
            #test_embeddings = np.vstack(test_embeddings)
            #np.save(f,test_embeddings)
            pickle.dump(test_embeddings,f)
            #pickle.dump(seq_df.seq_name.tolist(),f)
            
print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 11: Test/Inference...


acc: 0.96, masked acc: 0.44, loss: 1.194:  18%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     | 15623/84822 [06:45<30:12, 38.17it/s]

In [38]:
input_params.output_dir + '/embeddings.npy'

'/lustre/groups/epigenereg01/workspace/projects/vale/MLM/perbase_pred/embeddings/rna/Species-agnostic//embeddings.npy'

In [None]:
seq_transform = sequence_encoders.RollingMasker(mask_stride = 50, frame = 0)
     
val_fasta = datadir + 'fasta/240_mammals/240_mammals.shuffled.fa'
val_df = pd.read_csv(val_fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
val_df['species_name'] = val_df.seq_name.apply(lambda x:x.split(':')[1])
val_df['species_label'] = 181

val_df = val_df[val_df.species_name=='Pan_troglodytes'].sample(n=1000, random_state=1)

val_dataset = SeqDataset(val_fasta, val_df, transform = seq_transform)
val_dataloader = DataLoader(dataset = val_dataset, batch_size = 1, num_workers = 1, collate_fn = None, shuffle = False)

In [19]:
from helpers.temperature_scaling import ModelWithTemperature

scaled_model = ModelWithTemperature(model)
scaled_model.set_temperature(val_dataloader);

  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]


Before temperature - NLL: 1.664, ECE: 0.248
Optimal temperature: 2.710
After temperature - NLL: 1.414, ECE: 0.041


In [22]:
input_params.temperature = scaled_model.temperature