In [1]:
import os
import argparse
from argparse import Namespace
import pathlib
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import esm
from esm.data import *
from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
from esm.model.esm2_supervised import ESM2
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
from esm.modules import ConvTransformerLayer

import numpy as np
import pandas as pd
import random
import math
import scipy.stats as stats
from scipy.stats import spearmanr, pearsonr
from sklearn.metrics import r2_score, f1_score, roc_auc_score, mean_squared_error, mean_absolute_error
from sklearn import preprocessing
from copy import deepcopy
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import preprocessing
from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_auc_score, auc
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.utils import class_weight
from torch.utils.data import RandomSampler, SequentialSampler

from collections import Counter
# os.chdir('/root/devdata/pansc/github/UTR_Insight')

global layers, heads, embed_dim, batch_toks, cnn_layers, alphabet

from esm.modules import * 

import warnings
warnings.filterwarnings("ignore")

In [2]:
layers = 6
heads = 16
embed_dim =64
# batch_toks = 4096*2 #4096
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
repr_layers = [0, layers]

In [6]:
class DynamicFusionGate(nn.Module):
    """Dynamic fusion gate with multi-head attention style computation."""
    def __init__(self, embed_dim, num_heads=4):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads   
        # Query, Key projections
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(2 * embed_dim, embed_dim)
        # Value projections for both branches
        self.value_attn = nn.Linear(embed_dim, embed_dim)
        self.value_conv = nn.Linear(embed_dim, embed_dim)
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        # Scale factor
        self.scale = self.head_dim ** -0.5
        
    def forward(self, x_attn, x_conv):
        """
        Args:
            x_attn: Transformer branch output [B, T, C]
            x_conv: CNN branch output [B, T, C]
        Returns:
            fused_output: [B, T, C]
            gate_values: [B, T, 1] fusion weights for analysis
        """
        batch_size, seq_len, _ = x_attn.shape
        
        # Concatenate features for key
        x_cat = torch.cat([x_attn, x_conv], dim=-1)  # [B, T, 2C]
        # Project queries and keys
        q = self.query(x_attn).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]
        k = self.key(x_cat).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]
        # Compute attention scores (gate weights)
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale  # [B, H, T, T]
        attn_weights = F.softmax(attn_scores, dim=-1)  # [B, H, T, T]
        # Project values for both branches
        v_attn = self.value_attn(x_attn).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]
        v_conv = self.value_conv(x_conv).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [B, H, T, D]
        # Apply attention to values
        attended_attn = (attn_weights @ v_attn).transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        attended_conv = ((1 - attn_weights) @ v_conv).transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        # Combine attended features
        fused_output = attended_attn + attended_conv
        fused_output = self.out_proj(fused_output)
        # Compute average gate values for analysis
        gate_values = attn_weights.mean(dim=1).mean(dim=-1, keepdim=True)  # [B, T, 1]
        return fused_output, gate_values


