In [1]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload

import os
import sys

import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from mpra import MPRA_Collection, MPRA_Paper, MPRA_Dataset

In [2]:
folder_mpra = '/data/tuxm/project/MPRA-collection/data/mpra_test/'

mpra_collection = MPRA_Collection(folder_mpra)
print('\n'.join(mpra_collection.list_papers()))
print('\n'.join([f'{k}: {v}' for k, v in mpra_collection.list_datasets().items()]))

Nature_2022_Regev
GenomeResearch_2017_Seelig
Nature_2022_Regev: ['train_complex', 'train_defined', 'test_complex', 'test_defined']
GenomeResearch_2017_Seelig: ['random', 'native']


In [3]:
mpra_dataset = mpra_collection.get_dataset('GenomeResearch_2017_Seelig', 'random')

mpra_dataset.load()
# mpra_dataset now contains:
## info: dict()
## data: pd.DataFrame()
mpra_dataset.print_info()
print('len(mpra_dataset):', len(mpra_dataset))
print('mpra_dataset.data.shape:', mpra_dataset.data.shape)

mpra_dataset.init_XYobs()
# mpra_dataset now contains:
## _X: torch.Tensor()
## _Y: pd.DataFrame()
## X: torch.Tensor()
## Y: torch.Tensor()
print('mpra_dataset._X.shape:', mpra_dataset._X.shape)
print('mpra_dataset._Y.shape:', mpra_dataset._Y.shape)
print('mpra_dataset.X.shape:', mpra_dataset.X.shape)
print('mpra_dataset.Y.shape:', mpra_dataset.Y.shape)

==== ==== ==== ====
Description: Random 5' UTRs
MPRA Technique: Classic MPRA
Readout Assay: RNA-seq
Regulatory Element: 5' UTR
Sequence Origin: Random
Species: Yeast
==== ==== ==== ====
len(mpra_dataset): 489348
mpra_dataset.data.shape: (489348, 4)
mpra_dataset._X.shape: torch.Size([489348, 4, 64])
mpra_dataset._Y.shape: (489348, 1)
mpra_dataset.X.shape: torch.Size([489348, 4, 64])
mpra_dataset.Y.shape: torch.Size([489348, 1])


In [4]:
splits = mpra_dataset.split_rand(fracs=[0.8, 0.1, 0.1], seed=20240404)
split_train, split_valid, split_infer = splits
# each split now contains: 0-X, 1-Y, 2-obsX, 3-obsY
print('split_train[0].shape: ', split_train[0].shape)
print('split_valid[1].shape: ', split_valid[1].shape)

split_train[0].shape:  torch.Size([391480, 4, 64])
split_valid[1].shape:  torch.Size([48934, 1])


In [5]:
batch_size = 64
n_workers = 4
from torch.utils.data import TensorDataset, DataLoader
dataloader_train = DataLoader(TensorDataset(split_train[0], split_train[1]), batch_size=batch_size, num_workers=n_workers, shuffle=True)
dataloader_valid = DataLoader(TensorDataset(split_valid[0], split_valid[1]), batch_size=batch_size, num_workers=n_workers, shuffle=False)
dataloader_infer = DataLoader(TensorDataset(split_infer[0], split_infer[1]), batch_size=batch_size, num_workers=n_workers, shuffle=False)

In [6]:
# Define LightningModule

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule

def ConvBlock(
    in_channels,
    out_channels,
    kernel_size,
    padding = 'same',
    stride = 1,
    dilation = 1,
    bias = True,
    batchnorm = True,
    activation = 'relu',
):
    layers = [nn.BatchNorm1d(in_channels)] if batchnorm else []
    layers.append(nn.Conv1d(
        in_channels = in_channels,
        out_channels = out_channels,
        kernel_size = kernel_size,
        padding = padding,
        stride = stride,
        dilation = dilation,
        bias = bias,
    ))

    if activation == 'relu':
        layers.append(nn.ReLU())
    elif activation == 'leaky_relu':
        layers.append(nn.LeakyReLU())
    elif activation == 'gelu':
        layers.append(nn.GELU())
    elif activation == 'none':
        pass
    else:
        raise NotImplementedError

    return nn.Sequential(*layers)

# from https://github.com/boxiangliu/enformer-pytorch/blob/main/model/enformer.py
class Residual(nn.Module):
    # makes the module residual

    def __init__(self, module):
        super().__init__()
        self._module = module

    def forward(self, x, *args, **kwargs):
        return x + self._module(x, *args, **kwargs)

