In [1]:
import random
import shutil
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from tqdm import tqdm
import importlib
import math

import torch
import torch.nn as nn
import torch
import torch.cuda
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim
import torch.nn.functional as F 
from IPython.display import clear_output

import boda

In [2]:
def load_model(artifact_path):
    
    USE_CUDA = torch.cuda.device_count() >= 1
    if os.path.isdir('./artifacts'):
        shutil.rmtree('./artifacts')

    boda.common.utils.unpack_artifact(artifact_path)

    model_dir = './artifacts'

    my_model = boda.common.utils.model_fn(model_dir)
    my_model.eval()
    if USE_CUDA:
        my_model.cuda()
    
    return my_model


In [3]:
# Constants for flanks
right_flank = boda.common.constants.MPRA_DOWNSTREAM[:200]
left_flank = boda.common.constants.MPRA_UPSTREAM[-200:]

#Setting random seed
random.seed(42)

def generate_random_sequence(length):
    return ''.join(random.choice('ATCG') for _ in range(length))

def generate_random_sequences_with_flanks(num_sequences, sequence_length):
    sequences = []
    for _ in range(num_sequences):
        sequence = generate_random_sequence(sequence_length)
        sequence_with_flanks = left_flank + sequence + right_flank
        sequences.append(sequence_with_flanks)
    return sequences

N = 10  # Number of sequences
sequence_length = 200

random_sequences_for_diffusion = generate_random_sequences_with_flanks(N, sequence_length)

# Save sequences to a TSV file
tsv_file_path = "random_sequences_for_diffusion.tsv"
with open(tsv_file_path, "w") as tsv_file:
    tsv_file.write("Sequence\n")  # Write header
    for sequence in random_sequences_for_diffusion:
        tsv_file.write(sequence + "\n")

print(f"Sequences saved to {tsv_file_path}")

Sequences saved to random_sequences_for_diffusion.tsv


## Training code

In [4]:
data_module = boda.data.SeqDataModule

In [5]:
mut = (0,301) # limit of mutations from 0 to 300
epochs = 200 # specify the number of training epochs
batch_size = 1024 # number of sequences in a batch
batch_per_epoch = 1000 # number of batches in one epoch 
num_workers = 8
lr = 0.001
device = torch.device("cuda:0") 

In [7]:
def initialize_weights(m):
    if isinstance(m, nn.Conv1d):
        n = m.kernel_size[0] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2 / n))
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.001)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
            
class DataloaderWrapper:
    def __init__(self, dataloader, batch_per_epoch):
        self.batch_per_epoch = batch_per_epoch
        self.dataloader = dataloader
        self.iterator = iter(dataloader)

    def __len__(self):
        return self.batch_per_epoch
    
    def __next__(self):
        try:
            return next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.dataloader)

    def __iter__(self):
        for _ in range(self.batch_per_epoch):
            try:
                yield next(self.iterator)
            except StopIteration:
                self.iterator = iter(self.dataloader)

In [8]:
malinois_path = 'gs://tewhey-public-data/CODA_resources/malinois_model__20211113_021200__287348.tar.gz'
pretrained_model = load_model(malinois_path)

Copying gs://tewhey-public-data/CODA_resources/malinois_model__20211113_021200__287348.tar.gz...
\ [1 files][ 49.3 MiB/ 49.3 MiB]                                                
Operation completed over 1 objects/49.3 MiB.                                     
archive unpacked in ./


Loaded model from 20211113_021200 in eval mode


In [1]:
data = data_module(
    train_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_train.tsv",
    test_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_test.tsv",
    val_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_val.tsv",
    right_flank = boda.common.constants.MPRA_DOWNSTREAM[:200],
    batch_size = batch_size,
    left_flank = boda.common.constants.MPRA_UPSTREAM[-200:]
)

optimizer = torch.optim.AdamW(pretrained_model.parameters(), lr=lr) ###check, bc this is diff than their code
criterion = nn.CrossEntropyLoss()
score_criterion=nn.MSELoss()

dl_train = data.train_dataloader()
##check tensor dimensions

NameError: name 'data_module' is not defined

