In [2]:
%reload_ext autoreload
%autoreload 2
import sys, os
sys.path.append(os.path.abspath(os.path.join('..')))

from dmsensei import DataModule, create_model, Dataset
from dmsensei.config import device
from dmsensei.core.callbacks import WandbFitLogger, KaggleLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch import Trainer
from dmsensei.config import device
import sys
import os
from lightning.pytorch.loggers import WandbLogger
import torch, wandb
import numpy as np


dm = DataModule(
    name=["ribonanza"],#, 'sarah_supermodel'],
    data_type=["dms", 'structure'],
    force_download=False,
    batch_size=32,
    num_workers=0,
    train_split=1024,
    valid_split=0,
    predict_split=0,
    overfit_mode=True,
    shuffle_valid=False,
    use_error=False,
    structure_padding_value=0,
)

dm.setup('fit')

Loading dataset from disk
Done!                            


In [3]:
from torch import nn, tensor
import torch
import numpy as np
import pandas as pd

seq2int = {'X': 0, 'A': 1, 'C': 2, 'G': 3, 'U': 4, 'S': 5, 'E': 6}
START_TOKEN = seq2int['S']
END_TOKEN = seq2int['E']
PADDING_TOKEN = seq2int['X']


In [16]:


def load_ct(path):
    return pd.read_csv(path, sep=' ', header=None, names=['i', 'j', 'p'])

# reduce weight initialization variance
import torch.nn.init as init

# Set the global scale for weight initialization
global_scale = 0.1
    

class ConvBlock(nn.Module):
    def __init__(
        self,
        params
    ):
        
        super().__init__()  
        self.conv = nn.Conv2d(params['num_heads'], params['num_heads'], 3, padding=1)
        self.batch2d = nn.BatchNorm2d(7)
        self.gelu = nn.GELU()
        self.gammas = nn.Parameter(torch.ones(params['num_heads']))
        
    def forward(self, structure):
        x = self.conv(structure)
        x = self.batch2d(x)
        x = self.gelu(x)
        x = x + structure
        x = x * self.gammas
        return x
    
    
class FeedForward(nn.Module):
    def __init__(
        self,
        params,
    ):
        super().__init__()  
        self.layer_norm = nn.LayerNorm(params['embed_dim'])
        self.linear1 = nn.Linear(params['embed_dim'], params['hidden_dim'])
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(params['hidden_dim'], params['embed_dim'])
        
    def forward(self, sequence):
        self.layer_norm(sequence)
        sequence = self.linear1(sequence)
        sequence = self.gelu(sequence)
        sequence = self.linear2(sequence)
        return sequence
        

class Encoder(nn.Module):
    def __init__(
        self,
        params,
    ):
        super().__init__()  
        self.self_attention = SelfAttention(params)
        self.feed_forward = FeedForward(params)
        
    def forward(self, sequence, structure):
        encoded_sequence, encoded_structure = self.self_attention(sequence, structure)
        sequence = sequence + encoded_sequence
        sequence = sequence + self.feed_forward(sequence)
        return sequence, encoded_structure
    

class Ribonanza(nn.Module):
    def __init__(
        self,
        params,
    ):
        super().__init__()  
        self.params = params
        self.ntokens = 7
        self.table_embedding = nn.Embedding(self.ntokens, params['embed_dim'])
        self.table_embedding.weight.data.normal_(mean=0.0, std=0.2)
        self.output_net = nn.Linear(params['embed_dim'], 2)
        params['table_embedding'] = self.table_embedding
        self.encoders_stack = nn.ModuleList([Encoder(params) for _ in range(params['num_encoders'])])
        # Initialize the weights with a reduced scale
        init.xavier_uniform_(self.table_embedding.weight, gain=global_scale)
        init.xavier_uniform_(self.output_net.weight, gain=global_scale)

    def forward(self, batch):
        sequence = self.embed_sequence_batch(batch)
        structure = self.embed_structure_batch(batch)
        for encoder in self.encoders_stack:
            sequence, structure = encoder(sequence, structure)
        x = self.output_net(sequence)
        return x
    
    def embed_sequence_batch(self, batch):
        out = []
        L = max(batch.get('length')) 
        for sequence, length in zip(batch.get('sequence'), batch.get('length')):
            out.append(torch.concat([
                tensor([START_TOKEN], dtype=torch.long).to(device),
                sequence[:length],
                tensor([END_TOKEN], dtype=torch.long).to(device),
                tensor([PADDING_TOKEN] * (L - length), dtype=torch.long).to(device)],
                ))
        
        return self.table_embedding(torch.stack(out))

    def embed_structure_batch(self, batch):
        structure = batch.get('structure')
        batch_size, L, _ = structure.shape
        embedded_matrix = torch.zeros((self.params['num_heads'], batch_size, L+2, L+2), dtype=torch.float32).to(device)
        embedded_matrix[:, :, 1:-1, 1:-1] = structure
        return embedded_matrix.permute(1, 0, 2, 3)
    
    
