In [1]:
import kinodata
from kinodata.data import KinodataDocked, Filtered
from kinodata.data.data_module import create_dataset
from kinodata.data.grouped_split import KinodataKFoldSplit
from kinodata.transform import TransformToComplexGraph, FilterDockingRMSD
from kinodata.types import *


import json
from pathlib import Path
from typing import Any

import torch

import kinodata.configuration as cfg
from kinodata.model import ComplexTransformer, DTIModel, RegressionModel
from kinodata.model.complex_transformer import make_model as make_complex_transformer
from kinodata.model.dti import make_model as make_dti_baseline
from kinodata.data.data_module import make_kinodata_module
from kinodata.transform import TransformToComplexGraph

import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import tqdm

!wandb disabled

  from .autonotebook import tqdm as notebook_tqdm


W&B disabled.


In [2]:
data = KinodataDocked()

In [3]:
data[0]

HeteroData(
  y=[1],
  docking_score=[1],
  posit_prob=[1],
  predicted_rmsd=[1],
  pocket_sequence='KPLGRGAFGQVIEVAVKMLALMSELKILIHIGLNVVNLLGAMVIVEFCKFGNLSTYLRSFLASRKCIHRDLAARNILLICDFGLA',
  scaffold='C1CCC(CC2CCCC(C3CC(C4CCCC4)C4CCCCC34)C2)CC1',
  activity_type='pIC50',
  ident=[1],
  smiles='Nc1ncnc2c1c(-c1cccc(Oc3ccccc3)c1)cn2C1CCCC1',
  [1mligand[0m={
    z=[28],
    x=[28, 12],
    pos=[28, 3]
  },
  [1mpocket[0m={
    z=[652],
    x=[652, 12],
    pos=[652, 3]
  },
  [1mpocket_residue[0m={ x=[85, 23] },
  [1m(ligand, bond, ligand)[0m={
    edge_index=[2, 64],
    edge_attr=[64, 4]
  },
  [1m(pocket, bond, pocket)[0m={
    edge_index=[2, 1308],
    edge_attr=[1308, 4]
  }
)

In [4]:
df = data.df

Reading data frame from /Users/joschka/projects/kinodata-3D-affinity-prediction/data/raw/kinodata_docked_v2.sdf.gz...
Deduping data frame (current size: 121913)...
119713 complexes remain after deduplication.
Checking for missing pocket mol2 files...


100%|██████████| 3244/3244 [00:00<00:00, 17458.87it/s]


Adding pocket sequences...
(119713, 25)


100%|██████████| 119713/119713 [00:00<00:00, 2088064.59it/s]


Exiting with 3552 cached sequences.
(119713, 26)


In [6]:
from kinodata.data.io.read_klifs_mol2 import read_klifs_mol2

In [11]:
pocket = df[df["ident"] == data[0].ident.item()]["pocket_mol2_file"].values[0]

In [15]:
data[0]

HeteroData(
  y=[1],
  docking_score=[1],
  posit_prob=[1],
  predicted_rmsd=[1],
  pocket_sequence='KPLGRGAFGQVIEVAVKMLALMSELKILIHIGLNVVNLLGAMVIVEFCKFGNLSTYLRSFLASRKCIHRDLAARNILLICDFGLA',
  scaffold='C1CCC(CC2CCCC(C3CC(C4CCCC4)C4CCCCC34)C2)CC1',
  activity_type='pIC50',
  ident=[1],
  smiles='Nc1ncnc2c1c(-c1cccc(Oc3ccccc3)c1)cn2C1CCCC1',
  [1mligand[0m={
    z=[28],
    x=[28, 12],
    pos=[28, 3]
  },
  [1mpocket[0m={
    z=[652],
    x=[652, 12],
    pos=[652, 3]
  },
  [1mpocket_residue[0m={ x=[85, 23] },
  [1m(ligand, bond, ligand)[0m={
    edge_index=[2, 64],
    edge_attr=[64, 4]
  },
  [1m(pocket, bond, pocket)[0m={
    edge_index=[2, 1308],
    edge_attr=[1308, 4]
  }
)

In [18]:
pocket_df = read_klifs_mol2(pocket, with_bonds=False)
pocket_df

Unnamed: 0,atom.id,atom.name,atom.x,atom.y,atom.z,atom.type,residue.subst_id,residue.subst_name,atom.charge,atom.status_bit
0,1,N,9.5601,17.745001,49.130402,N.3,1,LYS838,0.00,BACKBONE
1,2,H,8.8146,18.065201,49.731800,H,1,LYS838,0.00,BACKBONE
2,3,CA,9.4738,16.413900,48.548000,C.3,1,LYS838,0.00,BACKBONE
3,4,HA,10.4599,15.951400,48.590099,H,1,LYS838,0.00,BACKBONE
4,5,C,9.0894,16.442699,47.070202,C.2,1,LYS838,0.00,BACKBONE
...,...,...,...,...,...,...,...,...,...,...
1358,1359,O,5.8970,23.331900,25.846500,O.2,85,ALA1050,-0.57,BACKBONE
1359,1360,CB,4.0712,22.360201,27.595400,C.3,85,ALA1050,-0.24,
1360,1361,HB1,4.4936,23.259001,28.044600,H,85,ALA1050,0.08,
1361,1362,HB2,3.4488,21.846100,28.327801,H,85,ALA1050,0.08,


In [21]:
non_hydrogen = pocket_df["atom.type"] != "H"

In [28]:
pocket_df[non_hydrogen][["atom.x", "atom.y", "atom.z"]]

Unnamed: 0,atom.x,atom.y,atom.z
0,9.5601,17.745001,49.130402
2,9.4738,16.413900,48.548000
4,9.0894,16.442699,47.070202
5,8.3816,17.337900,46.612099
6,8.4672,15.594100,49.362000
...,...,...,...
1352,5.9957,21.020500,28.309299
1354,5.2053,21.431801,27.146999
1356,6.0524,22.129400,26.084000
1358,5.8970,23.331900,25.846500


In [30]:
data[0]["pocket"].pos.shape

torch.Size([652, 3])

In [29]:
data[0]["pocket"].pos

tensor([[ 9.5601, 17.7450, 49.1304],
        [ 9.4738, 16.4139, 48.5480],
        [ 9.0894, 16.4427, 47.0702],
        ...,
        [ 6.0524, 22.1294, 26.0840],
        [ 5.8970, 23.3319, 25.8465],
        [ 4.0712, 22.3602, 27.5954]])