# Inference Sample

Copyright (c) 2022, NVIDIA CORPORATION. Licensed under the Apache License, Version 2.0 (the "License") you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0 

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

### Prerequisite
* Linux OS
* Pascal, Volta, Turing, or an NVIDIA Ampere architecture-based GPU.
* NVIDIA Driver
* Docker

### Import
Components for inferencing are part of the BioNeMo MegaMolBART source code. This notebook demonstrates the use of these components.

MegaMolBARTInferer implements following functions:
* `smis_to_hidden`
* `smis_to_embedding`
* `hidden_to_smis`

Note that gRPC limits request size to 4MB.

In [None]:
import warnings

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [None]:
from typing import List
from pathlib import Path
import os

try:
    BIONEMO_HOME: Path = Path(os.environ['BIONEMO_HOME']).absolute()
except KeyError:
    print("Must have BIONEMO_HOME set in the environment! See docs for instructions.")
    raise

config_path = BIONEMO_HOME / "examples" / "molecule" / "megamolbart" / "conf"
print(f"Using model configuration at: {config_path}")
assert config_path.is_dir()

### Setup and Test Data

`InferenceWrapper` is an adaptor that allows interaction with inference service.

In [None]:
smis = [
    'c1ccc2ccccc2c1',
    'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
]

In [None]:
from bionemo.triton.utils import load_model_config

cfg = load_model_config(config_path, config_name="infer.yaml")

In [None]:
from bionemo.triton.utils import load_model_for_inference
from bionemo.model.molecule.megamolbart.infer import MegaMolBARTInference

inferer = load_model_for_inference(cfg, interactive=True)

print(f"Loaded a {type(inferer)}")
assert isinstance(inferer, MegaMolBARTInference)

### SMILES to hidden state

`seq_to_hiddens` obtains the model's latent space representation of the SMILES.

In [None]:
hidden_states, pad_masks = inferer.seq_to_hiddens(smis)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")

assert tuple(hidden_states.shape) == (2, 45, 512)
assert tuple(pad_masks.shape) == (2, 45)

In [None]:
embeddings = inferer.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 512)

### SMILES to Embedding

`smis_to_embedding` queries the model to fetch the encoder embedding for the input SMILES.

In [None]:
embedding = inferer.seq_to_embeddings(smis)
print(f"{embeddings.shape=}")
assert tuple(embedding.shape) == (2, 512)

Note that this is equivalent to first producing the hidden representation, then using the input mask to produce embeddings with the encoder.

### Hidden state to SMILES

`hidden_to_smis` decodes the latent space representation back to SMILES.

In [None]:
from rdkit import Chem


def canonicalize_smiles(smiles: str) -> str:
    """Canonicalize input SMILES"""
    mol = Chem.MolFromSmiles(smiles)
    canon_smiles = Chem.MolToSmiles(mol, canonical=True)
    return canon_smiles

In [None]:
infered_smis = inferer.hiddens_to_seq(hidden_states, pad_masks)
canon_infered_smis = list(map(canonicalize_smiles, infered_smis))
print(f"Reconstructed SMILES:\n{canon_infered_smis}")
assert len(canon_infered_smis) == 2
for i, (original, reconstructed) in enumerate(zip(smis, canon_infered_smis)):
    assert original == reconstructed, f"Failure to recongstruct on #{i+1}: {original=}, {reconstructed=}"

### Sampling: Generate SMILES


In [None]:
def sample(smile: str) -> List[str]:
    return inferer.sample(num_samples=10, return_embedding=False, sampling_method="greedy-perturbate", smis=[smile])

samples = [sample(smile) for smile in smis]
print(f"Generated {len(samples)} samples")

assert len(samples) == 2
for i,s in enumerate(samples):
    print(f"Sample #{i+1} (length: {len(s)}):\n{s}\n-----------------------")
    assert len(s) == 10