In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import torch
from data import RNA_dataset, Molecule_dataset, RNA_dataset_independent, Molecule_dataset_independent, WordVocab
from model import RNA_feature_extraction, GNN_molecule, mole_seq_model, cross_attention
from torch_geometric.loader import DataLoader
import torch.optim as optim
from scipy.stats import pearsonr,spearmanr
from torch.autograd import Variable
import numpy as np
import os
import torch.nn as nn
from sklearn.metrics import mean_squared_error
import random
torch.set_printoptions(profile="full")

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
hidden_dim = 16

EPOCH = 200
RNA_type = 'Viral_RNA_independent'
rna_dataset = RNA_dataset(RNA_type)
molecule_dataset = Molecule_dataset(RNA_type)

rna_dataset_in = RNA_dataset_independent()
molecule_dataset_in = Molecule_dataset_independent()

seed = 1



# set random seed
def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_printoptions(precision=20)
set_seed(seed)

# combine two pyg dataset
class CustomDualDataset(Dataset):
    def __init__(self, dataset1, dataset2):
        self.dataset1 = dataset1
        self.dataset2 = dataset2

        assert len(self.dataset1) == len(self.dataset2)

    def __getitem__(self, index):
        return self.dataset1[index], self.dataset2[index]

    def __len__(self):
        return len(self.dataset1)  



def average_multiple_lists(lists):
    return [sum(item)/len(lists) for item in zip(*lists)]





# DeepRSMA architecture
class DeepRSMA(nn.Module):
    def __init__(self):
        super(DeepRSMA, self).__init__()
        # RNA graph + seq
        self.rna_graph_model = RNA_feature_extraction(hidden_dim)
        
        # Mole graph
        self.mole_graph_model = GNN_molecule(hidden_dim)
        # Mole seq
        self.mole_seq_model = mole_seq_model(hidden_dim)

        # Cross fusion module
        self.cross_attention = cross_attention(hidden_dim)
        
        self.line1 = nn.Linear(hidden_dim*2, 1024)
        self.line2 = nn.Linear(1024, 512)
        self.line3 = nn.Linear(512, 1)
        self.dropout = nn.Dropout(0.2)
        
        self.rna1 = nn.Linear(hidden_dim, hidden_dim*4)
        self.mole1 = nn.Linear(hidden_dim, hidden_dim*4)
        
        self.rna2 = nn.Linear(hidden_dim*4, hidden_dim)
        self.mole2 = nn.Linear(hidden_dim*4, hidden_dim)
        
        self.relu = nn.ReLU()
    
    def forward(self, rna_batch, mole_batch):
        rna_out_seq,rna_out_graph, rna_mask_seq, rna_mask_graph, rna_seq_final, rna_graph_final = self.rna_graph_model(rna_batch, device)
        
        mole_graph_emb, mole_graph_final = self.mole_graph_model(mole_batch)
        
        mole_seq_emb, _, mole_mask_seq = self.mole_seq_model(mole_batch, device)
        
        mole_seq_final = (mole_seq_emb[-1]*(mole_mask_seq.to(device).unsqueeze(dim=2))).mean(dim=1).squeeze(dim=1)


        # mole graph
        flag = 0
        mole_out_graph = []
        mask = []
        for i in mole_batch.graph_len:
            count_i = i
            x = mole_graph_emb[flag:flag+count_i]
            temp = torch.zeros((128-x.size()[0]), hidden_dim).to(device)
            x = torch.cat((x, temp),0)
            mole_out_graph.append(x)
            mask.append([] + count_i * [1] + (128 - count_i) * [0])
            flag += count_i
        mole_out_graph = torch.stack(mole_out_graph).to(device)
        mole_mask_graph = torch.tensor(mask, dtype=torch.float)
        
        context_layer, attention_score = self.cross_attention([rna_out_seq, rna_out_graph, mole_seq_emb[-1], mole_out_graph], [rna_mask_seq.to(device), rna_mask_graph.to(device), mole_mask_seq.to(device), mole_mask_graph.to(device)], device)

        
        out_rna = context_layer[-1][0]
        out_mole = context_layer[-1][1]
        
        # Affinity Prediction Module
        rna_cross_seq = ((out_rna[:, 0:512]*(rna_mask_seq.to(device).unsqueeze(dim=2))).mean(dim=1).squeeze(dim=1) + rna_seq_final ) / 2
        rna_cross_stru = ((out_rna[:, 512:]*(rna_mask_graph.to(device).unsqueeze(dim=2))).mean(dim=1).squeeze(dim=1) + rna_graph_final) / 2        

        rna_cross = (rna_cross_seq + rna_cross_stru) / 2
        rna_cross = self.rna2(self.dropout((self.relu(self.rna1(rna_cross)))))

        
        mole_cross_seq = ((out_mole[:,0:128]*(mole_mask_seq.to(device).unsqueeze(dim=2))).mean(dim=1).squeeze(dim=1) + mole_seq_final) / 2
        mole_cross_stru = ((out_mole[:,128:]*(mole_mask_graph.to(device).unsqueeze(dim=2))).mean(dim=1).squeeze(dim=1) + mole_graph_final) / 2
        
        mole_cross = (mole_cross_seq + mole_cross_stru) / 2
        mole_cross = self.mole2(self.dropout((self.relu(self.mole1(mole_cross)))))   
        
        out = torch.cat((rna_cross, mole_cross),1)
        out = self.line1(out)
        out = self.dropout(self.relu(out))
        out = self.line2(out)
        out = self.dropout(self.relu(out))
        out = self.line3(out)
        

        return out