class ResidualCNN(LightningModule):
    def __init__(self,
        len_seq,
        channels_conv = 256,
        channels_linear = 256,
        names_readout = [''],
        channels_out = 1,
        # channels_z = 0,
        channels_z = 1,
        kernel_size = 3,
        kernel_size_in = 13,
        n_layers_conv = 3,
        n_convs_per_layer = 3,
        n_layers_linear = 3,
        padding = 1,
        stride = 1,
        dilation = 1,
        bias = True,
        batchnorm = False,
        activation = 'relu',
        pooling_type = 'max',
        pooling_size = 4,
        dropout = 0.1,
        loss_type = 'mse',
        optimizer_type = 'Adam',
        lr = 1e-4,
        wd = 1e-5,
        scheduler_type = 'Cyclic',
    ):
        super().__init__()
        self.save_hyperparameters()

        self.names_readout = names_readout
        self.channels_z = channels_z

        self.conv_in = ConvBlock(
            in_channels = 4,
            out_channels = channels_conv,
            kernel_size = kernel_size_in,
            padding = (kernel_size_in - 1) // 2,
            stride = 1,
            dilation = 1,
            bias = bias,
            batchnorm = batchnorm,
            activation = activation,
        )
        self.convs = nn.ModuleList()

        channels_len = len_seq
        for _ in range(n_layers_conv):
            for _ in range(n_convs_per_layer):
                self.convs.append(
                    Residual(
                        ConvBlock(
                            in_channels = channels_conv,
                            out_channels = channels_conv,
                            kernel_size = kernel_size,
                            padding = padding,
                            stride = stride,
                            dilation = dilation,
                            bias = bias,
                            batchnorm = batchnorm,
                            activation = activation,
                        )
                    )
                )
            if pooling_type == 'max':
                self.convs.append(nn.MaxPool1d(kernel_size=pooling_size, ceil_mode=True))
            elif pooling_type == 'avg':
                self.convs.append(nn.AvgPool1d(kernel_size=pooling_size, ceil_mode=True))   
            elif pooling_type == 'conv':
                self.convs.append(
                    ConvBlock(
                        in_channels = channels_conv,
                        out_channels = channels_conv,
                        kernel_size = kernel_size,
                        padding = padding,
                        stride = pooling_size,
                        dilation = dilation,
                        bias = bias,
                        batchnorm = batchnorm,
                        activation = activation,
                    )
                )
            elif pooling_type == 'none':
                pass
            else:
                raise NotImplementedError
            
            if pooling_type != 'none':
                channels_len = (channels_len + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
                channels_len = (channels_len - pooling_size) // pooling_size + 1
                # print("channels_len: {}".format(channels_len))

        self.linear_in = nn.Linear(channels_conv * channels_len, channels_linear)
        self.linears = nn.ModuleList()
        for _ in range(n_layers_linear):
            self.linears.append(
                Residual(
                    nn.Sequential(
                        nn.Linear(channels_linear, channels_linear),
                        nn.ReLU(),
                        nn.Dropout(dropout),
                    )
                )
            )
        self.linear_out = nn.Linear(channels_linear, channels_out)

        if loss_type == 'mse':
            self.criterion = nn.MSELoss()
        elif loss_type == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss()
        else:
            raise NotImplementedError

        if optimizer_type == 'Adam':
            self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
        else:
            raise NotImplementedError

        if scheduler_type == 'Cyclic':
            self.scheduler = torch.optim.lr_scheduler.CyclicLR(
                self.optimizer, 
                base_lr = lr / 5,
                max_lr = lr * 5,
                cycle_momentum = False,
            )
        elif scheduler_type == 'Plateau':
            raise NotImplementedError
        else:
            print('NO Scheduler')
            raise NotImplementedError

        self.validation_step_outputs = []
        self.test_step_outputs = []
        self.y_pred = None
        self.y_targ = None
    
    def forward(self, x, z = None):
        # print('[0] x.shape:', x.shape)
        x = self.conv_in(x)
        # print('[1] x.shape:', x.shape)
        for conv in self.convs:
            x = conv(x)
        # print('[2] x.shape:', x.shape)
        x = x.view(x.shape[0], -1)
        # print('[3] x.shape:', x.shape)

        # check if z is Tensor
        z_default = torch.zeros(x.shape[0], self.channels_z, device=x.device)
        if self.channels_z > 0:
            z_default[:, 0] = 1
        z = z if isinstance(z, torch.Tensor) else z_default
        z = z if z.shape == (x.shape[0], self.channels_z) else z_default

        # print('[4] x.shape:', x.shape)
        x = self.linear_in(x)
        # print('[5] x.shape:', x.shape)
        for linear in self.linears:
            x = linear(x)
        # print('[6] x.shape:', x.shape)
        x = self.linear_out(x)
        # print('[7] x.shape:', x.shape)

        return x

    def training_step(self, batch, batch_idx):
        x, y_targ = batch[:2]
        z = batch[2] if len(batch) > 2 and self.channels_z > 0 else None
        y_pred = self(x, z=z)
        loss = self.criterion(y_pred, y_targ)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y_targ = batch[:2]
        z = batch[2] if len(batch) > 2 and self.channels_z > 0 else None
        y_pred = self(x, z=z)
        self.validation_step_outputs.append((y_pred, y_targ))
    
    def on_validation_epoch_end(self):
        y_pred, y_targ = map(lambda x: torch.cat(x, dim=0), zip(*self.validation_step_outputs))
        loss = self.criterion(y_pred, y_targ)
        self.log('valid_loss', loss)

        for k, name in enumerate(self.names_readout):
            corr = torch.corrcoef(torch.stack([y_pred[:, k], y_targ[:, k]]))[0, 1]
            self.log('valid_corr_{}'.format(name), corr)
        
        self.validation_step_outputs = []

    def test_step(self, batch, batch_idx):
        x, y_targ = batch[:2]
        z = batch[2] if len(batch) > 2 and self.channels_z > 0 else None
        y_pred = self(x, z=z)
        self.test_step_outputs.append((y_pred, y_targ))
    
    def on_test_epoch_end(self):
        y_pred, y_targ = map(lambda x: torch.cat(x, dim=0), zip(*self.test_step_outputs))
        self.y_pred, self.y_targ = y_pred, y_targ

        dict_infer = dict()
        loss = self.criterion(y_pred, y_targ)
        dict_infer['infer_loss'] = loss

        for k, name in enumerate(self.names_readout):
            corr = torch.corrcoef(torch.stack([y_pred[:, k], y_targ[:, k]]))[0, 1]
            dict_infer['infer_corr_{}'.format(name)] = corr
        dict_infer['infer_corr'] = sum([dict_infer['infer_corr_{}'.format(name)] for name in self.names_readout]) / len(self.names_readout)
        self.log_dict(dict_infer)
        
        self.test_step_outputs = []
    
    def predict_step(self, batch, batch_idx, dataloader_idx = None):
        x, y_targ = batch[:2]
        z = batch[2] if len(batch) > 2 and self.channels_z > 0 else None
        y_pred = self(x, z=z)
        return y_pred

    def configure_optimizers(self):
        return [self.optimizer], [{'scheduler': self.scheduler, 'interval': 'step'}]


In [7]:
# Todo: make those more readable
len_seq = mpra_dataset._X.shape[2]
names_readout = [name[3:] for name in mpra_dataset.list_Y_names()]

module = ResidualCNN(
    len_seq = len_seq,
    names_readout = names_readout,
    channels_out = len(names_readout),
    n_layers_conv = 1,
    n_convs_per_layer = 5,
    n_layers_linear = 2,
    pooling_size = 4,
)

In [8]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
callback_ES = EarlyStopping(monitor='valid_loss', mode='min', patience=5)
callback_MC = ModelCheckpoint(monitor='valid_loss', mode='min', save_top_k=1, save_last=True)
from pytorch_lightning import Trainer
trainer = Trainer(
    max_epochs=10000, 
    callbacks=[callback_ES, callback_MC], 
    accelerator='cuda', 
    val_check_interval=100, 
    check_val_every_n_epoch=None,
    devices=[2],
    enable_progress_bar=False,
    enable_model_summary=False,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
__SEED__ = 42
np.random.seed(__SEED__)
torch.manual_seed(__SEED__)

print('==== trainer.fit(module, dataloader_train, dataloader_valid) ====')
trainer.fit(module, dataloader_train, dataloader_valid)
module = module.__class__.load_from_checkpoint(callback_MC.best_model_path)

print('==== trainer.test(module, dataloader_infer) ====')
results = trainer.test(module, dataloader_infer)

You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /homes/gws/ylsheng/mpra_test/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


==== trainer.fit(module, dataloader_train, dataloader_valid) ====


You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


==== trainer.test(module, dataloader_infer) ====
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       infer_corr           0.6433966755867004
     infer_corr_expr        0.6433966755867004
       infer_loss           0.7691899538040161
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
