In [1]:
import pysam
import pandas as pd

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

from encoding_utils import sequence_encoders

%load_ext autoreload
%autoreload 2

In [2]:
import helpers.models as models            #model architecture
import helpers.train_eval as train_eval    #train and evaluation
import helpers.misc as misc                #miscellaneous functions

In [3]:
from models.spec_dss import DSSResNet, DSSResNetEmb, SpecAdd

In [4]:
data_dir = '/s/project/mll/sergey/effect_prediction/MLM/fasta/'

fasta_fa = data_dir + '240_mammals/240_mammals.fa'
fasta_fai = data_dir + '240_mammals/240_mammals.fa.fai'

In [5]:
class SeqDataset(Dataset):
    
    def __init__(self, fasta_fa, seq_df, transform):
        
        self.fasta = pysam.FastaFile(fasta_fa)
        
        self.seq_df = seq_df
        self.transform = transform
        
    def __len__(self):
        
        return len(self.seq_df)
    
    def __getitem__(self, idx):
        
        seq = self.fasta.fetch(seq_df.iloc[idx].seq_name)
        
        species_label = seq_df.iloc[idx].species_label
        
        masked_sequence, target_labels_masked, target_labels, mask, _ = self.transform(seq, motifs = {})
        
        masked_sequence = (masked_sequence, species_label)
        
        return masked_sequence, target_labels_masked, target_labels, mask
    
    def close(self):
        self.fasta.close()

In [6]:
seq_df = pd.read_csv(fasta_fai, header=None, sep='\t', usecols=[0], names=['seq_name'])
seq_df['species_name'] = seq_df.seq_name.apply(lambda x:x.split(':')[1])

In [7]:
species = seq_df.species_name.drop_duplicates().sort_values().tolist()
species_encoding = {species_name:species_label for species_label, species_name in enumerate(species)}
seq_df['species_label'] = seq_df.species_name.map(species_encoding)

In [8]:
N_train, N_test = 10000, 5000

train_seq, test_seq = seq_df.iloc[:N_train],seq_df.iloc[N_train:N_train+N_test]

In [9]:
seq_transform = sequence_encoders.SequenceDataEncoder(seq_len=200, total_len=200, mask_rate=0.15, split_mask=True,)

In [10]:
train_dataset = SeqDataset(fasta_fa, train_seq, transform = seq_transform)
test_dataset = SeqDataset(fasta_fa, test_seq, transform = seq_transform)

In [11]:
train_dataloader = DataLoader(dataset = train_dataset, batch_size = 16, num_workers = 4, collate_fn = None, shuffle = None)
test_dataloader = DataLoader(dataset = test_dataset, batch_size = 16, num_workers = 4, collate_fn = None, shuffle = None)

In [12]:
species_encoder = SpecAdd(embed = True, encoder = 'label', d_model = 128)

In [13]:
model = DSSResNetEmb(d_input = 5, d_output = 5, d_model = 128, n_layers = 4, 
                     dropout = 0.1, embed_before = True, species_encoder = species_encoder)

In [14]:
criterion = torch.nn.CrossEntropyLoss(reduction="mean")

x, y_mask, y, _  = next(iter(train_dataloader))

In [15]:

logits, embeddings = model(x[0],x[1])


criterion(logits, y_mask)

torch.argmax(logits, dim=1)

  return einsum('chn,hnl->chl', W, S).float(), state                   # [C H L]


tensor([[2, 4, 2,  ..., 4, 4, 0],
        [1, 3, 1,  ..., 4, 0, 0],
        [2, 0, 0,  ..., 0, 0, 0],
        ...,
        [1, 1, 2,  ..., 4, 4, 1],
        [1, 1, 1,  ..., 4, 4, 0],
        [1, 1, 2,  ..., 1, 4, 2]])

In [3]:
fasta = pysam.FastaFile(train_fasta)

In [12]:
seq = fasta.fetch('ENST00000318911.5_utr3_6_0_chr8_144097337_f:Acinonyx_jubatus:LLWD01000002.1:189')

In [14]:
seq = 'CCCTGCCCAACGTCTGCTTGCCGTCTTGCCTGAACAGGCCCGCAAGCCAAGGAGCCACCCTGGACCTGTTCAGGCCTCAGCTGGCCCGCTTGGCCAAGCTCCTCTTTCTTTGGGACAAGAGGGAAAGGGGCAAGAGACCAGGTTCTAGCTCCAGATCCTTCAGCACCCATCATGGAAATAAATTAAGTT'

In [39]:
encoder = sequence_encoders.SequenceDataEncoder(seq_len=200,
                total_len=200,
                mask_rate=0.15,
                split_mask=True,)

In [41]:
masked_sequence, target_labels_masked, target_labels, mask, motif_mask = encoder.__call__(seq, motifs=None)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 1
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 0 0
 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 1 0 0 1 0 1 0 0 0 0 0 0 1 0 0 0 0 0 1 0 0
 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0]


In [9]:
target_labels_masked.shape

torch.Size([23])

In [10]:
target_labels_masked

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100])

In [11]:
from utils.sequence_operations import *

In [32]:
seq_labels, seq_one_hot = one_hot_encode(seq)

In [37]:
masked_seq, mask = random_masking(seq_one_hot,
                            mask_rate=0.15,
                            split_mask=False,
                            frame=0)

In [38]:
mask

array([0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1])