In [None]:
class KmerSpecificExpert(nn.Module):
    """Expert specialized for processing features at a specific k-mer scale."""
    def __init__(self, embed_dim, ffn_embed_dim, kmer_size, dropout=0.1):
        super().__init__()
        self.kmer_size = kmer_size
        self.embed_dim = embed_dim
        self.ffn_embed_dim = ffn_embed_dim
        # K-mer specific processing
        self.conv_proj = nn.Conv1d(
            embed_dim, embed_dim, 
            kmer_size, 
            padding=(kmer_size - 1) // 2
        )
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ffn_embed_dim),
            nn.GELU(),
            nn.Linear(ffn_embed_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.layer_norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        # x shape: [num_selected_tokens, embed_dim]
        x = x.unsqueeze(1)  # [num_selected_tokens, 1, embed_dim]
        # Process with k-mer specific convolution
        residual = x
        x = x.permute(0, 2, 1)  # [num_selected_tokens, embed_dim, 1]
        x = self.conv_proj(x)
        x = x.permute(0, 2, 1)  # [num_selected_tokens, 1, embed_dim]
        x = self.layer_norm(x + residual)
        x = x.squeeze(1)  # [num_selected_tokens, embed_dim]
        # Process with FFN
        residual = x
        x = self.ffn(x)
        return x + residual


class MOELayerWithKmerExperts(nn.Module):
    """Mixture of Experts layer with k-mer specific experts."""
    def __init__(self, embed_dim, ffn_embed_dim, num_experts=4, top_k=2, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_experts = num_experts
        self.top_k = top_k
        self.expert_kmer_sizes = [3, 5, 7, 9]  # Different k-mer sizes for each expert
        # Create experts with different k-mer specializations
        self.experts = nn.ModuleList([
            KmerSpecificExpert(embed_dim, ffn_embed_dim, kmer, dropout)
            for kmer in self.expert_kmer_sizes[:num_experts]
        ])
        # Gate to select experts
        self.gate = nn.Linear(embed_dim, num_experts)
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        # Compute gate scores
        gate_scores = self.gate(x)  # [B, T, num_experts]
        gate_weights = F.softmax(gate_scores, dim=-1)  # [B, T, num_experts]
        # Get top-k experts
        top_k_weights, top_k_indices = torch.topk(
            gate_weights, self.top_k, dim=-1, sorted=False
        )  # [B, T, top_k]
        # Normalize top-k weights
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
        # Flatten the batch and sequence dimensions for expert processing
        flat_x = x.view(-1, self.embed_dim)  # [B*T, C]
        flat_top_k_indices = top_k_indices.view(-1, self.top_k)  # [B*T, top_k]
        flat_top_k_weights = top_k_weights.view(-1, self.top_k)  # [B*T, top_k]
        
        # Initialize output
        output = torch.zeros_like(flat_x)
        # Process with each expert
        for i, expert in enumerate(self.experts):
            # Create mask for current expert
            expert_mask = (flat_top_k_indices == i).any(dim=-1)  # [B*T]
            if expert_mask.any():
                # Process relevant tokens with current expert
                expert_input = flat_x[expert_mask]  # [M, C]
                expert_output = expert(expert_input)
                # Get weights for this expert
                expert_weight = flat_top_k_weights[expert_mask] * \
                               (flat_top_k_indices[expert_mask] == i).float()  # [M, top_k]
                expert_weight = expert_weight.sum(dim=-1, keepdim=True)  # [M, 1]
                # Add weighted expert output
                output[expert_mask] += expert_output * expert_weight
        # Reshape output back to original dimensions
        output = output.view(batch_size, seq_len, self.embed_dim)
        # Project and return output
        output = self.out_proj(output)
        return self.dropout(output)

class ParallelConvTransformerLayer(nn.Module):
    """Parallel CNN and Transformer layer with k-mer specific MOE and dynamic fusion."""
    def __init__(
        self,
        embed_dim,
        ffn_embed_dim,
        attention_heads,
        kmer=7,
        dropout=0.1,
        add_bias_kv=True,
        use_esm1b_layer_norm=True,
        use_rotary_embeddings=False,
        num_experts=4,  # MOE parameters
        top_k=2,       # MOE parameters
        use_moe=False,  # Whether to use MOE
        use_dynamic_fusion=True,  # Whether to use dynamic fusion
        fusion_heads=4  # Number of heads for dynamic fusion
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.ffn_embed_dim = ffn_embed_dim
        self.attention_heads = attention_heads
        self.kmer = kmer
        self.dropout = dropout
        self.use_rotary_embeddings = use_rotary_embeddings
        self.use_moe = use_moe
        self.use_dynamic_fusion = use_dynamic_fusion
        self.fusion_heads = fusion_heads
        
        BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm

        # Transformer branch
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            self.attention_heads,
            add_bias_kv=add_bias_kv,
            add_zero_attn=False,
        )
        self.attn_layer_norm = BertLayerNorm(self.embed_dim)
        
        # CNN branch
        self.conv = nn.Conv1d(
            self.embed_dim, 
            self.embed_dim, 
            self.kmer, 
            padding=(self.kmer - 1) // 2
        )
        self.conv_layer_norm = BertLayerNorm(self.embed_dim)
        
        # Fusion mechanism
        if use_dynamic_fusion:
            self.fusion = DynamicFusionGate(embed_dim, num_heads=fusion_heads)
        else:
            self.fusion_gate = nn.Sequential(
                nn.Linear(2 * embed_dim, embed_dim),
                nn.Sigmoid()
            )
        
        # FFN - either standard or MOE with k-mer experts
        if use_moe:
            self.ffn = MOELayerWithKmerExperts(
                embed_dim, 
                ffn_embed_dim, 
                num_experts=num_experts, 
                top_k=top_k, 
                dropout=dropout
            )
        else:
            self.ffn = nn.Sequential(
                nn.Linear(self.embed_dim, self.ffn_embed_dim),
                nn.GELU(),
                nn.Linear(self.ffn_embed_dim, self.embed_dim),
                nn.Dropout(self.dropout)
            )
            
        self.final_layer_norm = BertLayerNorm(self.embed_dim)
        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(
        self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False
    ):
        residual = x
        
        # Create inverse padding mask correctly
        if self_attn_padding_mask is not None:
            padding_mask = self_attn_padding_mask.bool() if not self_attn_padding_mask.is_floating_point() else self_attn_padding_mask > 0.5
            inverse_padding_mask = ~padding_mask
            inverse_padding_mask = inverse_padding_mask.unsqueeze(2).float()  # B*T*1
        else:
            inverse_padding_mask = torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
        
        # Parallel branches
        ## Transformer branch
        x_attn = self.attn_layer_norm(x)
        x_attn = x_attn.permute(1, 0, 2)  # T*B*C
        x_attn, attn = self.self_attn(
            query=x_attn,
            key=x_attn,
            value=x_attn,
            key_padding_mask=self_attn_padding_mask,
            need_weights=True,
            attn_mask=self_attn_mask,
        )
        x_attn = x_attn.permute(1, 0, 2)  # B*T*C
        
        ## CNN branch
        x_conv = x.permute(0, 2, 1)  # B*C*T
        x_conv = self.conv(x_conv)  # B*C*T
        x_conv = x_conv.permute(0, 2, 1)  # B*T*C
        x_conv = self.conv_layer_norm(x_conv)
        x_conv = x_conv * inverse_padding_mask
        
        # Fusion
        if self.use_dynamic_fusion:
            x, gate_values = self.fusion(x_attn, x_conv)
        else:
            fusion_input = torch.cat([x_attn, x_conv], dim=-1)  # B*T*2C
            gate = self.fusion_gate(fusion_input)  # B*T*C
            x = gate * x_attn + (1 - gate) * x_conv  # Gated fusion
        
        # Residual
        x = residual + x
        residual = x
        
        # FFN (either standard or MOE with k-mer experts)
        x = self.final_layer_norm(x)
        x = self.ffn(x)
        x = self.dropout_layer(x)
        x = residual + x
        
        return x, attn


# Test the ParallelConvTransformerLayer
x = torch.rand(5, 100, 128).to(device)
pad = torch.zeros_like(x[...,0]).bool().to(device)
module = ParallelConvTransformerLayer(128, embed_dim*4, 8, kmer=7, dropout=0.1, use_moe=True, num_experts=4, top_k=2).to(device)
x, attn = module(x, self_attn_padding_mask=pad)
print(x.shape)  

torch.Size([5, 100, 128])


In [None]:
class MainModel(nn.Module):
    def __init__(self, alphabet, dropout=0.2, CovTransformer_layers=3, 
                 kmer=7, layers=6, embed_dim=128, nodes=40, heads=16):
        super().__init__()
        self.embedding_size = embed_dim
        self.nodes = nodes
        self.dropout = dropout
        self.esm2 = ESM2_SISS(num_layers = layers,
                        embed_dim = embed_dim,
                        attention_heads = heads,
                        alphabet = alphabet) 
        # Change to nn.ModuleList
        self.convtransformer_Feature Fusion Module = nn.ModuleList([
            ParallelConvTransformerLayer(embed_dim, embed_dim*4, heads, kmer, dropout=self.dropout, 
                                         use_moe=True, num_experts=4, top_k=2
                                         )
            for i in range(CovTransformer_layers)
        ])
        self.dropout = nn.Dropout(self.dropout)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        # Linear layer for processing experimental source information
        self.experiment_dense = nn.Linear(2, self.nodes)  # Handling one-hot experiment indicators
        self.linear = nn.Linear(in_features = 6 * embed_dim, out_features = self.nodes)
        self.linear_2 = nn.Linear(in_features = self.nodes, out_features = self.nodes * 4)
        self.linear_3 = nn.Linear(in_features = self.nodes * 4, out_features = self.nodes)
        self.output = nn.Linear(in_features = self.nodes, out_features = 1)

    def forward(self, tokens, experiment_indicator, self_attn_padding_mask=None):
        # ESM embedding
        embeddings = self.esm2(tokens, repr_layers, return_representation=True)
        embeddings_rep = embeddings["representations"][layers][:, 1 : -1] #B*(T+2)*E -> B*T*E

        for i, layer in enumerate(self.convtransformer_decoder):
            x_o, attn = layer(x=embeddings_rep, self_attn_padding_mask=self_attn_padding_mask)  #tokens: B*T*E, x_o: B*T*E

        x = torch.flip(x_o, dims=[1])  # Reverse along the sequence length dimension
        # Select frames corresponding to frame 1, frame 2, and frame 3
        frame_1 = x[:, 0::3, :]
        frame_2 = x[:, 1::3, :]
        frame_3 = x[:, 2::3, :]
        # Global Max Pooling
        frame_1_max = torch.max(frame_1, dim=1)[0]  # B*C
        frame_2_max = torch.max(frame_2, dim=1)[0]  # B*C
        frame_3_max = torch.max(frame_3, dim=1)[0]  # B*C
        # Expand the dimensions of self_attn_padding_mask to match the feature tensor
        mask_expanded = ~self_attn_padding_mask.unsqueeze(2)  # (batch_size, seq_len, 1)，True Indicates valid data
        # Mean pooling over valid positions
        def masked_mean(frame, mask):
            frame_sum = torch.sum(frame * mask, dim=1)
            mask_sum = torch.sum(mask, dim=1) + 1e-8  # Avoid dividing by zero
            return frame_sum / mask_sum
        # Global average pooling
        frame_1_avg = masked_mean(frame_1, mask_expanded[:, 0::3, :])
        frame_2_avg = masked_mean(frame_2, mask_expanded[:, 1::3, :])
        frame_3_avg = masked_mean(frame_3, mask_expanded[:, 2::3, :])
        # Concatenate the pooled tensors into a single tensor
        pooled_output = torch.cat([frame_1_max, frame_1_avg, frame_2_max, frame_2_avg, frame_3_max, frame_3_avg], dim=1)  # B*(6*C)
        # Linear layer processing experiment indicator
        experiment_output = self.experiment_dense(experiment_indicator)
        x_pooled = self.flatten(pooled_output)

        o_linear = self.linear(x_pooled) + experiment_output #Concatenate the pooling output with experimental information
        o_linear_2 = self.linear_2(o_linear)
        o_linear_3 = self.linear_3(o_linear_2)

        o_relu = self.relu(o_linear_3)
        o_dropout = self.dropout(o_relu)
        o = self.output(o_dropout)  # B*1

        return o
    
    
# test the MainModel
tokens = torch.tensor([[7, 4, 6, 3, 6, 6, 5, 3, 5, 4, 5, 5, 4, 5, 6, 5, 3, 5, 3, 3, 3, 6, 3, 6,
         4, 4, 3, 6, 6, 6, 5, 5, 4, 6, 5, 6, 4, 6, 6, 5, 3, 5, 3, 6, 6, 6, 6, 6,
         3, 5, 3, 4, 4, 4, 5, 5, 3, 5, 5, 6, 3, 5, 5, 5, 5, 3, 6, 4, 6, 3, 4, 3,
         3, 6, 4, 5, 3, 5, 3, 4, 3, 3, 6, 6, 3, 5, 5, 6, 5, 3, 6, 5, 3, 6, 3, 3,
         6, 4, 6, 5, 3, 1],
        [7, 4, 3, 6, 5, 3, 3, 3, 5, 6, 6, 3, 5, 5, 4, 6, 3, 6, 6, 3, 4, 5, 5, 5,
         4, 3, 5, 5, 5, 6, 4, 3, 5, 6, 6, 4, 6, 5, 3, 5, 4, 3, 5, 6, 4, 5, 4, 5,
         4, 5, 5, 3, 4, 3, 5, 6, 5, 5, 5, 5, 5, 5, 5, 4, 4, 6, 4, 3, 4, 3, 6, 6,
         4, 6, 6, 6, 5, 5, 5, 4, 4, 5, 4, 5, 4, 6, 6, 3, 5, 3, 4, 4, 5, 3, 6, 5,
         5, 3, 3, 3, 6, 1],
        [7, 3, 3, 5, 6, 6, 3, 5, 3, 3, 3, 4, 3, 6, 6, 4, 4, 6, 3, 4, 3, 3, 5, 6,
         3, 4, 3, 3, 3, 3, 3, 3, 6, 5, 3, 6, 6, 6, 3, 3, 6, 3, 5, 6, 3, 3, 5, 4,
         4, 3, 3, 3, 3, 3, 5, 6, 5, 6, 5, 3, 6, 5, 5, 5, 6, 3, 3, 4, 5, 5, 5, 3,
         3, 6, 4, 6, 5, 6, 3, 4, 4, 3, 3, 5, 5, 6, 3, 6, 3, 3, 3, 4, 3, 3, 3, 5,
         6, 6, 4, 3, 6, 1],
        [7, 6, 3, 4, 6, 6, 4, 4, 4, 4, 3, 4, 3, 6, 3, 6, 3, 5, 6, 5, 5, 6, 6, 5,
         4, 4, 4, 5, 5, 3, 6, 4, 3, 4, 5, 6, 4, 4, 3, 5, 5, 4, 3, 3, 4, 5, 3, 6,
         4, 5, 3, 6, 3, 5, 4, 6, 6, 6, 6, 6, 4, 6, 3, 4, 3, 6, 5, 4, 6, 6, 5, 6,
         6, 3, 6, 4, 3, 5, 4, 4, 4, 5, 6, 5, 4, 3, 4, 5, 4, 6, 3, 6, 3, 5, 3, 5,
         4, 6, 5, 6, 6, 1],
        [7, 4, 3, 3, 4, 3, 6, 5, 6, 6, 4, 5, 6, 4, 5, 5, 3, 6, 6, 4, 4, 5, 6, 5,
         6, 6, 5, 3, 3, 3, 5, 3, 3, 5, 6, 3, 5, 3, 4, 3, 3, 4, 6, 6, 3, 4, 4, 4,
         4, 4, 5, 6, 3, 5, 4, 5, 5, 4, 5, 5, 4, 4, 5, 4, 5, 4, 3, 3, 6, 4, 3, 3,
         5, 5, 5, 6, 5, 3, 6, 6, 4, 6, 6, 5, 4, 5, 3, 5, 5, 6, 6, 6, 5, 4, 4, 3,
         3, 3, 6, 5, 3, 1]]).to(device)

experiment_indicator = torch.tensor([[0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.]]).to(device)

self_attn_padding_mask = torch.tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False]]).to(device)

alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')

model = MainModel(alphabet).to(device)
model(tokens, experiment_indicator, self_attn_padding_mask=self_attn_padding_mask)


tensor([[ 0.8634],
        [-0.0563],
        [ 2.9372],
        [ 0.6452],
        [ 1.1146]], device='cuda:0', grad_fn=<AddmmBackward0>)

### Build Model

In [7]:
def r2(x,y):
    slope, intercept, r_value, p_value, std_err = stats.linregress(x,y)
    return r_value**2

def performances(label, pred):
    label, pred = list(label), list(pred)
    r = r2(label, pred)
    R2 = r2_score(label, pred)
    rmse = np.sqrt(mean_squared_error(label, pred))
    mae = mean_absolute_error(label, pred)
    try:
        pearson_r = pearsonr(label, pred)[0]
    except:
        pearson_r = -1e-9
    try:
        sp_cor = spearmanr(label, pred)[0]
    except:
        sp_cor = -1e-9
    print(f'r-squared = {r:.4f} | pearson r = {pearson_r:.4f} | spearman R = {sp_cor:.4f} | R-squared = {R2:.4f} | RMSE = {rmse:.4f} | MAE = {mae:.4f}')
    return [r, pearson_r, sp_cor, R2, rmse, mae]

def performances_to_pd(performances_list):
    performances_pd = pd.DataFrame(performances_list, index = ['r2', 'PearsonR', 'SpearmanR', 'R2', 'RMSE', 'MAE']).T
    return performances_pd