In [18]:
## Not edited, but pretrained typo corrected
class Trainer:
    def __init__(self,
            model: torch.nn.Module, 
            pretrained_model: torch.nn.Module,
            train_dataloader: torch.utils.data.DataLoader ,
            test_dataloader: torch.utils.data.DataLoader ,
            criterion: torch.nn.CrossEntropyLoss,
            loss_criterion: torch.nn.CrossEntropyLoss,
            optimizer: torch.optim.Optimizer,
            epochs: int,
            batch_size: int = 1024,
            batch_per_epoch: int = 1000,
            device = torch.device("cuda:0")
            ):
        self.optimizer = optimizer
        self.criterion = criterion
        self.score_criterion = loss_criterion
        self.model = model
        self.pretrained_model = pretrained_model
        self.train_dl = train_dataloader
        self.test_dl = test_dataloader
        self.epochs = epochs
        self.batch_per_epoch = batch_per_epoch
        self.device = device
        self.batch_size = batch_size
        self.score_cor_mean = []
        self.score = []
            
    def train(self, epoch):
        print(f'start training, epoch = {epoch}')
        self.model.train()
        ltr = []
        for _, data in tqdm(enumerate(self.train_dl), mininterval=60):
            target_seq, mutated_seq, _ = data
            target_seq, mutated_seq = target_seq.float().to(self.device), mutated_seq.float().to(self.device) 
            pred = self.model(mutated_seq)
            loss = self.criterion(pred, target_seq)
            ltr.append(loss.item())
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad() 

        mean_loss = np.mean(ltr)
        return mean_loss
    
    def validate(self, epoch):
        print(f'start validating, epoch = {epoch}')
        with torch.no_grad():
            self.model.eval()
            self.pretrained_model.eval()
            lte = []
            score_losses = []
            score_cores = []
            
            
            for _, data in tqdm(enumerate(self.test_dl), mininterval=60):
                target_seq_val, mutated_seq_val, _ = data
                mutated_seq_val, target_seq_val = mutated_seq_val.float().to(self.device), target_seq_val.float().to(self.device)

                target_score = mutated_seq_val[:,4,1].clone()
                mutated_seq = mutated_seq_val[:,:4,:].clone()
                mut_seqs = torch.concat((mutated_seq, torch.zeros(mutated_seq.shape[0],2,mutated_seq.shape[2], device=self.device)), dim=1)
                left_batch = torch.broadcast_to(left_s2t, (mut_seqs.shape[0], left_s2t.shape[0], left_s2t.shape[1])).to(self.device)
                right_batch = torch.broadcast_to(right_s2t, (mut_seqs.shape[0], right_s2t.shape[0], right_s2t.shape[1])).to(self.device)


                pred = self.model(mutated_seq_val)
                pred_seq = torch.softmax(pred, dim=1)
                loss = self.criterion(pred_seq, target_seq_val)
                lte.append(loss.item())
                
                seqs = torch.concat((pred_seq, torch.zeros(pred_seq.shape[0], 2, pred_seq.shape[2], device=device)), dim=1)
                
                long_pred = torch.concat((left_batch, seqs, right_batch), dim=2)
                pred_score = self.pretrained_model(long_pred)[1]
                
                score_loss = self.score_criterion(pred_score, target_score)
                score_losses.append(score_loss.item())
                score_cor = stats.pearsonr(pred_score.cpu().numpy(), target_score.cpu().numpy())[0]
                score_cores.append(score_cor)

                
            self.score.append(np.mean(score_losses))
            self.score_cor_mean.append(np.mean(score_cores))
            mean_loss_val = np.mean(lte)
            return mean_loss_val
       
        
    def training(self):
        
        self.save_dir = f"../saved_model/model_epochs_{self.epochs}"
        os.makedirs(self.save_dir, exist_ok=True)
        train_losses = []
        test_losses = []
        for epoch in tqdm(range(self.epochs)):
            tr_loss = self.train(epoch)
            train_losses.append(tr_loss)
  
            test_loss = self.validate(epoch)
            test_losses.append(test_loss)

            self.plotter(train_losses,test_losses, epoch)
            self.save_model(epoch,train_losses)
        return train_losses, test_losses, self.score
    

    def plotter(self, loss_train, loss_val, epoch):
        fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1,  figsize=(7, 7))
        
        ax1.plot(loss_train, color='red')
        ax3.plot(loss_train, color='red')
        ax2.plot(loss_val, color='blue')
        ax3.plot(loss_val, color='blue')
        ax4.plot(self.score, color = 'black')
        ax5.plot(self.score_cor_mean, color = 'black')
        ax1.grid(axis='x')
        ax2.grid(axis='x')
        ax3.grid(axis='x')
        ax2.set_xlabel('Epoch')
        ax1.set_ylabel('Train Loss')
        ax3.set_ylabel('Train and val Loss')
        ax2.set_ylabel('Val Loss')
        ax4.set_ylabel('Score MSE')
        ax5.set_ylabel('Pearson cor')
        
        suptitle_string = f'epoch={epoch}'
        fig.suptitle(suptitle_string, y=1.05, fontsize=10)

        pic_test_name = os.path.join(self.save_dir, f"lossestrainandtest_epoch={epoch}.png")
        plt.tight_layout()
        fig.savefig(pic_test_name)
        fig.show()
        np.save(f'../saved_model/model_epochs_{self.epochs}/train_loss.npy', np.array(loss_train))
        np.save(f'../saved_model/model_epochs_{self.epochs}/test_loss.npy', np.array(loss_val))
        np.save(f'../saved_model/model_epochs_{self.epochs}/score_loss.npy', np.array(self.score))
        np.save(f'../saved_model/model_epochs_{self.epochs}/score_loss.npy', np.array(self.score_cor_mean))
            
    def save_model(self, epoch, losseshist):
        PATH = os.path.join(self.save_dir, f"model_{epoch}.pth")
            
        torch.save({
            'epoch' : epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': losseshist
            }, PATH)

        print(f'---------------  SAVED MODEL {PATH}-------------------')