# use viral RNA to train
train_dataset = CustomDualDataset(rna_dataset, molecule_dataset)
# independent test
test_dataset = CustomDualDataset(rna_dataset_in, molecule_dataset_in)


train_loader = DataLoader(
    train_dataset, batch_size=8, num_workers=0, drop_last=False, shuffle=False
)
test_loader = DataLoader(
    test_dataset, batch_size=1, num_workers=0, drop_last=False, shuffle=False
)



model = DeepRSMA()
model.to(device)

y_pred_all = []
max_p = -1

optimizer = optim.Adam(model.parameters(), lr=6e-5 , weight_decay=1e-5)
optimal_loss = 1e10
loss_fct = torch.nn.MSELoss()
for epoch in range(0,EPOCH):
    train_loss = 0

    for step, batch in enumerate(train_loader):
        optimizer.zero_grad()
        pre = model(batch[0].to(device), batch[1].to(device))

        y = batch[0].y
        
        loss = loss_fct(pre.squeeze(dim=1), y.float())
        loss.backward()
        optimizer.step()
        train_loss = train_loss + loss
    with torch.set_grad_enabled(False):
        model.eval()
        y_label = []
        y_pred = []
        for step, (batch_v) in enumerate(test_loader):
            label = Variable(torch.from_numpy(np.array(batch_v[0].y))).float()
            score = model(batch_v[0].to(device), batch_v[1].to(device))

            logits = torch.squeeze(score).detach().cpu().numpy()
            label_ids = label.to('cpu').numpy()

            y_label = y_label + label_ids.flatten().tolist()
            y_pred = y_pred + logits.flatten().tolist()

        p = pearsonr(y_label, y_pred)
        s = spearmanr(y_label, y_pred)
        rmse = np.sqrt(mean_squared_error(y_label, y_pred))
        print( 'epo:',epoch, 'pcc:',p[0],'scc: ',s[0], 'rmse:',rmse)

        if max_p < p[0]:
            max_p = p[0]
            print(' ')
            print('Best:', 'epo:',epoch, 'pcc:',p[0],'scc: ',s[0],'rmse:',rmse)

            torch.save(model.state_dict(), 'save/' + 'model_independent_'+str(seed)+'.pth')

        
        model.train()


RuntimeError: The size of tensor a (128) must match the size of tensor b (16) at non-singleton dimension 2

In [60]:
        rna_dataset_in = RNA_dataset_independent()
        molecule_dataset_in = Molecule_dataset_independent()

In [61]:
from ablation_utils import target_swap
test_set = target_swap(rna_dataset_in, molecule_dataset_in, seed=seed)

AttributeError: 'GlobalStorage' object has no attribute 't_id'

