In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [2]:
import esm
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np 
import h5py
import json
import re
import shutil
import random
import time

In [3]:
esm_transformer, esm2_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
batch_converter = esm2_alphabet.get_batch_converter()

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
esm_transformer = esm_transformer.to(device)

In [5]:
with open('data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

with open('data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)

In [6]:
save_file_path = '../results/single_seg_w_bos_eos.hdf5'

In [7]:
def get_contact(seq): 
    seq_tuple = [(1, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(seq_tuple)
    #batch_tokens = torch.cat((torch.full((batch_tokens.shape[0], 1), 32), batch_tokens[:, 1:-1], torch.full((batch_tokens.shape[0], 1), 32)), dim=1)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        esm2_predictions = esm_transformer.predict_contacts(batch_tokens)[0].cpu()
    return esm2_predictions.numpy()

In [8]:
def norm_sum_mult(ori_contact_seg, seg_cross_contact):
    ori_mult_new = np.multiply(ori_contact_seg, seg_cross_contact)
    ori_mult_ori = np.multiply(ori_contact_seg, ori_contact_seg)
    return (np.sum(ori_mult_new)/np.sum(ori_mult_ori))

In [13]:
previous_data = {}

with h5py.File('../results/w_bos_eos.hdf5', 'r') as f:
    for protein in f.keys():
        patch_starts = []
        for patch_start in f[protein].keys():
            patch_starts.append(int(patch_start))
        previous_data[protein] = patch_starts 

In [16]:
with open('../data/reproduce_single_recovery_w_bos_eos.json', 'w') as json_file:
    json.dump(previous_data, json_file)

In [20]:
# for a quick test
keys_for_test = ['1A8LA', '1BYIA']
test_dict = {key: previous_data[key] for key in keys_for_test if key in previous_data}
print(test_dict)

{'1A8LA': [31, 48, 92], '1BYIA': [136, 177, 80]}


In [25]:
with h5py.File(save_file_path, 'w') as f:
    proteins = test_dict.keys()
    #proteins = previous_data.keys()
    for protein in tqdm(proteins):
            seq = seq_dict[protein]
            ori_contact = get_contact(seq_dict[protein])
            
            patch_start_list = previous_data[protein]
            
            for patch_start in tqdm(patch_start_list):  
                patch_end = patch_start + 22
                flank_len_range = min(patch_start, len(seq) - patch_end + 1)
                
                for flank_len in range(flank_len_range):
                    seg_start = patch_start - flank_len 
                    seg_end = patch_end + flank_len 
                    seq_mask = seg_start *'<mask>' + seq[seg_start:seg_end] + len(seq[seg_end:])*'<mask>'

                    mask_contact_full = get_contact(seq_mask)

                    ori_contact_seg = ori_contact[patch_start:patch_end, patch_start:patch_end]
                    mask_contact_seg = mask_contact_full[patch_start:patch_end, patch_start:patch_end]

                    norm_sum_mult_value = norm_sum_mult(ori_contact_seg, mask_contact_seg)
                    
                    key0 = f'{protein}/{patch_start}/{flank_len}/ori_contact_full' 
                    key1 = f'{protein}/{patch_start}/{flank_len}/mask_contact_full' 
                    key2 = f'{protein}/{patch_start}/{flank_len}/ori_contact_seg' 
                    key3 = f'{protein}/{patch_start}/{flank_len}/mask_contact_seg' 
                    key4 = f'{protein}/{patch_start}/{flank_len}/norm_sum_mult_value' 

                    f.create_dataset(key0, data=ori_contact)
                    f.create_dataset(key1, data=mask_contact_full)
                    f.create_dataset(key2, data=ori_contact_seg)
                    f.create_dataset(key3, data=mask_contact_seg)
                    f.create_dataset(key4, data=norm_sum_mult_value)

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [00:03<00:06,  3.45s/it][A
 67%|██████▋   | 2/3 [00:08<00:04,  4.56s/it][A
100%|██████████| 3/3 [00:19<00:00,  6.34s/it][A
 50%|█████     | 1/2 [00:19<00:19, 19.15s/it]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [00:07<00:14,  7.44s/it][A
 67%|██████▋   | 2/3 [00:10<00:04,  4.76s/it][A
100%|██████████| 3/3 [00:19<00:00,  6.40s/it][A
100%|██████████| 2/2 [00:38<00:00, 19.24s/it]
