In [1]:
import torch
from torch_geometric.loader import DataLoader
from matplotlib import pyplot as plt
import numpy as np
from multiprocessing import Pool
from rdkit.Chem import Draw, PandasTools
from torch_geometric.loader import DataLoader
import torch
from modules.som_dataset import CustomDataset
from modules.som_models import GNNSOM
from utils import validation
import numpy as np23
import warnings
warnings.filterwarnings('ignore', '.*Sparse CSR tensor support is in beta state.*')
warnings.filterwarnings('ignore', '')

In [None]:
from torch_geometric.nn.aggr import MultiAggregation
from torch_geometric.nn import GlobalAttention

In [2]:
def CYP_REACTION(x):
    cyp_col = ['BOM_1A2', 'BOM_2A6', 'BOM_2B6', 'BOM_2C8', 'BOM_2C9', 'BOM_2C19', 'BOM_2D6', 'BOM_2E1', 'BOM_3A4']
    cyp_reactions = x[cyp_col].tolist()
    cyp_reactions = [i for i in cyp_reactions if i] 
    return '\n'.join( cyp_reactions )

In [3]:
cyp_list = [f'BOM_{i}' for i in  '1A2 2A6 2B6 2C8 2C9 2C19 2D6 2E1 3A4'.split()] + ['CYP_REACTION']

In [4]:
class CONFIG:
    substrate_loss_weight = 0.33    
    bond_loss_weight = 0.33
    atom_loss_weight = 0.33
    som_type_loss_weight = 0.33
    class_type = 2    
    th = 0.1    
    substrate_th = 0.5
    adjust_substrate = False
    test_only_reaction_mol = False
    equivalent_bonds_mean = False
    average = 'binary'            
    device = 'cuda:0'    
    reduction = 'sum'
    metric_mode = 'bond'
    n_classes = 5
args = CONFIG()



In [5]:
model = GNNSOM(
            num_layers=2,
            gnn_num_layers = 8,
            pooling='sum',
            dropout=0.1, 
            cyp_list=cyp_list, 
            use_face = True, 
            node_attn = True,
            face_attn = True,            
            n_classes=args.n_classes,
            ).to('cuda:0')

In [25]:
test_df = PandasTools.LoadSDF('data/test_0628.sdf')
test_df['CYP_REACTION'] = test_df.apply(CYP_REACTION, axis=1)
test_df['POS_ID'] = 'TEST' + test_df.index.astype(str).str.zfill(4)

test_dataset = CustomDataset(df=test_df, cyp_list=cyp_list, args=args)
test_loader = DataLoader(test_dataset, num_workers=2, batch_size=8, shuffle=False)

loss_fn_ce, loss_fn_bce = torch.nn.CrossEntropyLoss(), torch.nn.BCEWithLogitsLoss()

Processing...
Done!


In [7]:
args.ckpt = 'finetune/best.pt'
model.load_state_dict(torch.load(args.ckpt, 'cpu'))
test_scores = validation(model, test_loader, loss_fn_ce, loss_fn_bce, args)
validator = test_scores['validator'] 

In [8]:
print(validator.y_prob.keys())
print(validator.y_prob['bond_som'].keys())

dict_keys(['subs', 'bond_som', 'atom_som', 'atom_spn', 'dea', 'epo', 'oxi', 'dha', 'dhy', 'rdc'])
dict_keys(['BOM_1A2', 'BOM_2A6', 'BOM_2B6', 'BOM_2C8', 'BOM_2C9', 'BOM_2C19', 'BOM_2D6', 'BOM_2E1', 'BOM_3A4', 'CYP_REACTION'])


In [9]:
y_bond_som_1A2 = np.array(validator.y_prob['bond_som']['BOM_1A2'])

In [10]:
print(y_bond_som_1A2.shape)

(6338,)


### Unbatch

In [11]:
validator.unbatch()

In [44]:
y_prob_bond_som_1A2 = validator.y_prob_unbatch['bond_som']['BOM_1A2']
y_prob_atom_som_1A2 = validator.y_prob_unbatch['atom_som']['BOM_1A2']

In [45]:
y_true_bond_som_1A2 = validator.y_true_unbatch['bond_som']['BOM_1A2']
y_true_atom_som_1A2 = validator.y_true_unbatch['atom_som']['BOM_1A2']

In [46]:
print(len(y_bond_som_1A2))

135


In [49]:
for idx in range(test_df.shape[0]):
    test_id = test_df.loc[idx, 'POS_ID']
    y_prob_atom_mol = y_prob_atom_som_1A2[idx]
    y_prob_bond_mol = y_prob_bond_som_1A2[idx]
    print(f'id : {test_id}', f'len bond atom : {y_prob_atom_mol.shape[0]} {y_prob_bond_mol.shape[0]}' )
    break



id : TEST0000 len bond atom : 24 23