In [64]:
    rna_initial_mapping = [None for rna in rna_dataset_in]
    rna_index = 0
    reference_rnas = []
    for i, rna in enumerate(rna_dataset_in):
        if rna_initial_mapping[i] is None:
            reference_t_id = rna.t_id
            reference_rnas.append(i)
            for j in range(i, len(rna_dataset_in)):
                if rna_dataset[j].t_id == reference_t_id:
                    rna_initial_mapping[j] = rna_index
            rna_index += 1
        
    random.seed(seed)
    inds = random.sample(reference_rnas, rna_index)
    RNA_swap = [inds[rna_index] for rna_index in rna_initial_mapping]

AttributeError: 'GlobalStorage' object has no attribute 't_id'

In [69]:
rna_dataset_in[0].edge_index

tensor([[ 0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,
          3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,
          5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  7,  7,  8,  8,  9,  9,  9,  9,
         10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13,
         13, 13, 14, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27],
        [ 1,  2, 26, 27, 28,  2,  3, 25, 26, 27, 28,  3,  4, 24, 25, 26, 27, 28,
          4,  5, 23, 24, 25, 26, 27,  5,  6, 22, 23, 24, 25, 26,  6,  7, 20, 21,
         22, 23, 24, 25,  7,  8, 21, 22, 23, 24,  8,  9,  9, 10, 10, 11, 21, 22,
         11, 12, 20, 21, 22, 12, 13, 19, 20, 21, 22, 13, 14, 19, 20, 21, 14, 15,
         19, 20, 15, 16, 17, 16, 17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22,
         23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 28, 28]])

In [53]:
molecule_dataset[125].e_id

tensor([2378])

In [46]:
counter=0
for i in range(len(molecule_dataset_in)):
    for j in range(i+1,len(molecule_dataset)):
        if molecule_dataset[j].smiles_ori == molecule_dataset[i].smiles_ori:
            counter+=1
            print(f'{i} and {j}')
        

7 and 125
7 and 126
7 and 127
7 and 128
7 and 129
8 and 27
9 and 29
11 and 90
11 and 100
11 and 111
12 and 13
12 and 85
12 and 86
13 and 85
13 and 86
53 and 54
55 and 56
57 and 58
78 and 79
78 and 80
79 and 80
85 and 86
90 and 100
90 and 111
91 and 101
91 and 112
92 and 102
92 and 113
93 and 104
93 and 115
94 and 99
94 and 105
94 and 110
94 and 116
94 and 121
95 and 96
95 and 106
95 and 107
95 and 117
95 and 118
96 and 106
96 and 107
96 and 117
96 and 118
97 and 98
97 and 108
97 and 109
97 and 119
97 and 120
98 and 108
98 and 109
98 and 119
98 and 120
99 and 105
99 and 110
99 and 116
99 and 121
100 and 111
101 and 112
102 and 113
103 and 114
104 and 115
105 and 110
105 and 116
105 and 121
106 and 107
106 and 117
106 and 118
107 and 117
107 and 118
108 and 109
108 and 119
108 and 120
109 and 119
109 and 120
110 and 116
110 and 121
116 and 121
117 and 118
119 and 120
125 and 126
125 and 127
125 and 128
125 and 129
126 and 127
126 and 128
126 and 129
127 and 128
127 and 129
128 and 129


In [57]:
list(set(rna.t_id for rna in rna_dataset))

['Target_432',
 'Target_68',
 'Target_217',
 'Target_130',
 'Target_374',
 'Target_433',
 'Target_409',
 'Target_408',
 'Target_80',
 'Target_219',
 'Target_66',
 'Target_222',
 'Target_257',
 'Target_431',
 'Target_329',
 'Target_434',
 'Target_131',
 'Target_220',
 'Target_435',
 'Target_188',
 'Target_146',
 'Target_218',
 'Target_373',
 'Target_109',
 'Target_67']

In [59]:
rna_dataset[0].y

tensor([6.14266729354858398438])

In [10]:
import os
import torch
from torch_geometric.data import DataLoader
from data import RNA_dataset_independent, Molecule_dataset_independent
from model.deeprsma import DeepRSMA
from ablation_utils import identity, target_swap, ligand_swap



model_path = f"save/model_independent_1.pth"
ablations = {
    "target-swap": target_swap,
    "ligand-swap": ligand_swap,
    "none": identity,
}

