## Environment setup

```sh
conda create -n pred pandas pytorch matplotlib numpy pyyaml tqdm einops
```

In [1]:
from torch.utils.data import Dataset
import torch
import numpy as np
import pandas as pd
import csv
import orjson
import yaml
from Network import *
from tqdm import tqdm
import matplotlib.pyplot as plt


## Modules

In [2]:
class RNADatasetRN(Dataset):
    def __init__(self, dataset_csv, output_csv):
        self.data = pd.read_csv(dataset_csv)
        self.tokens={nt: i for i, nt in enumerate("ACGU")}

        try:
            output_data = pd.read_csv(output_csv)
            mask = ~self.data["SeqID"].isin(output_data["SeqID"])
            self.data = self.data[mask].reset_index(drop=True)
            self.output = open(output_csv, "a", newline="")
            self.writer = csv.writer(self.output)
        except:
            self.output = open(output_csv, "w", newline="")
            self.writer = csv.writer(self.output)
            self.writer.writerow(["SeqID", "RT", "RibonanzaNetPrediction", "RibonanzaNetMAE"])
            self.output.flush()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        seq_id = self.data.loc[index, "SeqID"]
        sequence = self.data.loc[index, "SQ"]
        sequence = [self.tokens[nt] for nt in sequence]
        sequence = torch.tensor(np.array(sequence))
        
        reactivity = self.data.loc[index,"RT"]
        reactivity = orjson.loads(reactivity)
        return seq_id, sequence, reactivity
    
    def __repr__(self):
        return str(self.data)

    def write(self, seq_id, reactivity, prediction, MAE):
        reactivity = orjson.dumps(reactivity).decode()
        prediction = orjson.dumps(prediction).decode()
        row = [seq_id, reactivity, prediction, MAE]
        self.writer.writerow(row)
        self.output.flush()

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        self.entries = entries

    def print(self):
        print(self.entries)

    @classmethod
    def from_yaml(cls, file_path):
        with open(file_path, "r") as file:
            config = yaml.safe_load(file)
        return cls(**config)

def mean_absolute_error(prediction, truth):
    MAE = 0
    for p, t in zip(prediction, truth):
        MAE += abs(p - t)
    MAE = MAE / len(prediction)
    return MAE

## Configs

### Zebrafish

In [3]:
args = {
    "dataset_csv": "/home/learn/xwt/pred/data/zebrafish/sample/sample.csv",
    "output_csv": "/home/learn/xwt/pred/data/zebrafish/sample/data/predictions.csv",
    "figure_dir": "/home/learn/xwt/pred/data/zebrafish/sample/figures",
    "config": "/home/learn/xwt/pred/models/RibonanzaNet/configs/pairwise.yaml",
    "weights_pt": "/home/learn/xwt/pred/models/RibonanzaNet/weights/RibonanzaNet.pt",
    "num_cores": 8
}

### Neural

In [3]:
args = {
    "dataset_csv": "/home/learn/xwt/pred/data/neural/sample/sample.csv",
    "output_csv": "/home/learn/xwt/pred/data/neural/sample/data/predictions.csv",
    "figure_dir": "/home/learn/xwt/pred/data/neural/sample/figures",
    "config": "/home/learn/xwt/pred/models/RibonanzaNet/configs/pairwise.yaml",
    "weights_pt": "/home/learn/xwt/pred/models/RibonanzaNet/weights/RibonanzaNet.pt",
    "num_cores": 8
}

## Load data

In [4]:
dataset = RNADatasetRN(args["dataset_csv"], args["output_csv"])
print(dataset)

                SeqID                                                 SQ  \
0  D7|ENST00000439929  AUGGGUCACCAGCAGCUGUACUGGAGCCACGCGCGAAAAUUCGGCC...   
1  D7|ENST00000432323  AUGAGCAAAGCUCACCCUCCCGAGCUGAAAAAAUUUAUGGACAAGA...   
2  D8|ENST00000473748  AUGGACACCAGCCGUGUGCAGCCUAUCAAGCUGGCCAGGGUCACCA...   
3  D8|ENST00000384674  ACUCUCUCGGCUCUGCAUAGUUGCACUUGGCUUCACCCGUGUGACU...   
4  D7|ENST00000447303  AUGCCUCGGAAAAUUGAGGAAAUCAAGGAUUUUCUGCUCACAGUCC...   
5  D7|ENST00000384581  AUCCUCCUGAUCCCUUUCCCAUCGGAUCUGAACACUGGUCUUGGUG...   
6  D0|ENST00000706951  GAUUCCCUGCAGUAAACGGACUUUUCAUUUAUUUAAUCAUUCAAAC...   
7  D0|ENST00000402089  AUGUCUGACAAACCCGAUAUGGCUGAGAUCGAGAAAUUCGAUAAGC...   
8  D8|ENST00000520566  AUGUCCGGCCGCGAAGGUGGCAAGAAGAAGCCACUGAAACAGCCCA...   
9  D7|ENST00000497342  AUGUCGCACAAACAAAUUUACUAUUCGGACAAAUACGAUGACAAGG...   

                                                  RD  \
0  [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0...   
1  [0.03795079675705898,0.3063321218898518,0.3363..

## Load model

In [5]:
model = RibonanzaNet(Config.from_yaml(args["config"]))
model.load_state_dict(torch.load(args["weights_pt"], map_location=torch.device('cpu')))
torch.set_num_threads(args["num_cores"])

constructing 9 ConvTransformerEncoderLayers


## Infer

In [6]:
for index in tqdm(range(len(dataset)), desc="Predict Reactivity"):
    try:
        seq_id, sequence, reactivity = dataset[index]
        sequence = sequence.unsqueeze(0)
        with torch.no_grad():
            prediction = model(sequence, torch.ones_like(sequence)).squeeze().cpu().numpy()
        prediction = prediction[:,0].tolist()
        MAE = mean_absolute_error(prediction, reactivity)
    except KeyError:
        continue
    dataset.write(seq_id, reactivity, prediction, MAE)

Predict Reactivity: 100%|██████████| 10/10 [07:04<00:00, 42.46s/it]


## Plot results

In [9]:
results_data = pd.read_csv(args["output_csv"])
for _, row in results_data.iterrows():
    seq_id, reactivity, prediction, _ = row
    reactivity = orjson.loads(reactivity)
    prediction = orjson.loads(prediction)

    fig = plt.figure(figsize=(len(reactivity) / 5, 5))
    plt.plot(prediction, label="Prediction")
    plt.plot(reactivity, label="Truth")
    
    plt.xlabel("Position")
    plt.ylabel("Reactivity")
    plt.legend()
    plt.savefig(f"{args['figure_dir']}/{seq_id}.pdf", format="pdf", bbox_inches="tight")
    plt.close()