In [18]:
## Trying to edit
class Trainer:
    def __init__(self,
            model: torch.nn.Module, 
            pretrained_model: torch.nn.Module,
            train_dataloader: torch.utils.data.DataLoader ,
            test_dataloader: torch.utils.data.DataLoader ,
            criterion: torch.nn.CrossEntropyLoss,
            loss_criterion: torch.nn.CrossEntropyLoss,
            optimizer: torch.optim.Optimizer,
            epochs: int,
            batch_size: int = 1024,
            batch_per_epoch: int = 1000,
            device = torch.device("cuda:0")
            ):
        self.optimizer = optimizer
        self.criterion = criterion
        self.score_criterion = loss_criterion
        self.model = model
        self.pretrained_model = pretrained_model
        self.train_dl = train_dataloader
        self.test_dl = test_dataloader
        self.epochs = epochs
        self.batch_per_epoch = batch_per_epoch
        self.device = device
        self.batch_size = batch_size
        self.score_cor_mean = []
        self.score = []
            
    def train(self, epoch):
        print(f'start training, epoch = {epoch}')
        self.model.train()
        ltr = []
        for _, data in tqdm(enumerate(self.train_dl), mininterval=60):
            target_seq, mutated_seq, _ = data
            target_seq, mutated_seq = target_seq.float().to(self.device), mutated_seq.float().to(self.device) 
            pred = self.model(mutated_seq)
            loss = self.criterion(pred, target_seq)
            ltr.append(loss.item())
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad() 

        mean_loss = np.mean(ltr)
        return mean_loss
    
    def validate(self, epoch):
        print(f'start validating, epoch = {epoch}')
        with torch.no_grad():
            self.model.eval()
            self.pretrained_model.eval()
            lte = []
            score_losses = []
            score_cores = []
            
            
            for _, data in tqdm(enumerate(self.test_dl), mininterval=60):
                target_seq_val, mutated_seq_val, _ = data
                mutated_seq_val, target_seq_val = mutated_seq_val.float().to(self.device), target_seq_val.float().to(self.device)

                target_score = mutated_seq_val[:,4,1].clone()
                mutated_seq = mutated_seq_val[:,:4,:].clone()
                mut_seqs = torch.concat((mutated_seq, torch.zeros(mutated_seq.shape[0],2,mutated_seq.shape[2], device=self.device)), dim=1)
                left_batch = torch.broadcast_to(left_s2t, (mut_seqs.shape[0], left_s2t.shape[0], left_s2t.shape[1])).to(self.device)
                right_batch = torch.broadcast_to(right_s2t, (mut_seqs.shape[0], right_s2t.shape[0], right_s2t.shape[1])).to(self.device)


                pred = self.model(mutated_seq_val)
                pred_seq = torch.softmax(pred, dim=1)
                loss = self.criterion(pred_seq, target_seq_val)
                lte.append(loss.item())
                
                seqs = torch.concat((pred_seq, torch.zeros(pred_seq.shape[0], 2, pred_seq.shape[2], device=device)), dim=1)
                
                long_pred = torch.concat((left_batch, seqs, right_batch), dim=2)
                pred_score = self.pretrained_model(long_pred)[1]
                
                score_loss = self.score_criterion(pred_score, target_score)
                score_losses.append(score_loss.item())
                score_cor = stats.pearsonr(pred_score.cpu().numpy(), target_score.cpu().numpy())[0]
                score_cores.append(score_cor)

                
            self.score.append(np.mean(score_losses))
            self.score_cor_mean.append(np.mean(score_cores))
            mean_loss_val = np.mean(lte)
            return mean_loss_val
       
        
    def training(self):
        
        self.save_dir = f"../saved_model/model_epochs_{self.epochs}"
        os.makedirs(self.save_dir, exist_ok=True)
        train_losses = []
        test_losses = []
        for epoch in tqdm(range(self.epochs)):
            tr_loss = self.train(epoch)
            train_losses.append(tr_loss)
  
            test_loss = self.validate(epoch)
            test_losses.append(test_loss)

            self.plotter(train_losses,test_losses, epoch)
            self.save_model(epoch,train_losses)
        return train_losses, test_losses, self.score
    

    def plotter(self, loss_train, loss_val, epoch):
        fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, 1,  figsize=(7, 7))
        
        ax1.plot(loss_train, color='red')
        ax3.plot(loss_train, color='red')
        ax2.plot(loss_val, color='blue')
        ax3.plot(loss_val, color='blue')
        ax4.plot(self.score, color = 'black')
        ax5.plot(self.score_cor_mean, color = 'black')
        ax1.grid(axis='x')
        ax2.grid(axis='x')
        ax3.grid(axis='x')
        ax2.set_xlabel('Epoch')
        ax1.set_ylabel('Train Loss')
        ax3.set_ylabel('Train and val Loss')
        ax2.set_ylabel('Val Loss')
        ax4.set_ylabel('Score MSE')
        ax5.set_ylabel('Pearson cor')
        
        suptitle_string = f'epoch={epoch}'
        fig.suptitle(suptitle_string, y=1.05, fontsize=10)

        pic_test_name = os.path.join(self.save_dir, f"lossestrainandtest_epoch={epoch}.png")
        plt.tight_layout()
        fig.savefig(pic_test_name)
        fig.show()
        np.save(f'../saved_model/model_epochs_{self.epochs}/train_loss.npy', np.array(loss_train))
        np.save(f'../saved_model/model_epochs_{self.epochs}/test_loss.npy', np.array(loss_val))
        np.save(f'../saved_model/model_epochs_{self.epochs}/score_loss.npy', np.array(self.score))
        np.save(f'../saved_model/model_epochs_{self.epochs}/score_loss.npy', np.array(self.score_cor_mean))
            
    def save_model(self, epoch, losseshist):
        PATH = os.path.join(self.save_dir, f"model_{epoch}.pth")
            
        torch.save({
            'epoch' : epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss': losseshist
            }, PATH)

        print(f'---------------  SAVED MODEL {PATH}-------------------')

