In [1]:
import os
from torch_geometric.data import Data

from atomsurf.protein.create_esm import get_esm_embedding_single, get_esm_embedding_batch
from atomsurf.utils.data_utils import AtomBatch, PreprocessDataset, pdb_to_surf, pdb_to_graphs
from atomsurf.utils.python_utils import do_all
#from atomsurf.utils.wrappers import DefaultLoader, get_default_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set up data paths
pdb_dir = "example_data/pdb"
surface_dir = "example_data/surfaces_0.1"
rgraph_dir = "example_data/rgraph"
esm_dir = "example_data/esm_emb"
example_name = "4kt3"

In [3]:
# Individual computation
# Set up paths
pdb_path = os.path.join(pdb_dir, f"{example_name}.pdb")
surface_dump = os.path.join(surface_dir, f"{example_name}.pt")
rgraph_dump = os.path.join(rgraph_dir, f"{example_name}.pt")

In [4]:
# Pre-compute surface, graphs and esm embeddings
pdb_to_surf(pdb_path, surface_dump)
pdb_to_graphs(pdb_path, rgraph_dump=rgraph_dump)
get_esm_embedding_single(pdb_path, esm_dir)

pdb_to_graph failed for :  example_data/pdb/4kt3.pdb can't convert np.ndarray of type numpy.str_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.


tensor([[ 0.0488, -0.0112, -0.1231,  ..., -0.0562,  0.1707,  0.1623],
        [-0.1003, -0.0680, -0.2085,  ..., -0.4624,  0.0685,  0.0481],
        [-0.0185, -0.1279, -0.1999,  ..., -0.2618, -0.0534, -0.0102],
        ...,
        [ 0.0123, -0.1870,  0.0046,  ..., -0.0819, -0.0566, -0.0228],
        [ 0.0018, -0.2486,  0.1032,  ..., -0.1365,  0.0102, -0.0234],
        [-0.1018, -0.1732,  0.0069,  ..., -0.2924, -0.0985,  0.1086]])

In [None]:
# Pre-compute surface, graphs and esm embeddings
pdb_to_surf(pdb_path, surface_dump)
pdb_to_graphs(pdb_path, rgraph_dump=rgraph_dump)
get_esm_embedding_single(pdb_path, esm_dir)

# Do the same but automatically on a directory
dataset = PreprocessDataset(data_dir="example_data")
do_all(dataset, num_workers=2)
get_esm_embedding_batch(in_pdbs_dir=pdb_dir, dump_dir=esm_dir)

# Load precomputed files
default_loader = DefaultLoader(surface_dir=surface_dir, graph_dir=rgraph_dir, embeddings_dir=esm_dir)
surface, graph = default_loader(example_name)

# Artifically group in a container and "batch"
protein = Data(surface=surface, graph=graph)
batch = AtomBatch.from_data_list([protein, protein])
print(batch)

# Instantiate a model, based on the dimensionality of the input
in_dim_surface, in_dim_graph = surface.x.shape[-1], graph.x.shape[-1]
atomsurf_model = get_default_model(in_dim_surface, in_dim_graph, model_dim=12)

# Encode your input batch !
surface, graph = atomsurf_model(graph=batch.graph, surface=batch.surface)
surface.x  # (total_n_verts, hidden_dim)
graph.x  # (total_n_nodes, hidden_dim)
print(graph.x.shape)