In [None]:
def generate_dataset_dataloader(e_data, obj_col, lab_col, batch_toks=8192*4, mask_prob = 0.0):
    dataset = FastaBatchedDataset(e_data.loc[:,obj_col], e_data.loc[:, lab_col], mask_prob = mask_prob)
    batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
    dataloader = torch.utils.data.DataLoader(dataset, 
                                            collate_fn=alphabet.get_batch_converter(), 
                                            batch_sampler=batches, 
                                            shuffle = False)
    print(f"{len(dataset)} sequences")
    return dataset, dataloader, batches

def get_experiment_indicator_for_batch(data_combine, batch_idx):
    # Get the experiment_indicator for the corresponding batch from train_combine
    batch_experiment_indicators = data_combine.iloc[batch_idx]['experiment_indicator'].values.tolist()
    # Convert to tensor
    experiment_indicator_tensor = torch.tensor(batch_experiment_indicators, dtype=torch.float32).to(device)
    return experiment_indicator_tensor

def shuffle_data_fn(in_data):
    # Use sample(frac=1) to shuffle the order of the dataset
    shuffle_data = in_data.sample(frac=1).reset_index(drop=True)
    return shuffle_data

In [8]:
def train_step(train_dataloader, train_shuffle_combine, train_shuffle_batch, model, epoch):        
    model.train()
    y_pred_list, y_true_list, loss_list = [], [], []
    
    for i, (labels, strs, masked_strs, toks, masked_toks, _) in enumerate(tqdm(train_dataloader)):
        toks = toks.to(device)
        padding_mask = toks.eq(alphabet.padding_idx)[:, 1:-1]
        labels = torch.FloatTensor(labels).to(device).reshape(-1, 1)
        experiment_indicator_tensor = get_experiment_indicator_for_batch(train_shuffle_combine, train_shuffle_batch[i])

        outputs= model(toks, experiment_indicator_tensor, self_attn_padding_mask=padding_mask)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_list.append(loss.cpu().detach())

        y_true_list.extend(labels.cpu().reshape(-1).tolist())
        y_pred = outputs.reshape(-1).cpu().detach().tolist()
        y_pred_list.extend(y_pred)
        
    loss_epoch = float(torch.Tensor(loss_list).mean())
    print(f'Train: Epoch-{epoch}/{num_epochs} | Loss = {loss_epoch:.4f} | ', end = '')
    
    metrics = performances(y_true_list, y_pred_list)
    return metrics, loss_epoch


