## CAUTION
You need to add a libraries from a [separate project](https://github.com/skrhakv/pocket-movement-analysis/blob/master/src/utils/biotite_utils.py) (update the `sys.path.append(...)` path). 

In [2]:
from biotite.structure import distance
import numpy as np  
import os
import sys
from biotite.structure import superimpose
import biotite
import biotite.database.rcsb as rcsb
import os
from biotite.structure.io.pdbx import get_structure
import biotite.structure.io.pdbx as pdbx
from biotite.sequence import ProteinSequence

sys.path.append('/home/vit/Projects/flexibility-analysis/src/utils')
import biotite_utils
import dataset_utils


In [36]:
CIF_FILES_PATH = '/home/vit/Projects/deeplife-project/data/cif_files'


##########################################################################################################################
#                                                                                                                        #
# This is mostly coppied from https://github.com/skrhakv/pocket-movement-analysis/blob/master/src/utils/biotite_utils.py #
#                                                                                                                        #
##########################################################################################################################

mapping = {'Aba': 'A', 'Ace': 'X', 'Acr': 'X', 'Ala': 'A', 'Aly': 'K', 'Arg': 'R', 'Asn': 'N', 'Asp': 'D', 'Cas': 'C',
           'Ccs': 'C', 'Cme': 'C', 'Csd': 'C', 'Cso': 'C', 'Csx': 'C', 'Cys': 'C', 'Dal': 'A', 'Dbb': 'T', 'Dbu': 'T',
           'Dha': 'S', 'Gln': 'Q', 'Glu': 'E', 'Gly': 'G', 'Glz': 'G', 'His': 'H', 'Hse': 'S', 'Ile': 'I', 'Leu': 'L',
           'Llp': 'K', 'Lys': 'K', 'Men': 'N', 'Met': 'M', 'Mly': 'K', 'Mse': 'M', 'Nh2': 'X', 'Nle': 'L', 'Ocs': 'C',
           'Pca': 'E', 'Phe': 'F', 'Pro': 'P', 'Ptr': 'Y', 'Sep': 'S', 'Ser': 'S', 'Thr': 'T', 'Tih': 'A', 'Tpo': 'T',
           'Trp': 'W', 'Tyr': 'Y', 'Unk': 'X', 'Val': 'V', 'Ycm': 'C', 'Sec': 'U', 'Pyl': 'O', 'Mhs': 'H', 'Snm': 'S',
           'Mis': 'S', 'Seb': 'S', 'Hic': 'H', 'Fme': 'M', 'Asb': 'D', 'Sah': 'C', 'Smc': 'C', 'Tpq': 'Y', 'Onl': 'X',
           'Tox': 'W', '5x8': 'X', 'Ddz': 'A'}


def three_to_one(three_letter_code):
    if three_letter_code[0].upper() + three_letter_code[1:].lower() not in mapping:
        return 'X'
    return mapping[three_letter_code[0].upper() + three_letter_code[1:].lower()]

def get_sequence(protein):
    sequence = ''.join([three_to_one(residue.res_name) for residue in protein])
    return ProteinSequence(sequence)



def get_protein_backbone(id, indices=None):
    pdb_id = id[:4]
    chain_id = id[4:]

    cif_file_path = rcsb.fetch(pdb_id, "cif", CIF_FILES_PATH)

    cif_file = pdbx.CIFFile.read(cif_file_path)

    protein = get_structure(cif_file, model=1)

    protein_backbone = protein[(protein.chain_id == chain_id) & (biotite.structure.filter_peptide_backbone(protein))]
    # the following code taken from the biotite source code:
    # https://github.com/biotite-dev/biotite/blob/v0.41.0/src/biotite/structure/residues.py#L22
    chain_id_changes = (protein_backbone.chain_id[1:] != protein_backbone.chain_id[:-1])
    res_id_changes   = (protein_backbone.res_id[1:]   != protein_backbone.res_id[:-1]  )
    ins_code_changes = (protein_backbone.ins_code[1:] != protein_backbone.ins_code[:-1])
    res_name_changes = (protein_backbone.res_name[1:] != protein_backbone.res_name[:-1])
    residue_change_mask = (
        chain_id_changes |
        res_id_changes |
        ins_code_changes |
        res_name_changes
    )

    residue_starts = np.where(residue_change_mask)[0] + 1
    residue_starts = np.concatenate(([0], residue_starts))

    # take the C_alphas if possible (now we have the N atoms, however, regularly the C_alphas are located right behind the N atoms, so increase the indices by 1)
    if len(protein_backbone) > residue_starts[-1] + 1:
        residue_starts = residue_starts + 1

    protein_backbone = protein_backbone[residue_starts]

    if indices:
        return protein_backbone[indices]
    
    return protein_backbone

##########################################################################################################################
#                                                                                                                        #
#                                                                                                                        #
#                                                                                                                        #
##########################################################################################################################

def compute_distances():
    dataset_path = f'../cryptobench-dataset/dataset.json'
    output_path = f'../data/residue-distances'
    sequence_path = f'../data/sequences'
    
    dataset = dataset_utils.load_subset(dataset_path)
    apo_holo_pairs = dataset_utils.load_main_apo_holo_pairs(dataset, multichain=False)

    for apo, holo in apo_holo_pairs.items():
        print(f'Processing {apo}:{holo}')

        # Some random biotite error for those pairs; I don't have the mental capacity to deal with this - skipping
        # if apo == "8j1kA" and holo == "3ouiA":
        #     continue
        # if apo == "8pfpA" and holo == "7gquA":
        #     continue
        if apo == "5lgrB" or apo == "8h6pA":
            continue

        if os.path.exists(f'{output_path}/{apo}.npy'):
            continue

        apo_backbone = get_protein_backbone(apo)
        apo_sequence = get_sequence(apo_backbone)

        holo_backbone = get_protein_backbone(holo)
        holo_sequence = get_sequence(holo_backbone)

        alignment = biotite_utils.align_sequences(apo_sequence, holo_sequence)

        # we need to filter out the non-matching residues, otherwise the superimpose function will fail
        apo_indices = [i[0] for i in alignment[0].trace]
        holo_indices = [i[1] for i in alignment[0].trace]

        # get the structures with respect to indices
        apo_backbone = get_protein_backbone(apo, indices=apo_indices)
        holo_backbone = get_protein_backbone(holo, indices=holo_indices)

        holo_backbone, _ = superimpose(apo_backbone, holo_backbone)

        # rerun the alignment - the indices might got shifted due to the filtering for the sake of superimposing
        apo_sequence = get_sequence(apo_backbone)
        holo_sequence = get_sequence(holo_backbone)

        # print(f'Length of apo sequence: {len(apo_sequence)}, Length of holo sequence: {len(holo_sequence)}')
        alignment = biotite_utils.align_sequences(apo_sequence, holo_sequence)

        # get mapping of indices from the filtered structure to the structure that matches the data from fluctuation:
        original_apo_backbone = get_protein_backbone(apo)
        # print(f'Length of original apo backbone: {len(original_apo_backbone)}')
        # print(f'Length of indiced apo backbone: {len(apo_backbone)}')
        original_apo_residue_ids = [residue.res_id for residue in original_apo_backbone]
        # print(original_apo_residue_ids)
        # print([residue.res_id for i, residue in enumerate(apo_backbone)])
        res_id_to_index = {i: original_apo_residue_ids.index(residue.res_id) for i, residue in enumerate(apo_backbone)}

        distances = np.full(len(original_apo_backbone), -1, dtype=np.float16)

        for i in alignment[0].trace:
            apo_index, holo_index = i
            # print(apo_index)
            # print(res_id_to_index[apo_index])
            distances[res_id_to_index[apo_index]] = distance(apo_backbone[apo_index], holo_backbone[holo_index])

        with open(f'{sequence_path}/{apo}.txt', 'r') as file:
            sequence = file.read()

        assert len(sequence) == len(distances), f'{apo}: {len(sequence)} != {len(distances)}'
        np.save(f'{output_path}/{apo}.npy', distances)

compute_distances()

# TODO:
# 1. napsat report z tohohle meetingu (viz dole)
# 2. proc se to nespocitalo pro vsechny??? pouze 933 z 946
# 3. zkusit ten multitask learning (mean square logarithm loss? zni lepe nez MSE ale je potreba zkusit oboji I guess)


# rigid (funguji grafovky opravdu hur? rozsirit threshold na 8A)
# multitask
# analyza chyb (jak daleko jsou chyby - jsou blizko, jestli jsou na povrchu/uvnitr, jestli jsou na static vazebnem miste, Metrika od Yany)
# dataset rozsireni
# 

Processing 1a4uB:3rj9C
Processing 1a8dA:1diwA
Processing 1ad1A:6clvC
Processing 1ak1A:2q2oA
Processing 1arlA:3cpaA
Processing 1aylA:6at3B
Processing 1b0iA:1g9hA
Processing 1bfnA:1bybA
Processing 1bhsA:1i5rA
Processing 1bk2A:5ihkA
Processing 1byiA:1dakA
Processing 1bzjA:5qelA
Processing 1c3kA:1c3nA
Processing 1cuzA:1xzcA
Processing 1cwqA:1xjiA
Processing 1dc6A:6utmB
Processing 1dklA:7z2sA
Processing 1dpjA:2jxrA
Processing 1dq2A:1qdcA
Processing 1dqzA:5vnsA
Processing 1dteA:4n8sA
Processing 1e3gA:4k7aA
Processing 1e5lB:1e5qB
Processing 1e6kA:3rvrA
Processing 1eccB:1ecfA
Processing 1efhB:3f3yA
Processing 1eswA:5jiwA
Processing 1evyA:1n1eB
Processing 1ezlC:3iboC
Processing 1f47B:1s1sB
Processing 1f8aB:4tyoB
Processing 1fd9A:8bjdA
Processing 1fdpA:1dfpA
Processing 1ffhA:2cnwA
Processing 1fl1B:4p3hB
Processing 1fvrA:6mweA
Processing 1fwkC:1h72C
Processing 1g24C:2a9kB
Processing 1g59A:1j09A
Processing 1gqnA:1qfeA
Processing 1gqzA:2gkeA
Processing 1h13A:2b4fA
Processing 1h3gB:3edfA
Processing 

## Some single-chain APOs don't have single-chain main holo structure
We will supplement that by using the default value for the whole chain (we used `-1` above in case residue wasn't present in the holo structure).

In [22]:
for file in os.listdir('/home/vit/Projects/cryptic-nn/data/sequences'):
    if not os.path.exists(f'/home/vit/Projects/cryptic-nn/data/residue-distances/{file.replace(".txt", ".npy")}'):
        with open(f'/home/vit/Projects/cryptic-nn/data/sequences/{file}', 'r') as f:
            sequence = f.read()
        np.save(f'/home/vit/Projects/cryptic-nn/data/residue-distances/{file.replace(".txt", ".npy")}', np.full(len(sequence), -1, dtype=np.float16))
