In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'

In [2]:
import esm
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm 
import numpy as np 
import h5py
import json
import re
import shutil
import random
import time

In [3]:
esm_transformer, esm2_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
batch_converter = esm2_alphabet.get_batch_converter()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
esm_transformer = esm_transformer.to(device)

In [4]:
dist_range_min = 50
dist_range_max = 100 
#dist_range_min = 100
expand_type = 'outward' 

In [5]:
save_file_path = 'results/pair_' + str(dist_range_min)  + '_' + str(dist_range_max)  + '_'+ expand_type + '_wbos_eos_randommask_no_unmask_within30.hdf5'
save_file_path

'results/pair_50_100_outward_wbos_eos_randommask_no_unmask_within30.hdf5'

In [7]:
with open('data/ss_dict.json', "r") as json_file:
    ss_dict = json.load(json_file)

with open('data/full_seq_dict.json', "r") as json_file:
    seq_dict = json.load(json_file)

with open('data/selected_protein.json', 'r') as file:
    selected_protein = json.load(file)


In [8]:
# get all SSEs
def get_segments(input_str):
    segments = []
    for match in re.finditer('E+|H+', input_str):
        if (match.group()[0] == 'E' and len(match.group()) > 3) or \
           (match.group()[0] == 'H' and len(match.group()) > 7):
            segments.append((match.start(), match.end()))
    return segments

In [9]:
#def get_original_contact(seq): 
def get_contact(seq): 
    seq_tuple = [(1, seq)]
    
    # with BOS/EOS 
    batch_labels, batch_strs, batch_tokens = batch_converter(seq_tuple)
    
    # without BOS/EOS 
    #batch_tokens = torch.cat((torch.full((batch_tokens.shape[0], 1), 32), batch_tokens[:, 1:-1], torch.full((batch_tokens.shape[0], 1), 32)), dim=1)
    
    batch_tokens = batch_tokens.to(device)
    with torch.no_grad():
        esm2_predictions = esm_transformer.predict_contacts(batch_tokens)[0].cpu()
    return esm2_predictions.numpy()

