In [1]:
#!/usr/bin/env python
# Copyright 2023 Z Zhang

# BioSeq2Seq, Version 1.0;
# you may not use this file except in compliance with the License.
# Use of this code requires following originality guidelines
# and declaring the source of the code.
# email:zhichunli@mail.dlut.edu.cn
# =========================================================================
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
print("CUDA_VISIBLE_DEVICES =", os.environ.get("CUDA_VISIBLE_DEVICES"))
from re import I
import numpy as np
from model.model_MambaHM import MambaHM
from utils.functions import split_based_num
from utils.dataloader import RD_dataloader
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from typing import Optional
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter
import shutil
from einops import rearrange

import torch
DEVICE=torch.device("cuda:0")

print("CUDA_VISIBLE_DEVICES =", os.environ["CUDA_VISIBLE_DEVICES"])
print("Torch sees", torch.cuda.device_count(), "GPUs")
print("Device 0 name:", torch.cuda.get_device_name(0))

CUDA_VISIBLE_DEVICES = 0


  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


CUDA_VISIBLE_DEVICES = 0
Torch sees 1 GPUs
Device 0 name: NVIDIA GeForce RTX 3090


Learning Rate Control

In [2]:
def lr_lambda(epoch):
    if epoch < 10:
        return 1.0  # 前10个epoch不变
    elif epoch < 20:
        return 0.5 ** (epoch - 10)  # 中期慢慢衰减
    else:
        return (0.5 ** 10) * (0.3 ** (epoch - 20))  # 后期快速衰减

Evaluation Metric

