## Summary

**Parameters**

- `SEQUENCE_GENERATION_METHOD`
- `STRUCTURE_ID`
- `SLURM_ARRAY_TASK_ID`


**Notes:**

- `astar` method should be given >= 64G memory in order to generate 200k sequences.
- `astar` cannot be ran in parallel.

**SLURM scripts**

```bash
export STRUCTURE_ID="4beuA02"

SEQUENCE_GENERATION_METHOD="astar" sbatch --mem 64G --time 72:00:00 ./scripts/run_notebook_gpu.sh $(realpath notebooks/10_generate_protein_sequences.ipynb)

SEQUENCE_GENERATION_METHOD="expectimax" sbatch --mem 32G --time 24:00:00 --array=1-3 ./scripts/run_notebook_gpu.sh $(realpath notebooks/10_generate_protein_sequences.ipynb)

SEQUENCE_GENERATION_METHOD="randexpectimax" sbatch --mem 32G --time 24:00:00 --array=1-3 ./scripts/run_notebook_gpu.sh $(realpath notebooks/10_generate_protein_sequences.ipynb)
```

----

## Imports

In [None]:
import gzip
import heapq
import io
import json
import os
import shutil
import time
from pathlib import Path

import kmtools.sci_tools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import proteinsolver
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch_geometric
from IPython.display import HTML, display
from kmbio import PDB
from torch_geometric.data import Batch
from tqdm.notebook import tqdm

## Properties

In [None]:
NOTEBOOK_NAME = "generate_protein_sequences"

In [None]:
NOTEBOOK_PATH = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_PATH.mkdir(exist_ok=True)
NOTEBOOK_PATH

In [None]:
UNIQUE_ID = "191f05de"

In [None]:
BEST_STATE_FILES = {
    #
    "191f05de": "protein_train/191f05de/e53-s1952148-d93703104.state"
}

In [None]:
# STRUCTURE_ID = os.getenv("STRUCTURE_ID", "5vli02")
# STRUCTURE_ID = os.getenv("STRUCTURE_ID", "1n5uA03")
STRUCTURE_ID = os.getenv("STRUCTURE_ID", "4z8jA00")
# STRUCTURE_ID = os.getenv("STRUCTURE_ID", "4unuA00")
# STRUCTURE_ID = os.getenv("STRUCTURE_ID", "4beuA02")

STRUCTURE_ID

In [None]:
STRUCTURE_FILE = Path(
    os.getenv(
        "STRUCTURE_FILE",
        NOTEBOOK_PATH.parent.parent / "proteinsolver" / "data" / "inputs" / f"{STRUCTURE_ID}.pdb",
    )
).resolve()

STRUCTURE_FILE

In [None]:
min_expected_proba_preset = {
    #
    "1n5uA03": 0.20,
    "4z8jA00": 0.29,
    "4unuA00": 0.25,
    "4beuA02": 0.25,
}
MIN_EXPECTED_PROBA = min_expected_proba_preset.get(STRUCTURE_ID, 0.15)

MIN_EXPECTED_PROBA

In [None]:
SEQUENCE_GENERATION_METHOD = os.getenv("SEQUENCE_GENERATION_METHOD", "expectimax")

assert SEQUENCE_GENERATION_METHOD in ("astar", "expectimax", "randexpectimax", "root2expectimax", "root10expectimax")
SEQUENCE_GENERATION_METHOD

In [None]:
START_FILE_INDEX = int(os.getenv("SLURM_ARRAY_TASK_ID", 0)) * 1000

START_FILE_INDEX

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

## Helper functions

