In [11]:
import sys
import torch
from rdkit  import Chem
from torch_geometric.data import HeteroData

sys.path.append('../')

In [12]:
from src.tacogfn.envs import frag_mol_env
from src.tacogfn.data import pharmacophore
from src.tacogfn.data.pharmacophore import PharmacoDB
from src.tacogfn.tasks.seh_frag import SOME_MOLS
from src.tacogfn.models import pharmaco_cond_graph_transformer

In [13]:
db = PharmacoDB('../misc/pharmacophores.lmdb')
ids = [
    '1a0q',
    '1a0t',
    '1a1b',
    '1a1c',
    '1a1e',
    '1a2c',
    '1a3e',
    '1a4g',
    '1a4h',
    '1a4k',
]
pharmacophores = [
    db.get_pharmacophore(id) for id in ids
]
pharmacophore_data_list = pharmacophore.PharmacophoreGraphDataset(
    pharmacophores
)

ctx = frag_mol_env.FragMolBuildingEnvContext()

In [14]:
mols = [Chem.MolFromSmiles(s) for s in SOME_MOLS]
graphs = [ctx.mol_to_graph(mols[i]) for i in range(len(mols))]
molecule_data_list = [ctx.graph_to_Data(graphs[i]) for i in range(len(graphs))]

In [15]:
merged_data_list = []

for pharmacophore_data, molecule_data in zip(pharmacophore_data_list, molecule_data_list):
    data = HeteroData()
    
    for key, value in molecule_data.items():
        data['compound'][key] = value
        
    for key, value in pharmacophore_data.items():
        data['pharmacophore'][key] = value
        
    merged_data_list.append(data)

In [16]:
import torch_geometric.data as gd

batch = gd.Batch.from_data_list(merged_data_list)

In [17]:
model = pharmaco_cond_graph_transformer.PharmacophoreConditionalGraphTransformer(
    pharmacophore_dim=64,
    x_dim=ctx.num_node_dim,
    e_dim=ctx.num_edge_dim,
    g_dim=ctx.num_cond_dim,
    num_emb=64,
    num_layers=3,
    num_heads=2,
    ln_type="pre",
)

In [18]:
model(batch, torch.randn(len(batch), ctx.num_cond_dim))[1].shape

torch.Size([10, 128])