def eval_step(test_dataloader, test_combine, test_batch, model, epoch):
    model.eval()
    y_pred_list, y_true_list, loss_list = [], [], []
    strs_list = []
    with torch.no_grad():
        for i, (labels, strs, masked_strs, toks, masked_toks, _) in enumerate(tqdm(test_dataloader)):
            strs_list.extend(strs)
            toks = toks.to(device)
            padding_mask = toks.eq(alphabet.padding_idx)[:, 1:-1]
            labels = torch.FloatTensor(labels).to(device).reshape(-1, 1)
            experiment_indicator_tensor = get_experiment_indicator_for_batch(test_combine, test_batch[i])
            
            outputs= model(toks, experiment_indicator_tensor, self_attn_padding_mask=padding_mask)
            loss = criterion(outputs, labels)
            loss_list.append(loss.cpu().detach())

            y_pred = outputs.reshape(-1).cpu().detach().tolist()
            y_true_list.extend(labels.cpu().reshape(-1).tolist())
            y_pred_list.extend(y_pred)
        
        loss_epoch = float(torch.Tensor(loss_list).mean())
        print(f'Test: Epoch-{epoch}/{num_epochs} | Loss = {loss_epoch:.4f} | ', end = '')
        metrics = performances(y_true_list, y_pred_list)
        e_pred = pd.DataFrame([strs_list, y_true_list, y_pred_list], index = ['utr', 'y_true', 'y_pred']).T
        
    return metrics, loss_epoch, e_pred