In [None]:
@torch.no_grad()
def design_sequence(net, data, random_position=False, value_selection_strategy="map", num_categories=None):
    assert value_selection_strategy in ("map", "multinomial", "ref")

    if num_categories is None:
        num_categories = data.x.max().item()

    if hasattr(data, "batch"):
        batch_size = data.batch.max().item() + 1
    else:
        print("Defaulting to batch size of one.")
        batch_size = 1

    if value_selection_strategy == "ref":
        x_ref = data.y if hasattr(data, "y") and data.y is not None else data.x

    x = torch.ones_like(data.x) * num_categories
    x_proba = torch.zeros_like(x).to(torch.float)
    index_array_ref = torch.arange(x.size(0))
    mask_ref = x == num_categories
    while mask_ref.any():
        output = net(x, data.edge_index, data.edge_attr)
        output_proba_ref = torch.softmax(output, dim=1)
        output_proba_max_ref, _ = output_proba_ref.max(dim=1)

        for i in range(batch_size):
            mask = mask_ref
            if batch_size > 1:
                mask = mask & (data.batch == i)

            index_array = index_array_ref[mask]
            max_probas = output_proba_max_ref[mask]

            if random_position:
                selected_residue_subindex = torch.randint(0, max_probas.size(0), (1,)).item()
                max_proba_index = index_array[selected_residue_subindex]
            else:
                selected_residue_subindex = max_probas.argmax().item()
                max_proba_index = index_array[selected_residue_subindex]

            assert x[max_proba_index] == num_categories
            assert x_proba[max_proba_index] == 0
            category_probas = output_proba_ref[max_proba_index]

            if value_selection_strategy == "map":
                chosen_category_proba, chosen_category = category_probas.max(dim=0)
            elif value_selection_strategy == "multinomial":
                chosen_category = torch.multinomial(category_probas, 1).item()
                chosen_category_proba = category_probas[chosen_category]
            else:
                assert value_selection_strategy == "ref"
                chosen_category = x_ref[max_proba_index]
                chosen_category_proba = category_probas[chosen_category]

            assert chosen_category != num_categories
            x[max_proba_index] = chosen_category
            x_proba[max_proba_index] = chosen_category_proba
        mask_ref = x == num_categories
        del output, output_proba_ref, output_proba_max_ref
    return x.cpu(), x_proba.cpu()

In [None]:
from dataclasses import dataclass
from dataclasses import field
from typing import Any


def load_heap_dump(heap_file):
    try:
        pfile = pq.ParquetFile(heap_file)
        heap = []
        for row_group in pfile.num_row_groups:
            df = pfile.read_row_group(row_group).to_parquet()
            heap = heap + [
                proteinsolver.utils.PrioritizedItem(
                    tup.p, torch.tensor(tup.x, dtype=torch.int8), tup.total_proba, tup.total_logproba
                )
                for tup in df.itertuples()
            ]
    except Exception as e:
        print(f"Encountered error loading heap file '{heap_file}': '{e}'.")

    heap_file_bak = heap_file.with_suffix(".parquet.bak")
    try:
        pfile = pq.ParquetFile(heap_file)
        heap = []
        for row_group in pfile.num_row_groups:
            df = pfile.read_row_group(row_group).to_parquet()
            heap = heap + [
                proteinsolver.utils.PrioritizedItem(
                    tup.p, torch.tensor(tup.x, dtype=torch.int8), tup.total_proba, tup.total_logproba
                )
                for tup in df.itertuples()
            ]
    except Exception as e:
        print(f"Encountered error loading heap file '{heap_file_bak}': '{e}'.")


def update_heap_dump(heap_file, heap):
    try:
        shutil.copy2(heap_file, heap_file.with_suffix(".parquet.bak"))
    except FileNotFoundError:
        pass

    df = pd.DataFrame(
        [
            {"p": pi.p, "x": pi.x.data.tolist(), "total_proba": pi.total_proba, "total_logproba": pi.total_logproba}
            for pi in heap
        ]
    )
    chunk_size = 100_000
    writer = None
    for start in range(0, len(df), chunk_size):
        df_chunk = df[start : start + chunk_size]
        table = pa.Table.from_pandas(df_chunk, preserve_index=False)
        if writer is None:
            writer = pq.ParquetWriter(heap_file, table.schema)
        writer.write_table(table)
    writer.close()


