In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

In [2]:
import esm
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np 
import h5py
import json
import re
import shutil
import random
import time

In [3]:
esm_transformer, esm2_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
batch_converter = esm2_alphabet.get_batch_converter()

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
esm_transformer = esm_transformer.to(device)

In [40]:
with open('data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

with open('data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)

In [39]:
save_file_path = 'results/single_seg_w_bos_eos.hdf5'

In [10]:
def patch_sum(matrix, top_left, size=22):
    i, j = top_left
    return sum(matrix[x][y] for x in range(i, i + size) for y in range(j, j + size))

In [11]:
def find_diagonal_patches(matrix, size=22, threshold=15, sample_numb=3):
    m = len(matrix)
    patches = []

    for i in range(m - size + 1):
        j = i  
        # select those with enough contacts + enough space (10) for explore flanking values 
        if (patch_sum(matrix, (i, j), size) > threshold) and (10 < i < (matrix.shape[0] - size - 10)):
            patches.append((i, j))
            
    sampled_patches = random.sample(patches, min(sample_numb, len(patches))) 

    return sampled_patches

In [12]:
def get_contact(seq): 
    seq_tuple = [(1, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(seq_tuple)
    #batch_tokens = torch.cat((torch.full((batch_tokens.shape[0], 1), 32), batch_tokens[:, 1:-1], torch.full((batch_tokens.shape[0], 1), 32)), dim=1)
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        esm2_predictions = esm_transformer.predict_contacts(batch_tokens)[0].cpu()
    return esm2_predictions.numpy()

In [13]:
def norm_sum_mult(ori_contact_seg, seg_cross_contact):
    ori_mult_new = np.multiply(ori_contact_seg, seg_cross_contact)
    ori_mult_ori = np.multiply(ori_contact_seg, ori_contact_seg)
    return (np.sum(ori_mult_new)/np.sum(ori_mult_ori))

In [45]:
with h5py.File(save_file_path, 'w') as f:
    for protein in tqdm(selected_protein):
            seq = seq_dict[protein]
            ori_contact = get_contact(seq_dict[protein])
            
            patches = find_diagonal_patches(ori_contact, size=22, threshold=15, sample_numb=3)
            
            for patch in tqdm(patches):  
                patch_start = patch[0] 
                patch_end = patch[0] + 22
                flank_len_range = min(patch_start, len(seq) - patch_end + 1)
                
                for flank_len in range(flank_len_range):
                    seg_start = patch_start - flank_len 
                    seg_end = patch_end + flank_len 
                    seq_mask = seg_start *'<mask>' + seq[seg_start:seg_end] + len(seq[seg_end:])*'<mask>'

                    mask_contact_full = get_contact(seq_mask)

                    ori_contact_seg = ori_contact[patch_start:patch_end, patch_start:patch_end]
                    mask_contact_seg = mask_contact_full[patch_start:patch_end, patch_start:patch_end]

                    norm_sum_mult_value = norm_sum_mult(ori_contact_seg, mask_contact_seg)
                    
                    key1 = f'{protein}/{patch[0]}/{flank_len}/mask_contact_full' 
                    key2 = f'{protein}/{patch[0]}/{flank_len}/ori_contact_seg' 
                    key3 = f'{protein}/{patch[0]}/{flank_len}/mask_contact_seg' 
                    key4 = f'{protein}/{patch[0]}/{flank_len}/norm_sum_mult_value' 

                    f.create_dataset(key1, data=mask_contact_full)
                    f.create_dataset(key2, data=ori_contact_seg)
                    f.create_dataset(key3, data=mask_contact_seg)
                    f.create_dataset(key4, data=norm_sum_mult_value)

  0%|          | 0/3 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [00:22<00:44, 22.16s/it][A
 67%|██████▋   | 2/3 [00:26<00:11, 11.52s/it][A
100%|██████████| 3/3 [00:35<00:00, 11.83s/it][A
 33%|███▎      | 1/3 [00:35<01:11, 35.71s/it]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [00:20<00:41, 20.99s/it][A
 67%|██████▋   | 2/3 [00:46<00:23, 23.40s/it][A
100%|██████████| 3/3 [00:57<00:00, 19.09s/it][A
 67%|██████▋   | 2/3 [01:33<00:48, 48.53s/it]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [00:09<00:19,  9.54s/it][A
 67%|██████▋   | 2/3 [00:14<00:06,  6.99s/it][A
100%|██████████| 3/3 [00:19<00:00,  6.58s/it][A
100%|██████████| 3/3 [01:53<00:00, 37.70s/it]
