# Inference

SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-NvidiaProprietary

NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.

### Prerequisite

**Before diving in, ensure you have all [necessary prerequisites](https://docs.nvidia.com/bionemo-framework/latest/pre-reqs.html). If this is your first time using BioNeMo, we recommend following the [quickstart guide](https://docs.nvidia.com/bionemo-framework/latest/quickstart-fw.html) first.** 

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [MegaMolBART](https://docs.nvidia.com/bionemo-framework/latest/models/megamolbart.html) model.

In [2]:
from rdkit import Chem
from bionemo.triton.inference_wrapper import new_inference_wrapper
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

INFO:datasets:PyTorch version 2.1.0a0+32f93b1 available.


[NeMo I 2024-05-21 21:19:37 megatron_hiddens:110] Registered hidden transform sampled_var_cond_gaussian at bionemo.model.core.hiddens_support.SampledVarGaussianHiddenTransform
[NeMo I 2024-05-21 21:19:37 megatron_hiddens:110] Registered hidden transform interp_var_cond_gaussian at bionemo.model.core.hiddens_support.InterpVarGaussianHiddenTransform


## Setup and test data

In [3]:
connection = new_inference_wrapper("grpc://localhost:8001")

smis = [
    'c1ccc2ccccc2c1',
    'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
]

W0521 21:19:43.634553 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.657920 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.658641 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.659281 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.660974 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.661570 140476867809920 client.py:184] tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: 60.0.
W0521 21:19:43.662192 140476867809920 client.py:184]

## SMILES to hidden state

`seqs_to_hidden` queries the model to fetch the latent space representation of the SMILES.

In [4]:
hidden_states, pad_masks = connection.seqs_to_hidden(smis)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")

assert tuple(hidden_states.shape) == (2, 1, 512)
assert tuple(pad_masks.shape) == (2, 1)

hidden_states.shape=torch.Size([2, 1, 512])
pad_masks.shape=torch.Size([2, 1])


## Hidden state to embeddings

In [5]:
embeddings = connection.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 512)

embeddings.shape=torch.Size([2, 512])


## SMILES to embeddings

In [6]:
embedding = connection.seqs_to_embedding(smis)
print(f"{embedding.shape=}")
assert tuple(embedding.shape) == (2, 512)

embedding.shape=torch.Size([2, 512])


## Hidden state to SMILES

In [7]:
def canonicalize_smiles(smiles: str) -> str:
    """Canonicalize input SMILES"""
    mol = Chem.MolFromSmiles(smiles)
    canon_smiles = Chem.MolToSmiles(mol, canonical=True)
    return canon_smiles

In [8]:
infered_smis = connection.hidden_to_seqs(hidden_states, pad_masks)
canon_infered_smis = list(map(canonicalize_smiles, infered_smis))
print(f"Reconstructed SMILES:\n{canon_infered_smis}")
assert len(canon_infered_smis) == 2

Reconstructed SMILES:
['c1ccc2ccccc2c1', 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC']


## Sampling

In [10]:
samples = connection.sample_seqs(seqs=smis)
print(f"Generated {len(samples)} samples")
assert len(samples) == 2
for i,s in enumerate(samples):
    print(f"Sample #{i+1} (length: {len(s)}):\n{s}\n-----------------------")
    assert len(s) == 1

Generated 2 samples
Sample #1 (length: 1):
['c1ccc2c(c1)C[C@H](CN[C@H]1COC3(CCC3)C1)O2']
-----------------------
Sample #2 (length: 1):
['COc1cc2nc(N3CCN(C(=O)C4CC4)CC3)nc(N(C)C)c2cc1OC']
-----------------------
