In [2]:
import sys
sys.path.append('./')
sys.path.append('../')
sys.path.append('../..')

import os
import pandas as pd
from sklearn import preprocessing
from typing import Sequence, Tuple, List, Union
from tqdm import tqdm
import fm
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
import numpy as np
import random

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(2024)

## 1. Load Model

### (1) define utr_function_predictor

In [3]:
class CNNModel(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(CNNModel, self).__init__()
        main_planes = 64
        dropout = 0.2

        self.conv1 = nn.Conv1d(in_planes, main_planes, kernel_size=3, padding=1)
        self.resblock1 = ResBlock(main_planes, main_planes, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.resblock2 = ResBlock(main_planes, main_planes, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.resblock3 = ResBlock(main_planes, main_planes, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.resblock4 = ResBlock(main_planes, main_planes, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.resblock5 = ResBlock(main_planes, main_planes, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.resblock6 = ResBlock(main_planes, main_planes, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(main_planes, out_planes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        x = self.resblock4(x)
        x = self.resblock5(x)
        x = self.resblock6(x)
        feature_maps = x  # Save the feature maps before pooling
        x = self.pool(x)
        x = self.flatten(x)
        x = self.dropout(x)
        out = self.fc(x)
        return out, feature_maps

class Human5PrimeUTRPredictor(torch.nn.Module):
    """
    contact predictor with inner product
    """
    def __init__(self, task="rgs", arch="cnn", input_types=["seq", "emb-rnafm"]):
        """
        :param depth_reduction: mean, first
        """       
        super().__init__()        
        self.task = task
        self.arch = arch
        self.input_types = input_types        
        self.padding_mode = "right" #"left"
        self.token_len = 100
        self.out_plane = 1
        self.in_channels = 0
        if "seq" in self.input_types:
            self.in_channels = self.in_channels + 4

        if "emb-rnafm" in self.input_types:
            self.reductio_module = nn.Linear(640, 32)
            self.in_channels = self.in_channels + 32  

        # if self.arch == "cnn":
        #     self.predictor = self.create_1dcnn_for_emd(in_planes=self.in_channels, out_planes=1)
        if self.arch == "cnn" and self.in_channels != 0:
            self.predictor = CNNModel(in_planes=self.in_channels, out_planes=1)
        else:
            raise Exception("Wrong Arch Type")

    def forward(self, tokens, inputs):
        ensemble_inputs = []
        if "seq" in inputs:
            nest_token = (tokens[:, 1:-1] - 4)

            # 100 right padding
            nest_token = torch.nn.functional.pad(nest_token, (0, self.token_len - nest_token.shape[1]), value=-1)
            token_padding_mask = nest_token.ge(0).long()
            one_hot_tokens = torch.nn.functional.one_hot((nest_token * token_padding_mask),
                                                         num_classes=4).float() * token_padding_mask.unsqueeze(-1)
            if self.arch == "cnn":     # B, L, 4
                one_hot_tokens = one_hot_tokens.permute(0, 2, 1)
            else:
                one_hot_tokens = one_hot_tokens.view(one_hot_tokens.shape[0], -1)
            ensemble_inputs.append(one_hot_tokens)

        if "emb-rnafm" in inputs:
            embeddings = inputs["emb-rnafm"]
            # remove auxiliary tokens
            embeddings, padding_masks = self.remove_pend_tokens_1d(tokens, embeddings)
            if len(embeddings.size()) == 3:  # for pure seq
                batch_size, seqlen, hiddendim = embeddings.size()
            elif len(embeddings.size()) == 4:  # for msa
                batch_size, depth, seqlen, hiddendim = embeddings.size()
                # reduction
                embeddings = self.msa_depth_reduction(embeddings, padding_masks)
            else:
                raise Exception("Unknown Embedding Type!")

            # 100 right padding
            embeddings = torch.nn.functional.pad(embeddings, (0, 0, 0, self.token_len - embeddings.shape[1]))
            embeddings = self.reductio_module(embeddings)

            if self.arch == "cnn":
                embeddings = embeddings.permute(0, 2, 1)
            else:
                embeddings = embeddings.reshape(embeddings.shape[0], -1)

            ensemble_inputs.append(embeddings)        

        if len(ensemble_inputs) > 0:
            ensemble_inputs = torch.cat(ensemble_inputs, dim=1)

        if self.padding_mode == "left":
            # 100 left padding
            nest_token = (tokens[:, 1:-1] - 4)
            nest_token = torch.nn.functional.pad(nest_token, (0, self.token_len - nest_token.shape[1]), value=-1)
            token_padding_mask = nest_token.ge(0).long()

            left_ensembles = []
            for i in range(nest_token.shape[0]):
                length = token_padding_mask[i].sum().item()
                left_ensemble = torch.cat([ensemble_inputs[i, :, length:], ensemble_inputs[i, :, 0: length]], dim=-1)
                left_ensembles.append(left_ensemble)
            ensemble_inputs = torch.stack(left_ensembles, dim=0)

        if isinstance(ensemble_inputs, list) != True:
            output, feature_maps = self.predictor(ensemble_inputs)
            output = output.squeeze(-1)
            return output, feature_maps
        else:
            output = 0
        return output
 
    def create_1dcnn_for_emd(self, in_planes, out_planes):
        main_planes = 64
        dropout = 0.2
        emb_cnn = nn.Sequential(
            nn.Conv1d(in_planes, main_planes, kernel_size=3, padding=1),  ## 3
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),  
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),  
            ResBlock(main_planes * 1, main_planes * 1, stride=2, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d), 
            ResBlock(main_planes * 1, main_planes * 1, stride=1, dilation=1, conv_layer=nn.Conv1d,
                     norm_layer=nn.BatchNorm1d),       
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(main_planes * 1, out_planes),
        )
        return emb_cnn
    
    def remove_pend_tokens_1d(self, tokens, seqs):
        """
        :param tokens:
        :param seqs: must be shape of [B, ..., L, E]    # seq: [B, L, E]; msa: [B, D, L, E]
        :return:
        """
        self.pad_idx = 1
        self.eos_idx = 2
        self.append_eos = True
        self.prepend_bos = True
        
        padding_masks = tokens.ne(self.pad_idx)

        # remove eos token  （suffix first）
        if self.append_eos:     # default is right
            eos_masks = tokens.ne(self.eos_idx)
            eos_pad_masks = (eos_masks & padding_masks).to(seqs)
            seqs = seqs * eos_pad_masks.unsqueeze(-1)
            seqs = seqs[:, ..., :-1, :]
            padding_masks = padding_masks[:, ..., :-1]

        # remove bos token
        if self.prepend_bos:    # default is left
            seqs = seqs[:, ..., 1:, :]
            padding_masks = padding_masks[:, ..., 1:]

        if not padding_masks.any():
            padding_masks = None

        return seqs, padding_masks
    

class ResBlock(nn.Module):
    def __init__(
        self,
        in_planes,
        out_planes,
        stride=1,
        dilation=1,
        conv_layer=nn.Conv2d,
        norm_layer=nn.BatchNorm2d,
    ):
        super(ResBlock, self).__init__()        
        self.bn1 = norm_layer(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = conv_layer(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, bias=False)       
        self.bn2 = norm_layer(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv_layer(out_planes, out_planes, kernel_size=3, padding=dilation, bias=False)

        if stride > 1 or out_planes != in_planes: 
            self.downsample = nn.Sequential(
                conv_layer(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
                norm_layer(out_planes),
            )
        else:
            self.downsample = None

        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)        
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        return out
    
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        nn.init.normal_(m.weight, std=0.001) #std=0.001)
        if isinstance(m.bias, nn.Parameter):
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BasicConv') != -1: 
        pass
    elif classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)
    elif classname.find('BatchNorm') != -1:
        if m.affine:
            nn.init.constant_(m.weight, 1.0)
            nn.init.constant_(m.bias, 0.0)

### (2) create RNA-FM backbone

In [4]:
device="cuda"   # "cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"  
backbone, alphabet = fm.pretrained.rna_fm_t12()
backbone.to(device)
print("create RNA-FM_backbone sucessfully")

create RNA-FM_backbone sucessfully


### (3) create UTR function downstream predictor

In [5]:
task="rgs"
arch="cnn"
input_items = ["seq","emb-rnafm"]   # ["seq"], ["emb-rnafm"], ["seq","emb-rnafm"] 
model_name = arch.upper() + "_" + "_".join(input_items) 
utr_func_predictor = Human5PrimeUTRPredictor(
    task=task, arch=arch, input_types=input_items    
)
utr_func_predictor.apply(weights_init)
utr_func_predictor.to(device)
print("create utr_func_predictor sucessfully")
print(utr_func_predictor)

create utr_func_predictor sucessfully
Human5PrimeUTRPredictor(
  (reductio_module): Linear(in_features=640, out_features=32, bias=True)
  (predictor): CNNModel(
    (conv1): Conv1d(36, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (resblock1): ResBlock(
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(2,), padding=(1,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (downsample): Sequential(
        (0): Conv1d(64, 64, kernel_size=(1,), stride=(2,), bias=False)
        (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (resblock2): ResBlock(
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_r

### (4) define loss function and optimizer

In [6]:
criterion = nn.MSELoss(reduction="none")
optimizer = optim.Adam(utr_func_predictor.parameters(), lr=0.001)

## 2. Load Data
You should download data from gdrive link: https://drive.google.com/file/d/10zCfOHOaOa__J2AIuZyidZ9sVJ9L11rI/view?usp=sharing and place it in the tutorials/utr-function-prediction

### (1) define utr dataset

In [7]:

class Human_5Prime_UTR_VarLength(object):
    def __init__(self, root, set_name="train", input_items=["seq", "emb-rnafm"], label_items=["rl"], use_cache=0):
        """
        :param root: root path of dataset - CATH. however not all of stuffs under this root path
        :param data_type: seq, msa
        :param label_type: 1d, 2d
        :param set_name: "train", "valid", "test"
        """
        self.root = root
        self.theme = "rna"
        self.data_type = "seq"
        self.set_name = set_name          
        self.use_cache = 1
        self.cache = {}

        self.input_items = input_items
        if "seq" not in self.input_items:
            self.input_items.append("seq")  # for generate RNA-FM embedding
        self.label_items = label_items
        self.data_items = input_items + label_items  # combine two items
    
        # 1. Create Paths
        self.src_scv_path = os.path.join(self.root, "data", "GSM4084997_varying_length_25to100.csv") 
        
        # 2. Create Data INFO        
        self.data_avatars, self.stats = self.__dataset_info(self.src_scv_path, self.data_items)
        self.labels = ['rl']

    def __getitem__(self, index):
        inputs = {}
        inputs["seq"] = (self.data_avatars["name"][index], self.data_avatars["seq"][index][0].replace("T", "U"))
        labels = {"rl": self.data_avatars["rl"][index],}
        return inputs, labels

    def __len__(self):
        return self.stats.shape[0]

    def __dataset_info(self, src_csv_path, data_items):
        """
        :param name_path: txt record name list for specific set_name
        :param data_dir:
        :param msa_dir:
        :param ann1d_dir:
        :param ann2d_dir:
        :return:
        """
        seq_col = ['utr100']
        label_col = ['rl']  # need scale

        src_df = pd.read_csv(src_csv_path)   #, index_col=0)
        src_df.loc[:, "ori_index"] = src_df.index
        random_df = src_df[src_df['set'] == 'random']
        ## Filter out UTRs with too few less reads
        random_df = random_df[random_df['total_reads'] >= 10]    # 87000 -> 83919
        random_df['utr100'] = random_df['utr'] # + 75 * 'N'  # left padding from reference code
        #random_df['utr100'] = random_df['utr100'].str[-100:]
        random_df.sort_values('total_reads', inplace=True, ascending=False)
        random_df.reset_index(inplace=True, drop=True)

        # human set for evaluation too
        human_df = src_df[src_df['set'] == 'human']
        ## Filter out UTRs with too few less reads
        human_df = human_df[human_df['total_reads'] >= 10]   # 16739 -> 15555
        human_df['utr100'] = human_df['utr'] #+ 75 * 'N'  # left padding from reference code, here we use right padding
        #human_df['utr100'] = human_df['utr100'].str[-100:]
        human_df.sort_values('total_reads', inplace=True, ascending=False)
        human_df.reset_index(inplace=True, drop=True)

        random_df_test = pd.DataFrame(columns=random_df.columns)
        for i in range(25, 101):
            tmp = random_df[random_df['len'] == i]
            tmp.sort_values('total_reads', inplace=True, ascending=False)
            tmp.reset_index(inplace=True, drop=True)
            random_df_test = random_df_test.append(tmp.iloc[:100])

        human_df_test = pd.DataFrame(columns=human_df.columns)
        for i in range(25, 101):
            tmp = human_df[human_df['len'] == i]
            tmp.sort_values('total_reads', inplace=True, ascending=False)
            tmp.reset_index(inplace=True, drop=True)
            human_df_test = human_df_test.append(tmp.iloc[:100])

        train_df = pd.concat([random_df, random_df_test]).drop_duplicates(keep=False)  # 去重后 76319
        self.scaler = preprocessing.StandardScaler()
        self.scaler.fit(train_df.loc[:, label_col].values.reshape(-1, 1))
        train_df.loc[:,'scaled_rl'] = self.scaler.transform(train_df.loc[:, label_col].values.reshape(-1, 1))
        random_df_test.loc[:, 'scaled_rl'] = self.scaler.transform(random_df_test.loc[:, label_col].values.reshape(-1, 1))
        human_df_test.loc[:, 'scaled_rl'] = self.scaler.transform(human_df_test.loc[:, label_col].values.reshape(-1, 1))

        if self.set_name == "train":
            set_df = train_df
        elif self.set_name == "valid":
            set_df = random_df_test
        else:
            set_df = human_df_test

        seq = set_df[seq_col].values
        scaled_rl = set_df['scaled_rl'].values
        selected_indices = set_df["ori_index"].values

        self.name = selected_indices

        data_paths = {"name": self.name}
        for itemname in data_items:
            # input
            if itemname == "seq":
                data_paths[itemname] = seq
            elif itemname == "emb-rnafm":
                pass
            # label
            elif itemname == "rl":
                data_paths[itemname] = scaled_rl
            else:
                raise Exception("Unknown data item name {}".format(itemname))


        print("{} Dataset Num Samples: {} ".format(self.set_name, set_df["len"].shape[0]))
        #print("Length-Frequency Table")
        #print(set_df["len"].describe())  # value_counts())
        return data_paths, set_df
    
    
class BatchConverter(object):
    """
    Callable to convert an unprocessed (labels + strings) batch to a processed (labels + tensor) batch.
    """
    def __init__(self, alphabet, data_type="seq",):
        """
        :param alphabet:
        :param data_type: seq, msa
        """
        self.alphabet = alphabet
        self.data_type = data_type.split("+")

    def __call__(self, raw_data, raw_anns=None):
        """
        :param raw_data: each element in raw data should contain (description, seq)
        :param raw_anns:
        :return:
        """
        # creat a new batch of data tensors
        data = {}
        for key in raw_data.keys():
            if key == "seq":
                labels, strs, tokens = self.__call_seq__(raw_data["seq"])

                data["description"] = labels
                data["string"] = strs
                data["token"] = tokens
                data["depth"] = [1] * len(strs)
                data["length"] = [len(s) for s in strs]
            else:
                if isinstance(raw_data[key][0], str):
                    data[key] = raw_data[key]
                elif isinstance(raw_data[key][0], np.ndarray):
                    try:   # same length
                        data[key] = torch.Tensor(raw_data[key])
                    except:
                        # here we padding them with 0 for consistance with cnn's padding, which is different with ann's padding
                        data[key] = torch.Tensor(self.__padding_numpy_matrix__(raw_data[key], data["length"], pad_idx=0))
                elif isinstance(raw_data[key][0], float) or isinstance(raw_data[key][0], int):
                    data[key] = torch.Tensor(raw_data[key])


        # creat a new batch of ann tensors
        if raw_anns is not None:
            anns = {}
            for key in raw_anns.keys():
                if isinstance(raw_anns[key][0], str):
                    anns[key] = raw_anns[key]
                elif isinstance(raw_anns[key][0], np.ndarray):
                    try:   # same length
                        anns[key] = torch.Tensor(raw_anns[key])
                    except:
                        anns[key] = torch.Tensor(self.__padding_numpy_matrix__(raw_anns[key], data["length"]))
                elif isinstance(raw_anns[key][0], float) or isinstance(raw_anns[key][0], int):
                    anns[key] = torch.Tensor(raw_anns[key])
        else:
            anns = None

        return data, anns

    def __call_seq__(self, raw_batch: Sequence[Tuple[str, str]]):
        # RoBERTa uses an eos token, while ESM-1 does not.
        batch_size = len(raw_batch)
        max_len = max(len(seq_str) for _, seq_str in raw_batch)
        tokens = torch.empty(
            (
                batch_size,
                max_len
                + int(self.alphabet.prepend_bos)
                + int(self.alphabet.append_eos),
            ),
            dtype=torch.int64,
        )
        tokens.fill_(self.alphabet.padding_idx)
        labels = []
        strs = []

        for i, (label, seq_str) in enumerate(raw_batch):
            labels.append(label)
            strs.append(seq_str)
            if self.alphabet.prepend_bos:
                tokens[i, 0] = self.alphabet.cls_idx
            seq = torch.tensor(
                [self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64
            )
            tokens[
            i,
            int(self.alphabet.prepend_bos): len(seq_str)
                                            + int(self.alphabet.prepend_bos),
            ] = seq
            if self.alphabet.append_eos:
                tokens[
                    i, len(seq_str) + int(self.alphabet.prepend_bos)
                ] = self.alphabet.eos_idx

        return labels, strs, tokens
    
def LofD_to_DofL(raw_batch):
    """
    list of dict to dict of list
    :param raw_batch:
    :return:
    """
    batch_size = len(raw_batch)
    example = raw_batch[0]
    new_batch = {}
    for key in example.keys():
        new_batch[key] = []
        for i in range(batch_size):
            new_batch[key].append(raw_batch[i][key])
    return new_batch

def build_collate_fn(alphabet, data_type):
    batch_converter = BatchConverter(alphabet, data_type)
    def collate_fn(batch):
        if len(batch[0]) == 1:
            data = zip(*batch)
            data = LofD_to_DofL(data)
            data, anns = batch_converter(data)
            anns = None
        elif len(batch[0]) == 2:
            data, anns = zip(*batch)
            data = LofD_to_DofL(data)
            anns = LofD_to_DofL(anns)
            data, anns = batch_converter(data, anns)
        else:
            raise Exception("Unexpected Num of Components in a Batch")
        return data, anns

    return collate_fn


### (2) generate dataloaders

In [8]:
from torch.utils.data import Subset
root_path = "./"
subset_size = 300  # Define the size of your subset

# Create indices for the subset
train_indices = list(range(subset_size))

train_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name="train", input_items=input_items, )
# train_set = Subset(train_set, train_indices)
# test the first batch
# train_set = train_set[:300]
val_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name="valid", input_items=input_items)
test_set =  Human_5Prime_UTR_VarLength(root=root_path, set_name="test", input_items=input_items)

collate_fn = build_collate_fn(alphabet, train_set.data_type)
# collate_fn = build_collate_fn(alphabet, train_set.dataset.data_type)

num_workers = 4
train_batch_size = 64
train_loader = DataLoader(
    train_set, batch_size=train_batch_size, sampler=RandomSampler(train_set, replacement=False),
    num_workers=num_workers, collate_fn=collate_fn, drop_last=False
)

val_batch_size = train_batch_size #* 4
val_loader = DataLoader(
    val_set, batch_size=val_batch_size, sampler=RandomSampler(val_set, replacement=False),
    num_workers=num_workers, collate_fn=collate_fn, drop_last=False
)

test_batch_size = train_batch_size #* 4
test_loader = DataLoader(
    test_set, batch_size=test_batch_size, sampler=RandomSampler(test_set, replacement=False),
    num_workers=num_workers, collate_fn=collate_fn, drop_last=False
)

# scaler = train_set.dataset.scaler
scaler = train_set.scaler


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return func(*args, **kwargs)


train Dataset Num Samples: 76319 


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return func(*args, **kwargs)


valid Dataset Num Samples: 7600 


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  return func(*args, **kwargs)


test Dataset Num Samples: 7600 


## 3. Training Model

### (1) define eval function

In [11]:
def model_eval(data_loader, i_epoch, set_name="unknown"):
    
    all_losses = []
    all_cams = []
    pbar = tqdm(data_loader, desc="Epoch {}, {} Set - real MSE: {}".format(i_epoch,set_name, "NaN"), ncols=100)
    for index, (data, anns) in enumerate(pbar):
    # for index, (data, anns) in enumerate(data_loader):
        backbone.eval()
        utr_func_predictor.eval()
        with torch.no_grad():
            x = data["token"].to(device)
            true_rl = anns["rl"].to(device)            
            inputs = {}
            results = {}
            if "seq" in input_items:
                inputs["seq"] = x         
            if "emb-rnafm" in input_items:
                results = backbone(x, need_head_weights=False, repr_layers=[12], return_contacts=False)
                inputs["emb-rnafm"] = results["representations"][12]    
            results["rl"], feature_maps = utr_func_predictor(x, inputs)
            fc_weights = utr_func_predictor.predictor.fc.weight.detach().cpu().numpy()
            batch_cams = []
            for i in range(feature_maps.size(0)):
                fmap = feature_maps[i].detach().cpu().numpy()
                weights = fc_weights[0]
                cam = np.dot(weights, fmap)
                batch_cams.append(cam)
            all_cams.extend(batch_cams)
            # After model's forward pass
            pds = scaler.inverse_transform(results["rl"].detach().cpu().numpy())
            gts = scaler.inverse_transform(true_rl.detach().cpu().numpy())
            pds = torch.Tensor(pds)
            gts = torch.Tensor(gts)            
        
            losses = criterion(pds, gts)  
            # losses = criterion(results["rl"], true_rl)  
            all_losses.append(losses.detach().cpu())     

    avg_loss = torch.cat(all_losses, dim=0).mean()
    print("Epoch {}, Evaluation on {} Set - MSE loss: {:.3f}".format(i_epoch, set_name, avg_loss))
    
    return avg_loss, all_cams

### (2) training process

In [12]:
n_epoches = 2
best_mse = 10
best_epoch = 0

for i_e in range(1, n_epoches+1):
    all_losses = []
    n_sample = 0
    n_iter = len(train_loader)

    pbar = tqdm(train_loader, desc="Epoch {}, Train Set - MSE loss: {}".format(i_e, "NaN"), ncols=100)
    for index, (data, anns) in enumerate(pbar):
        backbone.eval()
        utr_func_predictor.train()
        x = data["token"].to(device)
        true_rl = anns["rl"].to(device)        
        
        inputs = {}
        results = {}
        if "seq" in input_items:
            inputs["seq"] = x         
        if "emb-rnafm" in  input_items:            
            with torch.no_grad():
                results = backbone(x, need_head_weights=False, repr_layers=[12], return_contacts=False)            
            inputs["emb-rnafm"] = results["representations"][12]                
        results["rl"], feature_maps = utr_func_predictor(x, inputs) 
        losses = criterion(results["rl"], true_rl)
        batch_loss = losses.mean()
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        all_losses.append(losses.detach().cpu())
        current_avg_loss = torch.cat(all_losses, dim=0).mean()
        
        pbar.set_description("Epoch {}, Train Set - MSE loss: {:.3f}".format(i_e, current_avg_loss))
        #if index % 1 == 0:
        #    print("epoch: {}, iter:{}/{} MSE loss: {}".format(i_e, index+1, iters, all_loss/n_sample ))
    
        #if index+1 > 10:
        #    break    
    random_mse, all_cams = model_eval(val_loader, i_e, set_name="Random")
    
    
    if random_mse < best_mse:
        best_epoch = i_e
        best_mse = random_mse
        torch.save(utr_func_predictor.state_dict(), "result/{}_best_utr_predictor.pth".format(model_name))
    print("--------- Model: {}, Best Epoch {}, Best MSE {:.3f}".format(model_name, best_epoch, best_mse))

Epoch 1, Train Set - MSE loss: 0.334: 100%|█████████████████████| 1193/1193 [02:07<00:00,  9.34it/s]
Epoch 1, Random Set - real MSE: NaN:   1%|▏              | 1/119 [11:21:10<1339:38:27, 40870.40s/it]


KeyboardInterrupt: 

In [12]:
import matplotlib.pyplot as plt

sequence = val_loader[0]  # The RNA sequence string
cam = all_cams[0]           # The corresponding CAM

# Normalize the CAM for visualization
cam = (cam - cam.min()) / (cam.max() - cam.min())

plt.figure(figsize=(15, 5))
plt.bar(range(len(sequence)), cam, tick_label=list(sequence))
plt.title('Class Activation Map')
plt.xlabel('Sequence Position')
plt.ylabel('Activation')
plt.show()

TypeError: 'DataLoader' object is not subscriptable