class ConvSE(nn.Module):
    def __init__(
        self,
        params
    ):
        super().__init__()  
        self.conv = nn.Conv2d(params['num_heads'], params['num_heads'], 3, padding=1)
        self.batch2d = nn.BatchNorm2d(params['num_heads'])
        self.gelu = nn.GELU()   
        
        # Squeeze and Excitation
        self.adaptive_average_pooling = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(params['num_heads'], params['num_heads'])
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(params['num_heads'], params['num_heads'])
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, structure):
        x = self.conv(structure)
        x = self.batch2d(x)
        # SE
        y = self.adaptive_average_pooling(x)
        y = y.view(-1, params['num_heads']) # unsure
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y)
        y = y.view(-1, params['num_heads'], 1, 1) # unsure 
        x = x * y
        # end SE
        x = self.gelu(x)
        x = x + structure
        return x
    
class SelfAttention(nn.Module):
    def __init__(
        self,
        params,
    ):
        super().__init__()  
        self.convSE = ConvSE(params)
        self.pos_encoding = DynamicPositionalEncoding(params)
        
    def forward(self, sequence, structure):
        # self attention
        # sequence = sequence.permute(1, 0, 2) # check if this is correct
        structure = self.convSE(structure)
        sequence = sequence.reshape(sequence.shape[0], sequence.shape[1], params['num_heads'], params['dim_per_head']).permute(0, 2, 1, 3)
        attention = torch.matmul(sequence, sequence.permute(0, 1, 3, 2))
        attention =  attention + structure + self.pos_encoding(sequence)
        attention = attention / attention.sum(dim=2, keepdim=True) # unsure about this
        attention = attention @ sequence
        # end self attention
        attention = torch.reshape(attention, (attention.shape[0], attention.shape[2], -1))
        return attention, structure
    
    
class DynamicPositionalEncoding(nn.Module):
    def __init__(
        self,
        params,
    ):
        super().__init__()  
        self.params = params
        self.positional_encoding = nn.Parameter(torch.randn(params['max_len'], params['max_len'], 1))
        self.lin1 = nn.Linear(1, 48)
        self.silu = nn.SiLU()
        self.lin2 = nn.Linear(48, 48)
        self.lin3 = nn.Linear(48, params['num_heads'])
        # init.xavier_uniform_(self.lin1, gain=global_scale)
        # init.xavier_uniform_(self.lin2, gain=global_scale)
        # init.xavier_uniform_(self.lin3, gain=global_scale)
    
    def forward(self, sequence):
        # will this work with batches?
        seq_len = sequence.shape[2]
        x = self.positional_encoding[:seq_len, :seq_len, :]
        x = self.lin1(x)
        x = self.silu(x)
        x = self.lin2(x)
        x = self.silu(x)
        x = self.lin3(x)
        return x.permute(2, 0, 1).reshape(1, -1, seq_len, seq_len)

params = {
    'embed_dim': 192,
    'num_heads': 6,
    'hidden_dim': 768,
    'num_encoders': 12,
    'max_len': 430,
}
params['dim_per_head'] = params['embed_dim'] // params['num_heads']

model = Ribonanza(params).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
for batch in dm.train_dataloader():
    pred = model.forward(batch)[:, :, 0]
    pairs = [(p.to(device), t) for p, t in zip(pred, batch.get('dms')) if t is not None]
    print(pairs[0][0].device, pairs[0][1].device)
    optimizer.zero_grad()
    loss = torch.tensor(0.)
    for p, t in pairs:
        if t is None:
            continue
        p = p[1:-1]
        mask = t != -1000.
        if not torch.sum(mask):
            continue
        p, t = p[mask], t[mask]
        if p.device != t.device:
            print('mismatch')
        loss += criterion(p, t)
    loss.backward()
    

    break

mps:0 mps:0


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, mps:0 and cpu!

In [5]:
for batch in dm.train_dataloader():
    sequences = batch.get('sequence')
    break

In [6]:
batch.get('structure')

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [13]:
from torch import nn, tensor


        

embed_structure_batch(batch)[0]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='mps:0')

In [36]:
len(embed_sequence_batch(batch)[0])

208