In [10]:
# get centers of SSEs 
def get_ss_cents(segments): 
    ss_cents = []
    for seg in segments: 
        ss_cents.append((seg[1] + seg[0])//2) 
    return ss_cents 

In [11]:
# select pairs of SSEs separated by certain distance 
def get_pairs(arr):
    pairs = []
    for i in range(len(arr)):
        for j in range(i+1, len(arr)):
            if dist_range_min < abs(arr[i] - arr[j]) <= dist_range_max: # look at long separation ones
            #if dist_range_min < abs(arr[i] - arr[j]):
                pairs.append((arr[i], arr[j]))
    return pairs

In [12]:
# take 5 res on both sides 
# select segment pairs with enough contacts + leaves enough distances for exploring different flanking lengths 
def select_pairs(cent_pairs, matrix, cutoff, seq_len):
    selected_pairs = []
    for pair in cent_pairs: 
        ss1_start = pair[0] - 5 
        ss2_end = pair[1] + 5 + 1 
        patch_sum = np.sum(matrix[(pair[0] - 5): (pair[0] + 6), (pair[1] - 5): (pair[1] + 6)])
        # check there is enough contact between the two SSE 
        # check there is enough region for expanding to check recovery 
        if (patch_sum > cutoff) and (min(ss1_start, seq_len - ss2_end - 1) > 10):
            selected_pairs.append(pair) 
    n = min(len(selected_pairs), 3)
    selected_pairs = random.sample(selected_pairs, n)
    return selected_pairs 

In [13]:
# get the masked sequence and then the contact map 
def get_seg_contact(sequence, frag1_start, frag1_end, frag2_start, frag2_end): # flank_len is the amount of residues to add at sides of the segments  
    seg_seq_i = sequence[frag1_start: frag1_end] 
    seg_seq_j = sequence[frag2_start: frag2_end] 
    mask_length = frag2_start - frag1_end 
    full_seq = frag1_start * '<mask>' + seg_seq_i + mask_length * '<mask>' + seg_seq_j + (len(seq) -  frag2_end) * '<mask>'
    contact_map = get_contact(full_seq) 
    return contact_map


In [None]:
def norm_sum_mult(ori_contact_seg, seg_cross_contact):
    ori_mult_new = np.multiply(ori_contact_seg, seg_cross_contact)
    ori_mult_ori = np.multiply(ori_contact_seg, ori_contact_seg)
    return (np.sum(ori_mult_new)/np.sum(ori_mult_ori))

In [15]:
previous_data = {}

with h5py.File('../high_gremline_mask_50_100_outward_wbos_eos.hdf5', 'r') as f:
    for protein in f.keys():
        selected_pairs = []
        for pair in f[protein].keys():
            selected_pairs.append([int(pair.split('_')[0]), int(pair.split('_')[1])])
        previous_data[protein] = selected_pairs 

In [16]:
previous_data

{'1A8LA': [[138, 197], [26, 83]],
 '1ATGA': [[108, 161]],
 '1B63A': [[239, 292], [57, 146]],
 '1BRTA': [[26, 94]],
 '1BYIA': [[34, 112]],
 '1C1DA': [[178, 235], [25, 78]],
 '1C3PA': [[163, 251]],
 '1DJ0A': [[113, 186], [152, 219]],
 '1DP4A': [[151, 210]],
 '1DXYA': [[109, 160], [148, 201]],
 '1DYPA': [[101, 195]],
 '1E0CA': [[164, 226], [26, 84]],
 '1E6UA': [[102, 159], [199, 270], [85, 142]],
 '1EFDN': [[148, 205]],
 '1ELUA': [[104, 158]],
 '1EQ2A': [[112, 164], [188, 248], [89, 146]],
 '1ES9A': [[42, 97]],
 '1EVYA': [[18, 89]],
 '1F0KA': [[186, 258]],
 '1FC6A': [[232, 287]],
 '1FNNA': [[76, 128]],
 '1FVIA': [[192, 271]],
 '1G3QA': [[35, 115]],
 '1GA6A': [[31, 100], [31, 130]],
 '1GA8A': [[99, 157]],
 '1GK9B': [[349, 428]],
 '1GTVA': [[101, 176], [136, 192], [32, 91]],
 '1GU7A': [[171, 249], [37, 114]],
 '1H2BA': [[190, 258], [45, 141]],
 '1H99A': [[23, 82]],
 '1HQSA': [[217, 281]],
 '1I24A': [[131, 200], [150, 215], [253, 316]],
 '1I9ZA': [[59, 116]],
 '1II5A': [[96, 159], [96, 187]]

In [17]:
keys_for_test = ['3CKCA', '1PVGA']
test_dict = {key: previous_data[key] for key in keys_for_test if key in previous_data}
print(test_dict)

{}


In [None]:
with h5py.File(save_file_path, 'w') as f:
    for i, protein in enumerate(tqdm(previous_data.keys())):
    #for i, protein in enumerate(tqdm(test_dict)):
        seq = seq_dict[protein]
        contact_ori = get_contact(seq) 
        ss = ss_dict[protein + '.pdb']
        
        segments = get_segments(ss)
        ss_cents = get_ss_cents(segments)
        ss_pairs = get_pairs(ss_cents)
        selected_pairs = previous_data[protein]
        
        for position in tqdm(selected_pairs):
            ss1_start = position[0] - 5 
            ss1_end = position[0] + 5 + 1 
            ss2_start = position[1] - 5 
            ss2_end = position[1] + 5 + 1 
            
            ori_contact_seg = contact_ori[ss1_start:ss1_end, ss2_start:ss2_end]
            
            #expand outward 
            flank_len_range = min(ss1_start, len(seq) - ss2_end - 1)
            
            for flank_len in range(flank_len_range):
                # Define the ranges to exclude
                exclude_range1 = range(max(0, ss1_start - 30), min(len(seq), ss1_end + 30))
                exclude_range2 = range(max(0, ss2_start - 30), min(len(seq), ss2_end + 30))

                # Create a list of potential positions, excluding the specified ranges
                potential_positions = [pos for pos in range(len(seq)) 
                                       if pos not in exclude_range1 and pos not in exclude_range2 
                                       and not (ss1_start <= pos < ss1_end) 
                                       and not (ss2_start <= pos < ss2_end)]
                
            
                if len(potential_positions) < 2 * flank_len:
                    break
                
                random_unmask_positions = random.sample(potential_positions, 2*flank_len)

                seq_mask = ['<mask>'] * len(seq)
                for pos in random_unmask_positions:
                    seq_mask[pos] = seq[pos]
                seq_mask[ss1_start:ss1_end] = seq[ss1_start:ss1_end]
                seq_mask[ss2_start:ss2_end] = seq[ss2_start:ss2_end]
                seq_mask = ''.join(seq_mask)

                mask_contact_full = get_contact(seq_mask) 


                # expand both ways / outward / inward
                seg_cross_contact = mask_contact_full[ss1_start:ss1_end, ss2_start:ss2_end]

                sum_diff_value = np.sum(seg_cross_contact) - np.sum(ori_contact_seg)
                norm_sum_mult_value = norm_sum_mult(ori_contact_seg, seg_cross_contact)

                key0 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/ori_contact_full'
                key1 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/seg_contact' 
                key2 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/seg_cross_contact' 
                key3 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/sum_diff' 
                key4 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/sum_mult' 
                key5 = f'{protein}/{position[0]}_{position[1]}/{flank_len}/masked_seq'

                f.create_dataset(key0, data=contact_ori)
                f.create_dataset(key1, data=mask_contact_full)
                f.create_dataset(key2, data=seg_cross_contact)
                f.create_dataset(key3, data=sum_diff_value)
                f.create_dataset(key4, data=norm_sum_mult_value)
                f.create_dataset(key5, data=seq_mask)
    """if i % 300 == 0:
        f.flush()
        shutil.copy(save_file_path, save_file_path.split('.')[0] + '_' + str(i) + '.hdf5')
        time.sleep(20)"""

torch.cuda.empty_cache()

  0%|          | 0/821 [00:00<?, ?it/s]
  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:02<00:02,  2.46s/it][A
100%|██████████| 2/2 [00:04<00:00,  2.40s/it][A
  0%|          | 1/821 [00:05<1:09:01,  5.05s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:06<00:00,  6.03s/it][A
  0%|          | 2/821 [00:11<1:17:46,  5.70s/it]
  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:05<00:05,  5.94s/it][A
100%|██████████| 2/2 [00:15<00:00,  7.52s/it][A
  0%|          | 3/821 [00:26<2:16:54, 10.04s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:02<00:00,  2.92s/it][A
  0%|          | 4/821 [00:29<1:39:13,  7.29s/it]
  0%|          | 0/1 [00:00<?, ?it/s][A
100%|██████████| 1/1 [00:03<00:00,  3.22s/it][A
  1%|          | 5/821 [00:32<1:19:43,  5.86s/it]
  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:20<00:20, 20.31s/it][A
100%|██████████| 2/2 [00:23<00:00, 11.95s/it][A
  1%|          | 6/821 