In [1]:
from kinodata.data.dataset import KinodataDocked

  from .autonotebook import tqdm as notebook_tqdm


## Create the dataset
Creating the dataset for the first time will trigger the [main data processing](https://github.com/volkamerlab/kinodata-3D-affinity-prediction/blob/b6f795d82b612629ae07e96c8f497a9a73b8d778/kinodata/data/dataset.py#L270),
the result of which is cached.

In [2]:
dataset = KinodataDocked()
dataset

KinodataDocked(104836)

## Data representation and data loading
Docked complexes are stored/represented as [heterogenous graph data](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html?highlight=heterodata) objects.

In [3]:
data = dataset[42]

node_types, edge_types = data.metadata()
print(f"Node types: {', '.join([nt for nt in node_types])}")
print(f"Edge types: {', '.join([str(et) for et in edge_types])}")
print(f"Number of ligand heavy atoms: {data['ligand'].x.size(0)}")
print(f"Number of pocket heavy atoms: {data['pocket'].x.size(0)}")

Node types: ligand, pocket, pocket_residue
Edge types: ('ligand', 'bond', 'ligand'), ('pocket', 'bond', 'pocket')
Number of ligand heavy atoms: 13
Number of pocket heavy atoms: 652


Typically you'll want to load split the data into train, test (and validation) sets

In [19]:
from kinodata.data.data_module import make_data_module
from kinodata.data.grouped_split import KinodataKFoldSplit

In [24]:
demo_split = KinodataKFoldSplit("random-k-fold", k=5).split(dataset)[0]

In [29]:
# this could take a while, expect ~60 seconds
data_module = make_data_module(
   split=demo_split,
   batch_size=32,
   dataset_cls=KinodataDocked,
   num_workers=0,
   train_kwargs=dict(),
)

In [32]:
batch = next(iter(data_module.train_dataloader()))
batch

HeteroDataBatch(
  kissim_fp=[32, 85, 12],
  y=[32],
  docking_score=[32],
  posit_prob=[32],
  predicted_rmsd=[32],
  pocket_sequence=[32],
  scaffold=[32],
  activity_type=[32],
  ident=[32],
  smiles=[32],
  [1mligand[0m={
    z=[1019],
    x=[1019, 12],
    pos=[1019, 3],
    batch=[1019],
    ptr=[33]
  },
  [1mpocket[0m={
    z=[21026],
    x=[21026, 12],
    pos=[21026, 3],
    batch=[21026],
    ptr=[33]
  },
  [1mpocket_residue[0m={
    x=[2720, 23],
    batch=[2720],
    ptr=[33]
  },
  [1m(ligand, bond, ligand)[0m={
    edge_index=[2, 2260],
    edge_attr=[2260, 4]
  },
  [1m(pocket, bond, pocket)[0m={
    edge_index=[2, 42376],
    edge_attr=[42376, 4]
  }
)

## Make absolute structural information relative

### Complex graph representation

In [33]:
import kinodata.transform as T

If you want to use the complex graph representation from our publication use

In [10]:
dataset = KinodataDocked(
    transform=T.TransformToComplexGraph(remove_heterogeneous_representation=True)
)

In [15]:
data = dataset[42]

node_types, edge_types = data.metadata()
print(f"Node types: {', '.join([nt for nt in node_types])}")
print(f"Edge types: {', '.join([str(et) for et in edge_types])}")
print(f"Number of complex heavy atoms: {data['complex'].x.size(0)}")

Node types: pocket_residue, complex
Edge types: ('complex', 'bond', 'complex')
Number of complex heavy atoms: 665


In [34]:
# this could take a while, expect ~60 seconds
data_module = make_data_module(
   split=demo_split,
   batch_size=32,
   dataset_cls=KinodataDocked,
   num_workers=0,
   train_kwargs=dict(
       transform=T.TransformToComplexGraph(),
   ),
)

In [43]:
demo_batch = next(iter(data_module.train_dataloader()))

tensor([[ 8.9015, 14.1521, 51.9454],
        [ 7.9257, 13.2832, 51.2874],
        [ 7.6748, 13.6848, 49.8207],
        ...,
        [-2.4295, 25.8731, 30.9331],
        [-2.9853, 23.7691, 30.9705],
        [-2.3645, 24.7632, 32.7969]])

Our transformer adds edges on the fly using a dedicated torch module.

In [37]:
from kinodata.model.complex_transformer import StructuralInteractions

In [39]:
interaction_module = StructuralInteractions(
    32,
    interaction_radius=5.0,
    max_num_neighbors=16,
    rbf_size=32
)

In [45]:
edge_index, _, distances = interaction_module.interactions(demo_batch)

### Adding proximity-based edges and distances to other representations

In [None]:
...