## Generation code

In [9]:
## Mutagenesis code (copied from LegNet)

def mutagenesis(seqs, maxmut):
    batchsize = seqs.shape[0]
    seqlen = seqs.shape[2]
    muts = torch.full((batchsize,), maxmut)
    index = torch.arange(batchsize)
    mut_positions = torch.zeros(batchsize, seqlen, dtype=bool)
    for i in range(maxmut):
        single_positions = torch.randint(high=seqlen, size=(batchsize,))
        mut_positions[index, single_positions] |= muts > i

    mut_positions = mut_positions[:,None,:].broadcast_to(seqs.shape)
    x = seqs.permute(2, 0, 1)[mut_positions.permute(2, 0, 1)]
    mut_number = x.shape[0] // 4
    
    myperm = torch.randint(high=ALLPERM.shape[0], size=(mut_number,))
    myperm = (ALLPERM[myperm] + torch.arange(mut_number)[:,None] * 4).ravel()
    seqs.permute(2, 0, 1)[mut_positions.permute(2, 0, 1)] = x[myperm]

In [10]:
# diffusion-like sampling

def predict_float(dl_test, mut_interval, intensities, start, end):    
    seqs_batches = []
    scores_batches = []
    b_i = 0
    with torch.no_grad():
        diffusion_model.eval()
        for data in dl_test:
            b_i += 1
            seq_batch = data.float().to(device)
            score_chanels = seq_batch[:,4:5,:].clone().to(device)
            seq_batch = seq_batch[:,:4,:]
            target_score = torch.FloatTensor(seq_batch.shape[0], 1, 1).uniform_(start, end).to(device)
            for intens, muts, in zip(intensities, mut_interval):
                mutagenesis(seq_batch, muts)              
                tmp = torch.broadcast_to(target_score, (target_score.shape[0], 1, 80))
                seq_batch = torch.concat((seq_batch.to(device), tmp.to(device), torch.full_like(score_chanels, intens).to(device)), dim=1) 
                seq_batch = diffusion_model(seq_batch)
                seq_batch = torch.softmax(seq_batch, dim=1) 
            seqs_batches.append(seq_batch.cpu().numpy()) 
            scores_batches.append((target_score.squeeze()).cpu().numpy())
        return seqs_batches, target_score, scores_batches