In [None]:
print('====Load Data====')
train_data_50_random = pd.read_csv(f'Data/train_test/4.1_train_data_GSM3130435_egfp_unmod_1_BiologyFeatures.csv')
train_data_vary_random = pd.read_csv(f'Data/train_test/VaryLengthRandomTrain_sequence.csv')
test_data_50_random = pd.read_csv(f'Data/train_test/4.1_test_data_GSM3130435_egfp_unmod_1.csv')
test_data_vary_random = pd.read_csv(f'Data/train_test/VaryLengthRandomTest_sequence_num7600.csv')
eval_data_vary_human = pd.read_csv(f'Data/train_test/VaryLengthHumanAll_sequence_num15555.csv')

# train data
# Add 50 <pad> to the left side of the utr column
train_data_50_random['utr_100'] = '<pad>'*50 + train_data_50_random['utr']
test_data_50_random['utr_100'] = '<pad>'*50 + test_data_50_random['utr']

# Select the 'rl' and 'utr' columns
train_data_50_selected = train_data_50_random[['rl', 'utr_100']]
train_data_vary_selected = train_data_vary_random[['rl', 'utr_100']]

# Add experimental indicators for train_data_50_selected and train_data_vary_selected
train_data_50_selected.loc[:, 'experiment_indicator'] = [[1, 0]] * len(train_data_50_selected)
train_data_vary_selected.loc[:, 'experiment_indicator'] = [[0, 1]] * len(train_data_vary_selected)

# Merge two datasets
train_combine = pd.concat([train_data_50_selected, train_data_vary_selected], ignore_index=True)

# test data
# Select the 'rl' and 'utr' columns
test_data_50_selected = test_data_50_random[['rl', 'utr_100']]
test_data_vary_selected = test_data_vary_random[['rl', 'utr_100']]
# Add experimental indicators for test_data_50_selected, test_data_vary_selected, and eval_human_selected
test_data_50_selected.loc[:, 'experiment_indicator'] = [[1, 0]] * len(test_data_50_selected)
test_data_vary_selected.loc[:, 'experiment_indicator'] = [[0, 1]] * len(test_data_vary_selected)
# Merge two datasets
test_combine = pd.concat([test_data_50_selected, test_data_vary_selected], ignore_index=True)

