# Inference Sample for ESM2

Copyright (c) 2022, NVIDIA CORPORATION. Licensed under the Apache License, Version 2.0 (the "License") you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.


### Prerequisite

- Linux OS
- Pascal, Volta, Turing, or an NVIDIA Ampere architecture-based GPU.
- NVIDIA Driver
- Docker

Components for inference are part of the BioNeMo ESM source code. This notebook demonstrates the use of these components.

To run this notebook, you must launch the PyTriton based model server script beforehand. See `bionemo.triton.inference_wrapper` for instructions.

In [None]:
from bionemo.triton.inference_wrapper import new_inference_wrapper
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### Setup and Test Data

`new_inference_wrapper` creates a client that communicates with the Triton model server.

Please note that the batch size (number of sequences submitted for embedding inference), in other words, inference throughput, may be limited by the compute capacity of the inference node hosting the model. 

In [None]:
connection = new_inference_wrapper("grpc://localhost:8001")

seqs = [
    'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL','MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA',
]

### Sequence to Hidden States

__`seq_to_hiddens`__ queries the model to fetch the encoder hiddens states for the input protein sequence. `enc_mask` is returned with `hidden_states` and contains padding information  

In [None]:
hidden_states, pad_masks = connection.seqs_to_hidden(seqs)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")
assert tuple(hidden_states.shape) == (2, 43, 1280)
assert tuple(pad_masks.shape) == (2, 43)

### Hidden States to Embedding

__`hiddens_to_embedding`__ computes embedding vector by averaging `hidden_states` 

In [None]:
embeddings = connection.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)

### Sequence to Embedding

__`seq_to_embedding`__  queries the model to fetch the encoder hiddens states and computes embedding vector

In [None]:
embeddings = connection.seqs_to_embedding(seqs)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)