# Inference Sample for OpenFold

SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \
SPDX-License-Identifier: LicenseRef-NvidiaProprietary

NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.

### Prerequisite

- Linux OS
- Pascal, Volta, Turing, or an NVIDIA Ampere architecture-based GPU.
- NVIDIA Driver
- Docker

#### Import

Components for inferencing are part of the BioNeMo OpenFold source code. This notebook demonstrates the use of these components.


In [None]:
import os
import numpy as np
import py3Dmol
from pytriton.client import ModelClient
from bionemo.triton.utils import read_bytes_from_filepaths
from typing import List

In [2]:
BIONEMO_HOME = os.environ['BIONEMO_HOME']
TEST_DATA_DIR = os.path.join(BIONEMO_HOME, 'examples/tests/test_data/openfold_data')
MSA_DIR = os.path.join(TEST_DATA_DIR, 'inference', 'msas')

# Start inference server
Before starting the inference server, users should download OpenFold NeMo checkpoints through `download_artifacts.py` with the following command.

```python download_artifacts.py --models openfold_finetuning_inhouse --model_dir ${BIONEMO_HOME}/models```

Now, we can initialize the inference server.

`python examples/protein/openfold/nbs/infer_server.py`

If you get an error similar to 
```PermissionError: [Errno 13] Permission denied: '/workspace/bionemo/.cache/pytriton/workspace_* .... UserWarning: The version_base parameter is not specified. Please specify a compatability version level, or None. Will assume defaults for version 1.1 with initialize(config_path=config_path)``` 
run ``` rm -rf <path to bionemo repo clone/.cache/pytriton``` from a shell external to the running container, and then retry.

# Inputs
## Set input sequences and optional MSAs
In case users are interested in inference with MSAs, precomputed MSAs for example sequences are in 
```examples/tests/test_data/openfold_data/inference/msas```

In [3]:
sequences = [
    'MDTAMQLKTSIGLITCRMNTQNNQIETILVQKRYSLAFSEFIHCHYSINANQGHLIKMFNNMTINERLLVKTLDFDRMWYHIWIETPVYELYHKKYQKFRKNWLLPDNGKKLISLINQAKGSGTLLWEIPKGKPKEDESDLTCAIREFEEETGITREYYQILPEFKKSMSYFDGKTEYKHIYFLAMLCKSLEEPNMNLSLQYENRIAEISKISWQNMEAVRFISKRQSFNLEPMIGPAFNFIKNYLRYKH', # 7DNU_A# 
    'MDTAMQLKTSIGLITCRMNTQNNQIETILVQKRYSLAFSEFIHCHYSINANQGHLIKMFNNMTINERLLVKTLDFDRMWYHIWIETPVYELYHKKYQKFRKNWLLPDNGKKLISLINQAKGSGTLLWEIPKGKPKEDESDLTCAIREFEEETGITREYYQILPEFKKSMSYFDGKTEYKHIYFLAMLCKSLEEPNMNLSLQYENRIAEISKISWQNMEAVRFISKRQSFNLEPMIGPAFNFIKNYLRYKH', # 7DNU_A# 

]

### Prepare for optional MSA inputs

In [4]:
# Use sample MSA inputs
msa_a3m_filepaths: List[List[str]] = [
    [
        os.path.join(MSA_DIR, '7dnu_A', 'bfd_uniclust_hits.a3m'),
        # os.path.join(MSA_DIR, '7dnu_A', 'mgnify_hits.a3m'),
        # os.path.join(MSA_DIR, '7dnu_A', 'uniref90_hits.a3m'),
    ],
    [
        os.path.join(MSA_DIR, '7dnu_A', 'bfd_uniclust_hits.a3m'),
        os.path.join(MSA_DIR, '7dnu_A', 'mgnify_hits.a3m'),
        os.path.join(MSA_DIR, '7dnu_A', 'uniref90_hits.a3m'),
    ],
]

# Or users can skip MSA inputs
# msa_a3m_filepaths: List[List[str]] = [[], []]

In [5]:
# Test inference inputs to ensure the same number of samples
if len(sequences) != len(msa_a3m_filepaths):
    raise ValueError(
        f'Sequences and msa_a3m_filepaths have inconsistent number of inference samples. Got {len(sequences), len(msa_a3m_filepaths)}'
    )

# Read contents from msa_a3m_filepaths and batch into array for Triton 
max_length = max(len(msa_a3m_filepaths_) for msa_a3m_filepaths_ in msa_a3m_filepaths)

if max_length:
    msa_a3m_file_contents = []
    for msa_a3m_filepaths_ in msa_a3m_filepaths:
        msa_a3m_file_contents_: List[bytes] = read_bytes_from_filepaths(*msa_a3m_filepaths_)  # reading
        msa_a3m_file_contents_ += [''] * (max_length - len(msa_a3m_file_contents_))  # batching
        msa_a3m_file_contents.append(msa_a3m_file_contents_)
else:
    msa_a3m_file_contents = [[''],] * len(msa_a3m_filepaths)

# Convert input for Triton inference
sequences = np.array(sequences)
sequences = np.char.encode(sequences, 'utf-8')
msa_a3m_file_contents = np.array(msa_a3m_file_contents)
msa_a3m_file_contents = np.char.encode(msa_a3m_file_contents, 'utf-8')

# Batching
sequences_batch = sequences[np.newaxis, ...]
msa_a3m_file_contents_batch = msa_a3m_file_contents[np.newaxis, ...]

###  Use ModelClient to run inference on the server

The config for OpenFold inference is examples/protein/openfold/conf/infer.yaml.


In [6]:
# inference through Triton server
with ModelClient("localhost", "bionemo_openfold", inference_timeout_s= 180) as client:
    output_pdb_strings = client.infer_batch(
        sequences_batch=sequences_batch,
        msa_a3m_file_contents_batch=msa_a3m_file_contents_batch,
    )
    output_pdb_strings = output_pdb_strings['output_pdb_string']
    output_pdb_strings = np.char.decode(output_pdb_strings.astype('bytes'), 'utf-8')
    output_pdb_strings = output_pdb_strings.tolist()[0]


## Visualize the prediction results

In [7]:
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(output_pdb_strings[1])
view.setStyle({'model': -1}, {"cartoon": {'color': 'purple'}})
view.zoomTo()
view.show()