# eval data
eval_human_selected = eval_data_vary_human[['rl', 'utr_100']]
eval_human_selected.loc[:, 'experiment_indicator'] = [[0, 1]] * len(eval_human_selected)

====Load Data====


In [10]:
print(len(train_data_50_random))
print(len(train_data_vary_random))
print(len(test_data_50_random))
print(len(test_data_vary_random))
print(len(eval_data_vary_human))

260000
76319
20000
7600
15555


In [11]:
alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
print(alphabet.tok_to_idx)
assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}

{'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}


In [14]:
train_data_50_random

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,utr,0,1,2,3,4,5,6,...,codon_P,codon_R,codon_V,codon_W,codon_N,uORF,CGperc,utrlen_m80,ATratio,utr_100
0,0,20000,CCGGCTATAGCGCGCAGTGCTCGGATGGCAAGGCGTTCAACCGTGA...,9.131133e-06,9.043781e-06,1.146974e-05,1.399193e-05,1.112620e-05,7.725591e-06,4.959113e-06,...,0.0625,0.2500,0.0625,0.0000,0.0000,0.0,0.68,30.0,0.222222,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
1,1,20001,AATGGGTACCGTATGTCAAAGGAACCATAGGCGCGGTCGTAGTGGA...,1.675790e-05,1.469614e-05,1.625375e-05,9.543211e-06,6.772468e-06,4.363158e-06,3.449818e-06,...,0.0625,0.1250,0.0000,0.0625,0.0625,2.0,0.56,30.0,0.363636,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
2,2,20002,CCGCGGCATGATCAAACGGTCTAGATCATAATGCGAGGTCAGTACC...,1.109030e-05,1.347871e-05,1.423318e-05,1.198283e-05,8.126962e-06,6.484693e-06,4.168530e-06,...,0.0625,0.1250,0.0625,0.0000,0.0000,1.0,0.54,30.0,0.454545,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
3,3,20003,GTGTAAGGCTCTGGTGGGCACACGTCGGTTTGCAGTAATGACAGAC...,1.448387e-05,1.843540e-05,1.509489e-05,9.220320e-06,9.723329e-06,6.004345e-06,2.515492e-06,...,0.0000,0.0625,0.1250,0.0000,0.0625,1.0,0.58,30.0,0.076923,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
4,4,20004,CCGGAAAATCGCGCAAGTCTCCCTTACGAATAAAGCTCAAAGGGAA...,1.399407e-06,2.087026e-06,1.753147e-06,4.628098e-06,5.272850e-06,6.965040e-06,8.121447e-06,...,0.1250,0.0625,0.0000,0.0000,0.0625,1.0,0.54,30.0,1.000000,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
259995,259995,279995,TCGCTTTGAGGCGCACGTGAAACGCAGCAACCCCCTCGAAGTTGTC...,3.148667e-07,5.217566e-07,3.862867e-07,2.511371e-07,4.837477e-07,1.040753e-06,7.546477e-07,...,0.1250,0.1250,0.0000,0.0000,0.0000,1.0,0.58,30.0,0.272727,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
259996,259996,279996,CGCTAGAACCAGGAGCAACGTAGCGTAAGGGGAATATGAAGGAAAA...,2.448963e-07,0.000000e+00,0.000000e+00,2.152604e-07,1.064245e-06,9.606952e-07,1.006197e-06,...,0.0000,0.2500,0.0625,0.0000,0.0625,2.0,0.50,30.0,2.142857,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
259997,259997,279997,GGTGAACGACGAGACAATAACGTGGGCCTTTAAAGGAATAGGGCAA...,0.000000e+00,0.000000e+00,0.000000e+00,3.228906e-07,0.000000e+00,1.120811e-06,6.827765e-07,...,0.0000,0.2500,0.0625,0.0000,0.1875,1.0,0.48,30.0,1.333333,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...
259998,259998,279998,CCGGAGTATTACATCACACATGAGCCCGGTGGGGATGTTGGGGAGC...,1.784244e-06,2.347905e-06,1.782862e-06,1.435069e-06,2.418739e-07,0.000000e+00,5.030985e-07,...,0.1250,0.0000,0.0625,0.0000,0.0000,0.0,0.62,30.0,0.300000,<pad><pad><pad><pad><pad><pad><pad><pad><pad><...


In [19]:
train_shuffle_data = shuffle_data_fn(train_combine)
train_shuffle_dataset,  train_shuffle_dataloader, train_shuffle_batch = generate_dataset_dataloader(train_shuffle_data, 'rl', 'utr_100', mask_prob = 0.0)
test_dataset, test_dataloader, test_batch = generate_dataset_dataloader(test_combine, 'rl', 'utr_100', mask_prob = 0.0)
eval_dataset, eval_dataloader, eval_batch = generate_dataset_dataloader(eval_human_selected, 'rl', 'utr_100', mask_prob = 0.0)

esm2_modelfile = 'Model/utr_lm/ESM2SISS_FS4.22_fiveSpeciesCao_6layers_16heads_128embedsize_4096batchToks_lr1e-05_supervisedweight1.0_structureweight1.0_MLMLossMin_epoch115.pkl'

model = MainModel(alphabet).to(device)
state_dict = torch.load(esm2_modelfile, map_location=device)
model.esm2.load_state_dict({k.replace('module.', ''):v for k,v in state_dict.items()})

num_epochs = 300

learning_rate = 1e-4 #1e-4, 1e-05

optimizer = torch.optim.SGD(
    model.parameters(),
    lr = learning_rate, 
    momentum=0.9,
    weight_decay = 1e-4)

criterion = torch.nn.HuberLoss()

loss_best, ep_best, r2_best = np.inf, -1, -1
loss_train_dict, loss_test_dict = dict(), dict()
loss_valid_dict = dict() 

metrics_train_dict = dict()
metrics_test_dict = dict()
metrics_valid_dict = dict()


336319 sequences
27600 sequences
15555 sequences


In [None]:
# Designated training information
train_info = 'train_info'
folder_path = f"{train_info}/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
    print(f"Folder created: {folder_path}")
else:
    print(f"The folder already exists: {folder_path}")

In [16]:
torch.cuda.empty_cache()

print(torch.cuda.memory_allocated())
print(torch.cuda.memory_reserved())

108298240
117440512


In [None]:
for epoch in range(1, num_epochs+1):
    metrics_train, loss_train = train_step(train_shuffle_dataloader, train_shuffle_data, train_shuffle_batch, model, epoch)
    loss_train_dict[epoch] = loss_train
    metrics_train_dict[epoch] = metrics_train
    
    metrics_test, loss_test, _ = eval_step(test_dataloader, test_combine, test_batch, model, epoch)
    loss_test_dict[epoch] = loss_test
    metrics_test_dict[epoch] = metrics_test

    if metrics_test[0] > r2_best: 
        path_saver = f'{train_info}/model_epoch{epoch}.pkl'
        r2_best, ep_best = metrics_test[0], epoch
        torch.save(model.eval().state_dict(), path_saver)
        print(f'****Saving model in {path_saver}: Best epoch = {ep_best} | Train Loss = {loss_train:.4f} |  Val Loss = {loss_test:.4f} | R2_best = {r2_best:.4f}')
        
    metrics_human, loss_human, _ = eval_step(eval_dataloader, eval_human_selected, eval_batch, model, epoch)
    loss_valid_dict[epoch] = loss_test
    metrics_valid_dict[epoch] = metrics_human

In [None]:
loss_train_df = pd.DataFrame(loss_train_dict,
                             index = ['train_loss']).T
loss_test_df = pd.DataFrame(loss_test_dict,
                            index = ['test_loss']).T
loss_valid_df = pd.DataFrame(loss_valid_dict,
                            index = ['valid_loss']).T


metrics_train_df = pd.DataFrame(metrics_train_dict, 
                                index = [
                                        'Train_r2', 'Train_PearsonR', 'Train_SpearmanR', 'Train_R2', 'Train_RMSE', 'Train_MAE'
                                        ]).T

metrics_test_df = pd.DataFrame(metrics_test_dict, 
                                index = [
                                        'test_r2', 'test_PearsonR', 'test_SpearmanR', 'test_R2', 'test_RMSE', 'test_MAE'
                                        ]).T
metrics_valid_df = pd.DataFrame(metrics_valid_dict, 
                                index = [
                                        'valid_r2', 'valid_PearsonR', 'valid_SpearmanR', 'valid_R2', 'valid_RMSE', 'valid_MAE'
                                        ]).T
# Save training and testing loss
loss_train_df.to_csv(f'{train_info}/train_loss.csv')
loss_test_df.to_csv(f'{train_info}/test_loss.csv')
loss_valid_df.to_csv(f'{train_info}/human_loss.csv')

metrics_train_df.to_csv(f'{train_info}/train_metrics.csv', index = True)
metrics_test_df.to_csv(f'{train_info}/test_metrics.csv', index = True)
metrics_valid_df.to_csv(f'{train_info}/human_metrics.csv', index = True)