In [3]:
class MeanPearsonCorrCoefPerChannel(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    def __init__(self, n_channels:int, dist_sync_on_step=False):
        """Calculates the mean pearson correlation across channels aggregated over regions"""
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.reduce_dims=(0, 1)
        self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
        self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
        self.add_state("true_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
        self.add_state("pred", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
        self.add_state("pred_squared", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
        self.add_state("count", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape

        self.product += torch.sum(preds * target, dim=self.reduce_dims)
        self.true += torch.sum(target, dim=self.reduce_dims)
        self.true_squared += torch.sum(torch.square(target), dim=self.reduce_dims)
        self.pred += torch.sum(preds, dim=self.reduce_dims)
        self.pred_squared += torch.sum(torch.square(preds), dim=self.reduce_dims)
        self.count += torch.sum(torch.ones_like(target), dim=self.reduce_dims)

    def compute(self):
        true_mean = self.true / self.count
        pred_mean = self.pred / self.count

        covariance = (self.product
                    - true_mean * self.pred
                    - pred_mean * self.true
                    + self.count * true_mean * pred_mean)

        true_var = self.true_squared - self.count * torch.square(true_mean)
        pred_var = self.pred_squared - self.count * torch.square(pred_mean)
        tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
        correlation = covariance / tp_var
        return correlation
    

def convert_resolution(target, window_width, aim_resolution):
    k = aim_resolution // window_width
    target = rearrange(target, 'b (r n) d -> b r n d', n = k)
    target = torch.mean(target, dim=2)
    return target

1. Data Preparation

In [None]:
reference_genome_file = 'hg19.fa'

sequence_data_file = [
                        [
                        ['K562_ATAC.bw'],
                         ]
                        
                        ]

target_seq_file = [
    [   
        'H3K122ac.bigWig',
        'H3K4me1.bigWig', 
        'H3K4me2.bigWig', 
        'H3K4me3.bigWig', 
        'H3K27ac.bigWig', 
        'H3K27me3.bigWig', 
        'H3K36me3.bigWig', 
        'H3K9ac.bigWig', 
        'H3K9me3.bigWig', 
        'H4K20me1.bigWig', 
    ],
    
    ]

include_chr = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8',
              'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17',
              'chr18', 'chr19', 'chr20', 'chr21', 'chrX']
blacklist_file = 'hg19Blacklist.bed'


outdir = 'outdir'
if not os.path.exists(outdir):
    os.mkdir(outdir)
log_dir = os.path.join(outdir, 'log/')
model_save = os.path.join(outdir, 'model.pth')



train_samples = np.loadtxt('MambaHM/samples/train.bed', dtype=str, delimiter='\t')
validation_samples = np.loadtxt('MambaHM/samples/valid.bed', dtype=str, delimiter='\t')



2. Model Building

In [8]:
lr = 1e-4
window_width = 16

batch_size = 4
steps_per_epoch = 20 * batch_size
num_epochs = 1
max_avg_pearson = 0

model = MambaHM(channels=384,
            num_heads=8,
            num_transformer_layers=3,
            pooling_type='max',
            output_channels=len(target_seq_file[0]),
            target_length=7168,
            device=DEVICE
            ).to(DEVICE)

model = torch.nn.DataParallel(model, device_ids=[0])

optimizer = torch.optim.Adam(model.parameters(), lr)
scheduler = LambdaLR(optimizer, lr_lambda)
criterion = torch.nn.MSELoss(reduction='mean')


In [9]:
for epoch_i in range(num_epochs):
    print('epoch: ', epoch_i)
    samples, _ = split_based_num(train_samples, steps_per_epoch)

    train_dataset = RD_dataloader(samples,
                                reference_genome_file,
                                sequence_data_file,
                                target_seq_file,
                                window_width=window_width,
                                extend=40960,
                                nan=0,
                                rc=False,
                                )
    data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

    metric_train = MeanPearsonCorrCoefPerChannel(n_channels=10)
    loss_train = 0
    data_iter = iter(data_loader)
    model.train()
    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
    for i in tqdm(range(steps_per_epoch // batch_size), ascii=True):
        optimizer.zero_grad()
        data_item = next(data_iter)
        dna = data_item[0].to(DEVICE)
        seq = data_item[1].to(DEVICE)
        target = data_item[2].to(DEVICE)
        seq_id = data_item[-1]

        outputs = model(dna, seq)
        loss = criterion(target, outputs)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.)
        optimizer.step()
        loss_train += loss.item()
        target_t = target
        predicted_t = outputs
        metric_train.update(target_t.detach().cpu(), predicted_t.detach().cpu())
    
    scheduler.step()

epoch:  0
lr:  0.0001


100%|##########| 20/20 [00:37<00:00,  1.89s/it]


3. Validation

In [10]:
eva_data_loader = RD_dataloader(validation_samples,
                            reference_genome_file,
                            sequence_data_file,
                            target_seq_file,
                            window_width=window_width,
                            extend=40960,  
                            nan=0,
                            valid=True,
                            rc=False,
                            )
eva_dataset = DataLoader(dataset=eva_data_loader, batch_size=1, shuffle=True)

print('-' * 50)

loss_test = 0
metric = MeanPearsonCorrCoefPerChannel(n_channels=10)
model.eval()
# model.set_eval()
with torch.no_grad():
    for i, eva_data_item in tqdm(enumerate(eva_dataset), ascii=True):
        if i > len(validation_samples):
            break
        dna = eva_data_item[0].to(DEVICE)
        seq = eva_data_item[1].to(DEVICE)
        target = eva_data_item[2].to(DEVICE)
        predicted = model(dna, seq)
        target_resolution = convert_resolution(target, window_width, 1024).detach().cpu()
        predicted_resolution = convert_resolution(predicted, window_width, 1024).detach().cpu()
        loss_eva = criterion(target, predicted).mean()
        loss_test += loss_eva.item()
        
        target_t = target_resolution
        predicted_t = predicted_resolution
        metric.update(target_t, predicted_t)
print(metric.compute())

--------------------------------------------------


2655it [09:21,  4.73it/s]

tensor([0.4676, 0.4001, 0.5013, 0.4449, 0.4394, 0.2909, 0.1341, 0.4185, 0.1705,
        0.4461])