def get_descendents(net, x, total_proba, total_logproba, edge_index, edge_attr, cutoff):
    index_array = torch.arange(x.size(0))
    mask = x == 20

    with torch.no_grad():
        output = net(x, edge_index, edge_attr).cpu()
        output = torch.softmax(output, dim=1)

    output = output[mask]
    index_array = index_array[mask]

    max_proba, max_index = output.max(dim=1)[0].max(dim=0)
    row_with_max_proba = output[max_index]

    assert total_logproba <= 0, total_logproba

    children = []
    for i, p in enumerate(row_with_max_proba):
        x_clone = x.clone()
        assert x_clone[index_array[max_index]] == 20
        x_clone[index_array[max_index]] = i
        total_proba_clone = total_proba - cutoff + p.item()
        total_logproba_clone = total_logproba - np.log(cutoff) + np.log(p.item())
        children.append((x_clone, total_proba_clone, total_logproba_clone))
    return children


def design_sequence_astar(
    net, data, cutoff, num_categories=20, max_results=5_000, max_heap_size=100_000_000, heap=None
):
    # TODO: keep only total probabilities and log-probabilities instead of the entire array

    assert num_categories < 128  # So that we can store x as int8
    total_proba = cutoff * data.x.size(0)
    total_logproba = np.log(cutoff) * data.x.size(0)
    if heap is None:
        heap = [
            proteinsolver.utils.PrioritizedItem(
                -total_logproba, data.x.cpu().to(torch.int8), total_proba, total_logproba
            )
        ]
    results = []

    pbar = tqdm(total=max_results)
    while len(results) < max_results:
        try:
            item = heapq.heappop(heap)
        except IndexError:
            break
        if not (item.x == num_categories).any():
            assert item.x.dtype == torch.int8
            results.append((item.x.data, item.total_proba, item.total_logproba))
            pbar.update(1)
        else:
            children = get_descendents(
                net,
                item.x.to(torch.long).to(device),
                item.total_proba,
                item.total_logproba,
                data.edge_index,
                data.edge_attr,
                cutoff,
            )
            for x, total_proba, total_logproba in children:
                heapq.heappush(
                    heap,
                    proteinsolver.utils.PrioritizedItem(
                        -total_logproba, x.cpu().to(torch.int8), total_proba, total_logproba
                    ),
                )
        if len(heap) > max_heap_size:
            heap = heap[: len(heap) // 2]
            heapq.heapify(heap)
    return results, heap

## Load structure

In [None]:
structure_all = PDB.load(STRUCTURE_FILE)
if STRUCTURE_ID in ["5vli02"]:
    chain_id = "C"
else:
    chain_id = "A"
structure = PDB.Structure(STRUCTURE_FILE.name + chain_id, structure_all[0].extract(chain_id))
assert len(list(structure.chains)) == 1

In [None]:
view = PDB.view_structure(structure)

view

## Load model

In [None]:
%run protein_train/{UNIQUE_ID}/model.py

In [None]:
batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
frac_present = 0.5
frac_present_valid = frac_present
info_size= 1024

In [None]:
state_file = BEST_STATE_FILES[UNIQUE_ID]
state_file

In [None]:
net = Net(
    x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
)
net.load_state_dict(torch.load(state_file, map_location=device))
net.eval()
net = net.to(device)

## Design pipeline

### Load protein sequence and geometry

In [None]:
pdata = proteinsolver.utils.extract_seq_and_adj(structure, chain_id)
# print(pdata)

In [None]:
sequence_ref = pdata.sequence
print(len(sequence_ref), sequence_ref)

### Convert data to suitable format

In [None]:
data = proteinsolver.datasets.protein.row_to_data(pdata)
data = proteinsolver.datasets.protein.transform_edge_attr(data)

### Basic model statistics

In [None]:
model_stats = {}

In [None]:
residues, residue_probas = design_sequence(
    net, data.to(device), random_position=False, value_selection_strategy="map", num_categories=20
)

model_stats.update(
    {
        "map_sequence_identity": sum(
            proteinsolver.utils.AMINO_ACIDS[r] == sequence_ref[i] for (i, r) in enumerate(residues)
        )
        / len(sequence_ref),
        "map_proba": residue_probas.mean().item(),
        "map_logproba": residue_probas.log().mean().item(),
    }
)

In [None]:
residues, residue_probas = design_sequence(
    net, data.to(device), random_position=False, value_selection_strategy="ref", num_categories=20
)

model_stats.update(
    {
        "ref_sequence_identity": sum(
            proteinsolver.utils.AMINO_ACIDS[r] == sequence_ref[i] for (i, r) in enumerate(residues)
        )
        / len(sequence_ref),
        "ref_proba": residue_probas.mean().item(),
        "ref_logproba": residue_probas.log().mean().item(),
    }
)

In [None]:
model_stats

In [None]:
model_stats_file = NOTEBOOK_PATH.joinpath(f"stats-{UNIQUE_ID}-{STRUCTURE_FILE.stem}.json")
with model_stats_file.open("wt") as fout:
    json.dump(model_stats, fout)
    
model_stats_file

### Run protein design using expectimax search

In [None]:
amino_acids = proteinsolver.utils.AMINO_ACIDS

In [None]:
def get_output_file(file_index):
    return NOTEBOOK_PATH.joinpath(f"designs-{UNIQUE_ID}-{SEQUENCE_GENERATION_METHOD}-{STRUCTURE_FILE.stem}-{file_index}.parquet")

In [None]:
file_index = START_FILE_INDEX

while get_output_file(file_index).is_file():
    file_index += 1
    
file_index

In [None]:
start_time = time.perf_counter()
random_position = SEQUENCE_GENERATION_METHOD.startswith("rand")
print(f"random_position: {random_position}")

batch_size = int(
    512
    * (3586 / (data.x.size(0) + data.edge_attr.size(0)))
    * (torch.cuda.get_device_properties(device).total_memory / 12_650_217_472)
)
print(f"batch_size: {batch_size}")

batch_size = 1

data_batch = Batch.from_data_list([data.clone() for _ in range(batch_size)]).to(device)
data_batch.x = torch.ones_like(data_batch.x) * 20

batch_values, batch_probas = design_sequence(
    net, data_batch, random_position=random_position, value_selection_strategy="multinomial"
)
for i in range(batch_size):
    values = batch_values[data_batch.batch == i]
    probas = batch_probas[data_batch.batch == i]
    sequence = "".join(amino_acids[i] for i in values)
    probas_sum = probas.sum().item()
    probas_log_sum = probas.log().sum().item()

print(f"Elapsed time: {time.perf_counter() - start_time}.")

In [None]:
timing_stats = {
    "5vli02": np.mean(
        [
            0.13965036603622139,
            0.1378761320374906,
            0.14166183699853718,
            0.1354912610258907,
            0.1416573489550501,
        ]
    ),
    "1n5uA03": np.mean(
        [
            0.2606568201445043,
            0.273853771854192,
            0.274543966865167,
            0.25143068190664053,
            0.2526147000025958,
        ]
    ),
    "4z8jA00": np.mean(
        [
            0.2858221740461886,
            0.2860491331666708,
            0.28511124989017844,
            0.2742918790318072,
            0.2711744031403214,
        ]
    ),
    "4unuA00": np.mean(
        [
            0.36467301612719893,
            0.34314795304089785,
            0.33564279321581125,
            0.37832364114001393,
            0.3413601068314165,
        ]
    ),
    "4beuA02": np.mean(
        [
            1.2793075828813016,
            1.2771926030982286,
            1.2803262548986822,
            1.276441911002621,
            1.2801529061980546,
        ]
    ),
}
print(timing_stats)
print(np.mean(list(timing_stats.values())))