In [1]:
import numpy as np
import pandas as pd

import pickle
import os
import gc

import pysam

import torch
from torch.utils.data import DataLoader, Dataset

from tqdm.notebook import tqdm

%load_ext autoreload
%autoreload 2

In [2]:
from encoding_utils import sequence_encoders

import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions
from helpers.metrics import MaskedAccuracy

from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [3]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.seq_df = seq_df
        self.transform = transform
        
    def __len__(self):
        
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(seq_df.iloc[idx].seq_name).upper()
                
        species_label = seq_df.iloc[idx].species_label
                
        masked_sequence, target_labels_masked, target_labels, mask, _ = self.transform(seq, motifs = {})
        
        masked_sequence = (masked_sequence, species_label)
        
        return masked_sequence, target_labels_masked, target_labels
    
    def close(self):
        self.fasta.close()

In [4]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('\nCUDA device: GPU\n')
else:
    device = torch.device('cpu')
    print('\nCUDA device: CPU\n')
    #raise Exception('CUDA is not found')


CUDA device: GPU



In [5]:
gc.collect()
torch.cuda.empty_cache()

In [27]:
input_params = misc.dotdict({})

input_params.fasta = '/s/project/mll/sergey/effect_prediction/MLM/fasta/240_mammals/240_mammals.shuffled.fa'
input_params.species_list = '/s/project/mll/sergey/effect_prediction/MLM/fasta/240_mammals/240_species.txt'

input_params.tot_epochs = 50

input_params.output_dir = './test'

input_params.train = True
input_params.val_fraction = 0.1

input_params.train_splits = 4

input_params.save_at = []
input_params.validate_every = 1

In [19]:
seq_df = pd.read_csv(input_params.fasta + '.fai', header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])

#seq_df['seq_len'] = seq_df.seq_name.apply(lambda x:int(x.split(':')[-1]))
#seq_df = seq_df[seq_df.seq_len>60]

species_encoding = pd.read_csv(input_params.species_list, header=None).squeeze().to_dict()
species_encoding = {species:idx for idx,species in species_encoding.items()}
species_encoding['Homo_sapiens'] = species_encoding['Pan_troglodytes']

seq_df['species_label'] = seq_df.species_name.map(species_encoding)

#seq_df = seq_df.sample(frac = 1., random_state = 1) #DO NOT SHUFFLE, otherwise too slow

In [8]:
seq_df = seq_df.iloc[:2000]

In [22]:
seq_transform = sequence_encoders.SequenceDataEncoder(seq_len = 2000, total_len = 2000, 
                                                      mask_rate = 0.15, split_mask = True)

In [23]:
if input_params.train:
    
    N_train = int(len(seq_df)*(1-input_params.val_fraction))       
    train_df, test_df = seq_df.iloc[:N_train], seq_df.iloc[N_train:]
                  
    train_fold = np.repeat(list(range(input_params.train_splits)),repeats = N_train // input_params.train_splits + 1 )
    train_df['train_fold'] = train_fold[:N_train]

    train_dataset = SeqDataset(input_params.fasta, train_df, transform = seq_transform)
    train_dataloader = DataLoader(dataset = train_dataset, batch_size = 512, num_workers = 16, collate_fn = None, shuffle = None)

else:
                  
    test_df = seq_df
                  
test_dataset = SeqDataset(input_params.fasta, test_df, transform = seq_transform)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 512, num_workers = 16, collate_fn = None, shuffle = None)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df['train_fold'] = train_fold[:N_train]


In [24]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = 128, n_layers = 4, 
                     dropout = 0., embed_before = True, species_encoder = species_encoder)

model = model.to(device) 

model_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.Adam(model_params, lr = 1e-4, weight_decay = 5e-4)

In [28]:
last_epoch = 0

