In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '3'

In [2]:
import esm
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np 
import h5py
import json
import re
import shutil
import random
import time

In [3]:
esm_transformer, esm2_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
batch_converter = esm2_alphabet.get_batch_converter()

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
esm_transformer = esm_transformer.to(device)

In [5]:
with open('../data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

with open('../data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)

In [6]:
def get_contact(seq): 
    seq_tuple = [(1, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(seq_tuple)
    batch_tokens = torch.cat((torch.full((batch_tokens.shape[0], 1), 32), batch_tokens[:, 1:-1], torch.full((batch_tokens.shape[0], 1), 32)), dim=1)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        esm2_predictions = esm_transformer.predict_contacts(batch_tokens)[0].cpu()
    return esm2_predictions.numpy()

In [8]:
def diverse_sample(lst, k):
    sorted_lst = sorted(lst)
    sampled = [random.choice(sorted_lst)]  # start by picking one element randomly
    for _ in range(k - 1):
        next_item = max(sorted_lst, key=lambda x: min(abs(x - s) for s in sampled))
        sampled.append(next_item)
        sorted_lst.remove(next_item)  
    return sampled

In [14]:
def find_diagonal_patches(matrix, size=11, sample_numb=3):
    patches = []

    for i in range(len(matrix) - size + 1):
        
        sub_matrix = matrix[i: i + size, i + size + 5: i + 5 + 2 * size] 
        
        if np.sum(sub_matrix) > 10:
            patches.append(i)
    
    if patches:  
        sampled_patches = diverse_sample(patches, min(sample_numb, len(patches))) 
        return sampled_patches
    else: 
        return []

In [15]:
def patch_plotting(ori_contact): 
    plt.imshow(ori_contact, cmap="Greys",vmin=0,vmax=1)

    x_line = np.arange(0, ori_contact.shape[0])
    y_line = x_line + 5
    plt.plot(x_line, y_line, color='blue', linewidth=1)

    patches = find_diagonal_patches(ori_contact)
    for patch in patches:
        plt.gca().add_patch(plt.Rectangle((patch, patch + 11 + 5), 11, 11, linewidth=1, edgecolor='orange', facecolor='none'))

    plt.colorbar()
    plt.show()

In [16]:
patch_info = {}

for protein in tqdm(selected_protein): 
    ori_contact = get_contact(seq_dict[protein])
    #patch_plotting(ori_contact)
    patches = find_diagonal_patches(ori_contact)
    if not patches: 
        continue 
        
    patch_info[protein] = patches 

100%|██████████| 1431/1431 [03:53<00:00,  6.14it/s]


NameError: name 'previous_data' is not defined

In [17]:
with open('../data/revision_single_seg_reproduce.json', 'w') as json_file: json.dump(patch_info, json_file)