Example on how to compute representation of Pharmacophores with GVP-GNN

In [1]:
import sys

sys.path.append('../')

In [2]:
import pickle
import torch
import torch_geometric
from src.tacogfn.models import gvp_model
from src.pharmaconet import PharmacophoreModel
from utils import get_example_pharmacophore_datalist

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data_list = get_example_pharmacophore_datalist()

In [4]:
# model = GVP_embedding((6, 3), (32, 0), (32, 1), (32, 0), seq_in=True).to(device)

embedding_channels = 256

model = gvp_model.GVP_embedding((216, 1), (embedding_channels, 16), (24, 1), (32, 1), seq_in=True)

In [5]:
data = data_list[0]
out = model((data.node_s, data.node_v), data.edge_index, (data.edge_s, data.edge_v), data.seq)

In [6]:
out

tensor([[0.9390, 0.3552, 0.0000,  ..., 0.0000, 0.3677, 0.4050],
        [1.2494, 0.2384, 0.0000,  ..., 0.1526, 0.1072, 1.1338],
        [0.9459, 0.0000, 0.0000,  ..., 0.0294, 0.1247, 1.9413],
        ...,
        [0.5858, 0.2619, 0.0000,  ..., 0.0000, 0.0847, 0.4707],
        [0.7057, 0.6664, 0.0000,  ..., 0.0000, 0.4189, 1.2255],
        [0.7210, 0.8459, 0.0000,  ..., 0.2783, 0.4058, 1.1547]],
       grad_fn=<ReluBackward0>)

In [4]:
pharmacophre_dict = pickle.load(open('../dataset/new_pharmacophores/1a2d_B.pt', 'rb'))

In [51]:
pharmacophore = PharmacophoreModel()
pharmacophore.__setstate__(pharmacophre_dict)

In [25]:
from src.pharmaconet.scoring import pharmacophore_model
import torch_cluster
from src.tacogfn.utils import transforms
from src.tacogfn.data.utils import _normalize, _rbf

interaction_to_id = {
    interaction: i
    for i, interaction in enumerate(pharmacophore_model.INTERACTION_TYPES)
}
id_to_interaction = {
    i: interaction
    for i, interaction in enumerate(pharmacophore_model.INTERACTION_TYPES)
}

In [32]:
hotspot_positions.shape

torch.Size([51, 3])

In [31]:
radii.shape

torch.Size([51, 1])

In [33]:
scores.shape

torch.Size([51])

In [53]:
import numpy as np

nodes = pharmacophore.nodes
top_k = 20

with torch.no_grad():
    # Node features
    seq = torch.as_tensor(
        [interaction_to_id[node.interaction_type] for node in nodes],
        dtype=torch.long,
    )
    centroids = torch.tensor(
        [node.center for node in nodes],
    )
    hotspot_positions = torch.tensor(
        [node.hotspot_position for node in nodes],
    )
    radii = torch.tensor(
        [node.radius for node in nodes],
    ).unsqueeze(-1)
    scores = torch.tensor(
        [node.score for node in nodes],
    )
    features = torch.tensor(
        np.array([node.feature for node in nodes]),
    )
    dist_to_hotspot = hotspot_positions - centroids

    radii_rbf = _rbf(
        radii.squeeze(-1),
        D_min=0,
        D_max=2,
        D_count=8,
    )
    unit_vector_to_hotspot = _normalize(dist_to_hotspot)
    dist_to_hotspot_rbf = _rbf(
        dist_to_hotspot.norm(dim=-1),
        D_min=0,
        D_max=8,
        D_count=8,
    )
    scores_therometer = transforms.thermometer(
        scores, n_bins=8, vmin=0.0, vmax=1.0
    )

    # Edge features
    edge_index = torch_cluster.knn_graph(centroids, k=top_k)
    covariance_dists = torch.sqrt(
        radii[edge_index[0]] ** 2 + radii[edge_index[1]] ** 2
    )
    pharmacophore_dists = centroids[edge_index[0]] - centroids[edge_index[1]]

    unit_vector_to_pharmacophore = _normalize(pharmacophore_dists)

    pharmacophore_dists_rbf = _rbf(
        pharmacophore_dists.norm(dim=-1),
        D_min=0,
        D_max=20,
        D_count=16,
    )

    covariance_dists_rbf = _rbf(
        covariance_dists.squeeze(-1),
        D_min=0,
        D_max=2,
        D_count=8,
    )

    node_s = torch.cat(
        [
            radii_rbf,  # 8
            scores_therometer,  # 8
            dist_to_hotspot_rbf,  # 8
            features,  # 192
        ],
        dim=-1,
    )

    node_v = unit_vector_to_hotspot.unsqueeze(-2)
    edge_s = torch.cat(
        [
            pharmacophore_dists_rbf,  # 16
            covariance_dists_rbf,  # 8
        ],
        dim=-1,
    )
    edge_v = unit_vector_to_pharmacophore.unsqueeze(-2)
    
    data = torch_geometric.data.Data(
            seq=seq,
            node_s=node_s,
            node_v=node_v,
            edge_s=edge_s,
            edge_v=edge_v,
            edge_index=edge_index,
        )

In [55]:
radii_rbf[0]

tensor([2.5281e-07, 5.0643e-04, 7.4430e-02, 8.0259e-01, 6.3498e-01, 3.6859e-02,
        1.5698e-04, 4.9052e-08])

In [56]:
node_s.shape

torch.Size([51, 216])

In [62]:
out

tensor([[0.0000, 0.0000, 0.0284,  ..., 0.0000, 0.0000, 0.3058],
        [0.0000, 0.0000, 0.0267,  ..., 0.2727, 0.0000, 0.2274],
        [0.0000, 0.0000, 0.3440,  ..., 0.3255, 0.0000, 0.6017],
        ...,
        [0.0000, 0.0000, 0.9428,  ..., 0.8119, 0.0000, 0.1741],
        [0.0000, 0.0000, 0.9559,  ..., 0.1160, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.2372,  ..., 0.0514, 0.0000, 0.2088]],
       grad_fn=<ReluBackward0>)