if input_params.model_weight:

    if torch.cuda.is_available():
        #load on gpu
        model.load_state_dict(torch.load(input_params.model_weight))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight))
    else:
        #load on cpu
        model.load_state_dict(torch.load(input_params.model_weight, map_location=torch.device('cpu')))
        if input_params.optimizer_weight:
            optimizer.load_state_dict(torch.load(input_params.optimizer_weight, map_location=torch.device('cpu')))

    last_epoch = int(input_params.model_weight.split('_')[-3]) #infer previous epoch from input_params.model_weight

predictions_dir = os.path.join(input_params.output_dir, 'predictions') #dir to save predictions
weights_dir = os.path.join(input_params.output_dir, 'weights') #dir to save model weights at save_at epochs

if input_params.save_at:
    os.makedirs(weights_dir, exist_ok = True)

In [29]:
def metrics_to_str(metrics):
    loss, total_acc, masked_acc = metrics
    return f'loss: {loss:.4}, total acc: {total_acc:.3f}, masked acc: {masked_acc:.3f}'

In [30]:
#from utils.misc import print    #print function that displays time

if input_params.train:

    for epoch in range(last_epoch+1, input_params.tot_epochs+1):

        print(f'EPOCH {epoch}: Training...')

        train_dataset.seq_df = train_df[train_df.train_fold == (epoch-1) % input_params.train_splits]
        print(f'using train samples: {list(train_dataset.seq_df.index[[0,-1]])}')

        train_metrics = train_eval.model_train(model, optimizer, train_dataloader, device,
                            silent = False)

        print(f'epoch {epoch} - train, {metrics_to_str(train_metrics)}')

        if epoch in input_params.save_at: #save model weights

            misc.save_model_weights(model, optimizer, weights_dir, epoch)

        if input_params.val_fraction>0 and ( epoch==input_params.tot_epochs or
                            (input_params.validate_every and epoch%input_params.validate_every==0)):

            print(f'EPOCH {epoch}: Validating...')

            val_metrics, _ =  train_eval.model_eval(model, optimizer, test_dataloader, device,
                    silent = False)

            print(f'epoch {epoch} - validation, {metrics_to_str(val_metrics)}')

else:

    print(f'EPOCH {last_epoch}: Test/Inference...')

    test_metrics, test_embeddings =  train_eval.model_eval(model, optimizer, test_dataloader, device, 
                                                          save_embeddings = True, silent = False)

    print(f'epoch {last_epoch} - test, {metrics_to_str(test_metrics)}')

    os.makedirs(predictions_dir, exist_ok = True)

    with open(predictions_dir + '/test_embeddings.pickle', 'wb') as f:
        pickle.dump(test_embeddings, f)

print()
print(f'peak GPU memory allocation: {round(torch.cuda.max_memory_allocated(device)/1024/1024)} Mb')
print('Done')

EPOCH 1: Training...
using train samples: [0, 847255]


  0%|                                                                                                         …



epoch 1 - train, loss: 1.209, total acc: 0.860, masked acc: 0.430
EPOCH 1: Validating...


  0%|                                                                                                         …

KeyboardInterrupt: 

In [17]:
test_embeddings[0].shape

(512, 128, 2000)

In [None]:
test_embeddings

In [21]:
torch.cuda.max_memory_allocated(device)

44971320832

45.46423006057739

In [62]:
criterion = torch.nn.CrossEntropyLoss(reduction="mean")

metric = MaskedAccuracy()



In [86]:
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([-100,-100,-100], dtype=torch.long)
criterion(input, target)

tensor(nan, grad_fn=<NllLossBackward0>)

In [265]:
list(model.parameters())[-1].grad.shape

torch.Size([240, 128])

In [289]:
logits.min()

tensor(-2.6490, grad_fn=<MinBackward1>)

In [274]:
itr = iter(train_dataloader)

In [298]:
loss

tensor(1.3295, grad_fn=<NllLoss2DBackward0>)

In [330]:
f=0