rows = []
seeds = [0, 1, 2]

ablation = target_swap
rna_dataset_in = RNA_dataset_independent()
molecule_dataset_in = Molecule_dataset_independent()

for seed in seeds:

    test_set_target_swap = target_swap(rna_dataset_in, molecule_dataset_in, seed=seed)
    test_set_id = identity(rna_dataset_in, molecule_dataset_in, seed=seed)
    test_loader_target_swap = DataLoader(
        test_set_target_swap, batch_size=1, num_workers=0, drop_last=False, shuffle=False
    )
    test_loader_id = DataLoader(
        test_set_id, batch_size=1, num_workers=0, drop_last=False, shuffle=False
    )
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = DeepRSMA(hidden_dim=128)

    if os.path.exists(model_path):

        pretrained_dict = torch.load(model_path,map_location="cuda:0" if torch.cuda.is_available() else "cpu")
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model = model.to(device)



In [16]:
for i in range(len(test_set_id)):
    print(f"RNAs identical: {(test_set_id[i][0].x==test_set_target_swap[i][0].x).all()}")
    print(f"Mols identical: {test_set_id[i][1].smiles_ori==test_set_target_swap[i][1].smiles_ori}")

RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identical: True
RNAs identical: True
Mols identica

In [21]:

import random
rna_dataset = rna_dataset_in
mol_dataset = molecule_dataset_in
rna_initial_mapping = [None for rna in rna_dataset]
rna_index = 0
reference_rnas = []
for i, rna in enumerate(rna_dataset):
    if rna_initial_mapping[i] is None:
        reference_x = rna.x
        reference_rnas.append(i)
        for j in range(i, len(rna_dataset)):
            if rna_dataset[j].x.shape == reference_x.shape:
                if (rna_dataset[j].x == reference_x).all():
                    rna_initial_mapping[j] = rna_index
        rna_index += 1
    
random.seed(seed)
inds = random.sample(reference_rnas, rna_index)
RNA_swap = [inds[rna_ind] for rna_ind in rna_initial_mapping]

In [39]:
from data import RNA_dataset, Molecule_dataset
import pandas as pd
from data.vocab import WordVocab  # Add this import
from data import RNA_dataset, Molecule_dataset


RNA_type = 'All_sf'
rna_dataset = RNA_dataset(RNA_type)
molecule_dataset = Molecule_dataset(RNA_type)
cold_type = 'rna'
all_df = pd.read_csv('data/RSM_data/' + 'All_sf' + '_dataset_v1.csv', delimiter='\t')  
df1 = pd.read_csv('data/blind_test/cold_' + cold_type +'1.csv', delimiter=',')
df2 = pd.read_csv('data/blind_test/cold_' + cold_type +'2.csv', delimiter=',')
df3 = pd.read_csv('data/blind_test/cold_' + cold_type +'3.csv', delimiter=',')
df4 = pd.read_csv('data/blind_test/cold_' + cold_type +'4.csv', delimiter=',')
df5 = pd.read_csv('data/blind_test/cold_' + cold_type +'5.csv', delimiter=',')
df = [df1, df2, df3, df4, df5]

df_f = df1

test_id = df_f['Entry_ID'].tolist()
test_id = all_df[all_df['Entry_ID'].isin(test_id)].index.tolist()

train_id = []
for j, other_id in enumerate(df):
    if j != i:
        train_id.extend(other_id['Entry_ID'].tolist())

train_id = all_df[all_df['Entry_ID'].isin(train_id)].index.tolist()

rna_dataset = rna_dataset[test_id]

for rna in rna_dataset:
    if rna.x.shape==rna_dataset[0].x.shape:
        if (rna.x==rna_dataset[0].x).all():
            print("Identical")
        else:
            print("not identical")
    else:
        print("not identical")


Identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
Identical
not identical
not identical
not identical
not identical
not identical
Identical
not identical
not identical
Identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not iden

In [41]:
mol_dataset = molecule_dataset[test_id]
for mol in mol_dataset:
    if mol.smiles_ori==mol_dataset[0].smiles_ori:
        print("Identical")
    else:
        print("not identical")
    

Identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identical
not identi

In [23]:
set(rna_initial_mapping)

{0}