In [None]:
import gzip
import heapq
import io
import json
import os
import shutil
import time
from pathlib import Path
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import kmtools.sci_tools
import numpy as np
import pandas as pd
import proteinsolver
import pyarrow as pa
import pyarrow.parquet as pq
import torch
import torch_geometric
#from IPython.display import HTML, display
from kmbio import PDB
from torch_geometric.data import Batch
from tqdm.notebook import tqdm
from sklearn.preprocessing import OneHotEncoder

acids = [
    "A",
    "C",
    "D",
    "E",
    "F",
    "G",
    "H",
    "I",
    "K",
    "L",
    "M",
    "N",
    "P",
    "Q",
    "R",
    "S",
    "T",
    "V",
    "W",
    "Y",
]

@torch.no_grad()
def design_sequence(net, data, random_position=False, value_selection_strategy="map", num_categories=None):
    assert value_selection_strategy in ("map", "multinomial", "ref")

    if num_categories is None:
        num_categories = data.x.max().item()

    if hasattr(data, "batch"):
        batch_size = data.batch.max().item() + 1
    else:
        print("Defaulting to batch size of one.")
        batch_size = 1

    if value_selection_strategy == "ref":
        x_ref = data.y if hasattr(data, "y") and data.y is not None else data.x

    x = torch.ones_like(data.x) * num_categories
    x_proba = torch.zeros_like(x).to(torch.float)
    index_array_ref = torch.arange(x.size(0))
    mask_ref = x == num_categories
    while mask_ref.any():
        output = net(x, data.edge_index, data.edge_attr)
        output_proba_ref = torch.softmax(output, dim=1)
        output_proba_max_ref, _ = output_proba_ref.max(dim=1)

        for i in range(batch_size):
            mask = mask_ref
            if batch_size > 1:
                mask = mask & (data.batch == i)

            index_array = index_array_ref[mask]
            max_probas = output_proba_max_ref[mask]

            if random_position:
                selected_residue_subindex = torch.randint(0, max_probas.size(0), (1,)).item()
                max_proba_index = index_array[selected_residue_subindex]
            else:
                selected_residue_subindex = max_probas.argmax().item()
                max_proba_index = index_array[selected_residue_subindex]

            assert x[max_proba_index] == num_categories
            assert x_proba[max_proba_index] == 0
            category_probas = output_proba_ref[max_proba_index]

            if value_selection_strategy == "map":
                chosen_category_proba, chosen_category = category_probas.max(dim=0)
            elif value_selection_strategy == "multinomial":
                chosen_category = torch.multinomial(category_probas, 1).item()
                chosen_category_proba = category_probas[chosen_category]
            else:
                assert value_selection_strategy == "ref"
                chosen_category = x_ref[max_proba_index]
                chosen_category_proba = category_probas[chosen_category]

            assert chosen_category != num_categories
            x[max_proba_index] = chosen_category
            x_proba[max_proba_index] = chosen_category_proba
        mask_ref = x == num_categories
        del output, output_proba_ref, output_proba_max_ref
    return x.cpu(), x_proba.cpu()
    

def run_ps(path_to_assemblies:Path, dataset_list:Path):
    #load the model
    state_file = '/home/s1706179/Proteinsolver/e53-s1952148-d93703104.state'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    %run /home/s1706179/Proteinsolver/model.py

    batch_size = 512
    num_features = 20
    adj_input_size = 2
    hidden_size = 128
    frac_present = 0.5
    frac_present_valid = frac_present
    info_size= 1024

    net = Net(
        x_input_size=num_features + 1, adj_input_size=adj_input_size, hidden_size=hidden_size, output_size=num_features
    )
    net.load_state_dict(torch.load(state_file, map_location=device))
    net.eval()
    net = net.to(device)
    #run predictions
    with open(dataset_list,'r') as file:
        structures = [x.strip("\n") for x in file.readlines()]
    results={}
    for protein in structures:
        STRUCTURE_FILE = path_to_assemblies/(protein[:4]+'.pdb')
        chain_id=protein[-1]
        try:
            structure_all = PDB.load(STRUCTURE_FILE)
            structure = PDB.Structure(STRUCTURE_FILE.name + chain_id, structure_all[0].extract(chain_id))
            pdata = proteinsolver.utils.extract_seq_and_adj(structure, chain_id)
            data = proteinsolver.datasets.protein.row_to_data(pdata)
            data = proteinsolver.datasets.protein.transform_edge_attr(data)
            residues, residue_probas = design_sequence(
                net, data.to(device), random_position=False, value_selection_strategy="map", num_categories=20
            )
            results[protein] = "".join(proteinsolver.utils.AMINO_ACIDS[i] for i in residues)
        except ValueError:
            continue
            

    enc=OneHotEncoder(categories=[acids],sparse=False)
    predicted_sequences = []
    with open('/home/s1706179/Proteinsolver/proteinsolver_nmr.txt','w') as file:
        file.write(f"ignore_uncommon False\ninclude_pdbs\n##########\n")
        for chain in results:
            predicted_sequences+=list(results[chain])
            file.write(f"{chain} {len(results[chain])}\n")
    arr=enc.fit_transform(np.array(predicted_sequences).reshape(-1, 1))
    pd.DataFrame(arr).to_csv("/home/s1706179/Proteinsolver/proteinsolver_nmr.csv", header=None, index=None)
    
run_ps(Path("/home/s1706179/Rosetta/empty_nmr_backbones/"),Path("/home/s1706179/Rosetta/data/nmr_set.txt"))