for idx, data in enumerate(train_dataloader):
    
    

    (masked_sequence, species_label), targets_masked, targets = data

    logits, embeddings = model(masked_sequence, species_label)


    loss = criterion(logits, targets_masked)

    if torch.isnan(loss):
        print('Loss nan')
        break
        
    optimizer.zero_grad()
    
    #loss.register_hook(lambda grad: print(grad))
    
    loss.backward()

            #if max_abs_grad:
            #    torch.nn.utils.clip_grad_value_(model.parameters(), max_abs_grad)

    optimizer.step()
    
    for name,param in model.named_parameters():
        
        if param.isnan().sum():
            print(idx,name,param.shape,param.isnan().sum())
            f = 1
    if f:
        break
        
    


32 s4_layers.0.kernel.log_dt torch.Size([128, 2]) tensor(2)
32 s4_layers.0.kernel.Lambda torch.Size([64, 2]) tensor(2)
32 s4_layers.0.kernel.W torch.Size([2, 128, 64, 2]) tensor(2)


In [327]:
optimizer.step()

for name,param in model.named_parameters():
        
        if param.grad.max()>1e10:
            print(idx,name,param.shape,param.isnan().sum())
            f = 1

0 encoder.weight torch.Size([128, 5, 15]) tensor(0)
0 encoder.bias torch.Size([128]) tensor(0)
0 s4_layers.0.D torch.Size([1, 128]) tensor(0)
0 s4_layers.0.kernel.log_dt torch.Size([128, 2]) tensor(0)
0 s4_layers.0.kernel.Lambda torch.Size([64, 2]) tensor(0)
0 s4_layers.0.kernel.W torch.Size([2, 128, 64, 2]) tensor(0)
0 s4_layers.0.output_linear.0.weight torch.Size([256, 128]) tensor(0)
0 s4_layers.0.output_linear.0.bias torch.Size([256, 1]) tensor(0)
0 s4_layers.1.kernel.log_dt torch.Size([128, 2]) tensor(0)
0 s4_layers.1.kernel.Lambda torch.Size([64, 2]) tensor(0)
0 s4_layers.1.kernel.W torch.Size([2, 128, 64, 2]) tensor(0)
0 norms.0.weight torch.Size([128]) tensor(0)
0 norms.0.bias torch.Size([128]) tensor(0)
0 resnet_layer.0.conv1.weight torch.Size([128, 128, 7]) tensor(0)
0 resnet_layer.0.conv1.bias torch.Size([128]) tensor(0)
0 resnet_layer.0.bn1.weight torch.Size([128]) tensor(0)
0 resnet_layer.0.bn1.bias torch.Size([128]) tensor(0)
0 resnet_layer.0.conv2.weight torch.Size([128,

In [328]:
for name,param in model.named_parameters():
    print(name, param.grad.max())

encoder.weight tensor(3.3170e+30)
encoder.bias tensor(1.1146e+30)
s4_layers.0.D tensor(2.9693e+19)
s4_layers.0.kernel.log_dt tensor(7.0755e+31)
s4_layers.0.kernel.Lambda tensor(1.8249e+30)
s4_layers.0.kernel.W tensor(2.0817e+30)
s4_layers.0.output_linear.0.weight tensor(2.3028e+21)
s4_layers.0.output_linear.0.bias tensor(1.2498e+20)
s4_layers.1.D tensor(146.5361)
s4_layers.1.kernel.log_dt tensor(9.7293e+22)
s4_layers.1.kernel.Lambda tensor(3.9747e+20)
s4_layers.1.kernel.W tensor(9.7611e+19)
s4_layers.1.output_linear.0.weight tensor(3326.9224)
s4_layers.1.output_linear.0.bias tensor(415.3416)
s4_layers.2.D tensor(0.0226)
s4_layers.2.kernel.log_dt tensor(1.7157)
s4_layers.2.kernel.Lambda tensor(0.0852)
s4_layers.2.kernel.W tensor(0.0082)
s4_layers.2.output_linear.0.weight tensor(3386.7385)
s4_layers.2.output_linear.0.bias tensor(561.5174)
s4_layers.3.D tensor(321.0021)
s4_layers.3.kernel.log_dt tensor(12325.0898)
s4_layers.3.kernel.Lambda tensor(4139.0215)
s4_layers.3.kernel.W tensor(314

In [309]:
model.s4_layers[0].kernel.Lambda.grad

tensor([[-6.4045e+29, -1.2555e+30],
        [ 6.1969e+30, -1.8450e+31],
        [-9.0364e+30, -3.4811e+30],
        [ 1.6769e+32,  5.5701e+31],
        [-4.0765e+30, -3.2897e+31],
        [-2.4014e+32, -1.0221e+32],
        [-7.3908e+31, -6.7461e+31],
        [-1.3097e+32, -1.5515e+32],
        [ 7.0436e+31,  3.6052e+31],
        [-3.8711e+32,  1.7960e+32],
        [-4.5004e+31, -1.9284e+32],
        [-1.1951e+32,  3.9861e+32],
        [ 9.5394e+31, -3.6152e+31],
        [-1.8462e+32, -1.8041e+33],
        [-6.7417e+31,  4.2997e+31],
        [ 2.2285e+32,  1.3209e+31],
        [-6.3038e+31,  3.3885e+31],
        [ 1.5946e+32, -1.2147e+32],
        [-1.0125e+33,  4.7673e+32],
        [-7.3206e+31,  1.6535e+32],
        [-2.7117e+32,  2.6849e+32],
        [-4.5600e+31, -6.9701e+31],
        [ 3.9817e+32,  7.6175e+31],
        [ 4.6965e+32,  1.9915e+32],
        [-1.2158e+31, -3.4869e+31],
        [-7.9421e+31,  1.9504e+32],
        [ 7.8557e+31, -1.7855e+32],
        [ 1.1458e+32,  2.195

In [283]:
print(model)

DSSResNetEmb(
  (encoder): Conv1d(5, 128, kernel_size=(15,), stride=(1,), padding=(7,))
  (s4_layers): ModuleList(
    (0-3): 4 x DSS(
      (kernel): DSSKernel()
      (activation): GELU(approximate='none')
      (dropout): Dropout2d(p=0.1, inplace=False)
      (output_linear): Sequential(
        (0): TransposedLinear()
        (1): GLU(dim=-2)
      )
    )
  )
  (norms): ModuleList(
    (0-3): 4 x LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (dropouts): ModuleList(
    (0-3): 4 x Dropout2d(p=0.1, inplace=False)
  )
  (decoder): Conv1d(128, 5, kernel_size=(15,), stride=(1,), padding=(7,))
  (resnet_layer): Sequential(
    (0): L1Block(
      (conv1): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=Tr

In [None]:
        metric(preds, targets_masked).detach() # compute only on masked nucleotides
        metric(preds, targets).detach()

In [203]:
species_label

tensor([ 52,  14, 205, 181])

In [None]:
fasta = pysam.FastaFile(train_fasta)

In [None]:
seq = fasta.fetch('ENST00000318911.5_utr3_6_0_chr8_144097337_f:Acinonyx_jubatus:LLWD01000002.1:189')

In [None]:
seq = 'CCCTGCCCAACGTCTGCTTGCCGTCTTGCCTGAACAGGCCCGCAAGCCAAGGAGCCACCCTGGACCTGTTCAGGCCTCAGCTGGCCCGCTTGGCCAAGCTCCTCTTTCTTTGGGACAAGAGGGAAAGGGGCAAGAGACCAGGTTCTAGCTCCAGATCCTTCAGCACCCATCATGGAAATAAATTAAGTT'

In [None]:
encoder = sequence_encoders.SequenceDataEncoder(seq_len=200,
                total_len=200,
                mask_rate=0.15,
                split_mask=True,)

In [None]:
masked_sequence, target_labels_masked, target_labels, mask, motif_mask = encoder.__call__(seq, motifs=None)

In [None]:
target_labels_masked.shape

In [None]:
target_labels_masked

In [None]:
from utils.sequence_operations import *

In [None]:
seq_labels, seq_one_hot = one_hot_encode(seq)

In [None]:
masked_seq, mask = random_masking(seq_one_hot,
                            mask_rate=0.15,
                            split_mask=False,
                            frame=0)

In [None]:
mask