In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
import esm
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np 
import h5py
import json
import re
import shutil
import random
import time

In [3]:
esm_transformer, esm2_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
batch_converter = esm2_alphabet.get_batch_converter()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
esm_transformer = esm_transformer.to(device)

In [4]:
#dist_range_min = 50
#dist_range_max = 100 
dist_range_min = 100
expand_type = 'outward' 

In [5]:
save_file_path = '../results/pair_' + str(dist_range_min)  + '_' + expand_type + '_wbos_eos_randommask.hdf5'
save_file_path

'results/pair_100_outward_wbos_eos_randommask.hdf5'

In [6]:
with open('../data/ss_dict.json', "r") as json_file:
    ss_dict = json.load(json_file)

with open('../data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

with open('../data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)


In [7]:
# get all SSEs
def get_segments(input_str):
    segments = []
    for match in re.finditer('E+|H+', input_str):
        if (match.group()[0] == 'E' and len(match.group()) > 3) or \
           (match.group()[0] == 'H' and len(match.group()) > 7):
            segments.append((match.start(), match.end()))
    return segments

In [8]:
#def get_original_contact(seq): 
def get_contact(seq): 
    seq_tuple = [(1, seq)]
    
    # with BOS/EOS 
    batch_labels, batch_strs, batch_tokens = batch_converter(seq_tuple)
    
    # without BOS/EOS 
    #batch_tokens = torch.cat((torch.full((batch_tokens.shape[0], 1), 32), batch_tokens[:, 1:-1], torch.full((batch_tokens.shape[0], 1), 32)), dim=1)
    
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        esm2_predictions = esm_transformer.predict_contacts(batch_tokens)[0].cpu()
    return esm2_predictions.numpy()

In [9]:
# get centers of SSEs 
def get_ss_cents(segments): 
    ss_cents = []
    for seg in segments: 
        ss_cents.append((seg[1] + seg[0])//2) 
    return ss_cents 

In [10]:
# select pairs of SSEs separated by certain distance 
def get_pairs(arr):
    pairs = []
    for i in range(len(arr)):
        for j in range(i+1, len(arr)):
            #if dist_range_min < abs(arr[i] - arr[j]) <= dist_range_max: # look at long separation ones
            if dist_range_min < abs(arr[i] - arr[j]):
                pairs.append((arr[i], arr[j]))
    return pairs

In [11]:
# take 5 res on both sides 
# select segment pairs with enough contacts + leaves enough distances for exploring different flanking lengths 
def select_pairs(cent_pairs, matrix, cutoff, seq_len):
    selected_pairs = []
    for pair in cent_pairs: 
        ss1_start = pair[0] - 5 
        ss2_end = pair[1] + 5 + 1 
        patch_sum = np.sum(matrix[(pair[0] - 5): (pair[0] + 6), (pair[1] - 5): (pair[1] + 6)])
        # check there is enough contact between the two SSE 
        # check there is enough region for expanding to check recovery 
        if (patch_sum > cutoff) and (min(ss1_start, seq_len - ss2_end - 1) > 10):
            selected_pairs.append(pair) 
    n = min(len(selected_pairs), 3)
    selected_pairs = random.sample(selected_pairs, n)
    return selected_pairs 

In [12]:
# get the masked sequence and then the contact map 
def get_seg_contact(sequence, frag1_start, frag1_end, frag2_start, frag2_end): # flank_len is the amount of residues to add at sides of the segments  
    seg_seq_i = sequence[frag1_start: frag1_end] 
    seg_seq_j = sequence[frag2_start: frag2_end] 
    mask_length = frag2_start - frag1_end 
    full_seq = frag1_start * '<mask>' + seg_seq_i + mask_length * '<mask>' + seg_seq_j + (len(seq) -  frag2_end) * '<mask>'
    contact_map = get_contact(full_seq) 
    return contact_map


In [13]:
def norm_sum_mult(ori_contact_seg, seg_cross_contact):
    ori_mult_new = np.multiply(ori_contact_seg, seg_cross_contact)
    ori_mult_ori = np.multiply(ori_contact_seg, ori_contact_seg)
    return (np.sum(ori_mult_new)/np.sum(ori_mult_ori))

In [14]:
previous_data = {}

with h5py.File('../high_gremline_mask_100_outward_wbos_eos.hdf5', 'r') as f:
    for protein in f.keys():
        selected_pairs = []
        for pair in f[protein].keys():
            selected_pairs.append([int(pair.split('_')[0]), int(pair.split('_')[1])])
        previous_data[protein] = selected_pairs 

In [15]:
with open('data/reproduce_pair_recovery_100_w_bos_eos.json', 'r') as json_file:
    previous_data = json.load(json_file)

In [16]:
previous_data

{'1BRTA': [[119, 221]],
 '1BS0A': [[90, 276]],
 '1DYPA': [[116, 228]],
 '1E6UA': [[19, 219]],
 '1ELUA': [[61, 281]],
 '1EU8A': [[21, 322]],
 '1F0KA': [[203, 315]],
 '1FNNA': [[48, 162]],
 '1FVIA': [[23, 163]],
 '1GA8A': [[16, 265]],
 '1GVFA': [[33, 266]],
 '1IN4A': [[55, 174]],
 '1JX6A': [[64, 296]],
 '1K7JA': [[19, 131]],
 '1MJ5A': [[128, 240]],
 '1MO9A': [[47, 172], [47, 178]],
 '1NNLA': [[17, 175]],
 '1NOXA': [[64, 165]],
 '1NZYA': [[84, 232]],
 '1O97D': [[61, 172]],
 '1OOEA': [[19, 203]],
 '1OZHA': [[227, 328], [377, 523]],
 '1P3DA': [[341, 455]],
 '1PVGA': [[101, 202]],
 '1Q0RA': [[122, 242]],
 '1Q6ZA': [[221, 322], [359, 505]],
 '1RKQA': [[30, 264]],
 '1RM6B': [[71, 210]],
 '1S0AA': [[67, 321]],
 '1SZWA': [[153, 343]],
 '1UJ2A': [[25, 135]],
 '1URSA': [[145, 268]],
 '1YARA': [[47, 213]],
 '1YFQA': [[17, 300]],
 '1YKIA': [[83, 189]],
 '1Z0SA': [[127, 259]],
 '1Z5ZA': [[31, 218]],
 '1ZMTA': [[18, 219]],
 '1ZR6A': [[233, 340]],
 '2A35A': [[20, 187]],
 '2AQJA': [[35, 182]],
 '2B61A':

In [17]:
keys_for_test = ['3CKCA', '1PVGA']
test_dict = {key: previous_data[key] for key in keys_for_test if key in previous_data}
print(test_dict)

{'3CKCA': [[224, 420]], '1PVGA': [[101, 202]]}


In [19]:
with h5py.File(save_file_path, 'w') as f:
    for i, protein in enumerate(tqdm(previous_data.keys())):
    #for i, protein in enumerate(tqdm(test_dict)):
        seq = seq_dict[protein]
        contact_ori = get_contact(seq) 
        ss = ss_dict[protein + '.pdb']
        
        segments = get_segments(ss)
        ss_cents = get_ss_cents(segments)
        ss_pairs = get_pairs(ss_cents)
        selected_pairs = previous_data[protein]
        
        for position in tqdm(selected_pairs):
            ss1_start = position[0] - 5 
            ss1_end = position[0] + 5 + 1 
            ss2_start = position[1] - 5 
            ss2_end = position[1] + 5 + 1 
            
            ori_contact_seg = contact_ori[ss1_start:ss1_end, ss2_start:ss2_end]
            
            #expand outward 
            flank_len_range = min(ss1_start, len(seq) - ss2_end - 1)
            
            for flank_len in range(flank_len_range):
                    potential_positions = list(range(0, ss1_start)) + \
                      list(range(ss1_end, ss2_start)) + \
                      list(range(ss2_end, len(seq)))
                    random_unmask_positions = random.sample(potential_positions, 2*flank_len)

                    seq_mask = ['<mask>'] * len(seq)
                    for pos in random_unmask_positions:
                        seq_mask[pos] = seq[pos]
                    seq_mask[ss1_start:ss1_end] = seq[ss1_start:ss1_end]
                    seq_mask[ss2_start:ss2_end] = seq[ss2_start:ss2_end]
                    seq_mask = ''.join(seq_mask)

                    mask_contact_full = get_contact(seq_mask) 
                    

                    # expand both ways / outward / inward
                    seg_cross_contact = mask_contact_full[ss1_start:ss1_end, ss2_start:ss2_end]

                    sum_diff_value = np.sum(seg_cross_contact) - np.sum(ori_contact_seg)
                    norm_sum_mult_value = norm_sum_mult(ori_contact_seg, seg_cross_contact)

                    key0 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/ori_contact_full'
                    key1 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/seg_contact' 
                    key2 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/seg_cross_contact' 
                    key3 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/sum_diff' 
                    key4 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/sum_mult' 
                    key5 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/masked_seq'

                    f.create_dataset(key0, data=contact_ori)
                    f.create_dataset(key1, data=mask_contact_full)
                    f.create_dataset(key2, data=seg_cross_contact)
                    f.create_dataset(key3, data=sum_diff_value)
                    f.create_dataset(key4, data=norm_sum_mult_value)
                    f.create_dataset(key5, data=seq_mask)
        """if i % 300 == 0:
            f.flush()
            shutil.copy(save_file_path, save_file_path.split('.')[0] + '_' + str(i) + '.hdf5')
            time.sleep(20)"""

torch.cuda.empty_cache()

  0%|          | 0/266 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:06<00:00,  6.88s/it][A
  0%|          | 1/266 [00:07<31:01,  7.02s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:16<00:00, 17.00s/it][A
  1%|          | 2/266 [00:24<57:15, 13.01s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:05<00:00,  5.02s/it][A
  1%|          | 3/266 [00:29<41:19,  9.43s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.45s/it][A
  2%|▏         | 4/266 [00:32<29:27,  6.75s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:11<00:00, 11.25s/it][A
  2%|▏         | 5/266 [00:43<36:43,  8.44s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.27s/it][A
  2%|▏         | 6/266 [00:46<29:16,  6.76s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:07<00:00,  7.66s/it][A
  3%|▎         | 7/266 [00:54<30:42,  7.11