### This notebook conducts an all-vs-all structural comparison of proteins within the SCOPe 40% clustered database using CIRPIN and Progres

### Load SCOPe40 embedded databases

In [18]:
fp_prog = '../embedded_databases/ark_scope40_Progres_embed_1_7_25.pt'
fp_cirpin = '../embedded_databases/ark_scope40_CIRPIN_embed_1_7_25.pt'

In [19]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
import progres as pg

Using device: cuda


### Load embeddings and labels

In [20]:
progres_emb = torch.load(fp_prog, map_location=device)['embeddings']
cirpin_emb = torch.load(fp_cirpin, map_location=device)['embeddings']
print('Loaded embs!')
progres_labels = torch.load(fp_prog, map_location=device)['ids']
cirpin_labels = torch.load(fp_cirpin, map_location=device)['ids']
print('Loaded embs!')
# check labels are same
progres_labels == cirpin_labels

Loaded embs!
Loaded embs!


True

### Calculate all v all using Progres/CIRPIN
### Data is chunked to perform all v all calculation

In [21]:
def get_putative_cps(emb_p =progres_emb, emb_c =cirpin_emb, labels = progres_labels, score_diff_cutoff =0.3, progres_cutoff = 0.6, cirpin_cutoff = 0.9):
    
    putative_cps = []

    chunk_size = 1000
    n = len(emb_p)
    num_putative_pairs = 0
    
    chunk_num = 0
    for i_start in range(0, n, chunk_size):
        
        i_end = min(i_start + chunk_size, n)
        
        
        # PROGRES: Keep scores below progres cutoff
        chunk_i_p = emb_p[i_start:i_end]
        dot_chunk_p = chunk_i_p @ progres_emb.T
        scaled_chunk_p = (dot_chunk_p + 1) / 2

        mask_p = scaled_chunk_p >= progres_cutoff
        scaled_chunk_p[mask_p] = 0
        
        
        # CIRPIN: Keep scores > cirpin_cutoff and < 1
        chunk_i_c = emb_c[i_start:i_end]
        dot_chunk_c = chunk_i_c @ cirpin_emb.T
        scaled_chunk_c = (dot_chunk_c + 1) / 2

        mask_c = (scaled_chunk_c <= cirpin_cutoff) | (scaled_chunk_c == 1)
        scaled_chunk_c[mask_c] = 0
    
    # Set the same indices that were zeroed in scaled_chunk_p to zero in scaled_chunk_c
        scaled_chunk_c[mask_p] = 0
    
    # Set the same indices that were zeroed in scaled_chunk_c to zero in scaled_chunk_p
        scaled_chunk_p[mask_c] = 0
        
        # Calculate difference and find matches above cutoff
        diff_chunk_scores = scaled_chunk_c - scaled_chunk_p
        mask = diff_chunk_scores > score_diff_cutoff
        #print(mask)
        
        # Get indices of matching positions
        indices = torch.nonzero(mask, as_tuple=False)

        indices[:, 0] += 1000 * chunk_num

        indices_list = indices.tolist()
        
        for pair in indices_list:
            p0 = pair[0]
            p1 = pair[1]
            p0_reindex = p0 - (1000 * chunk_num)
            putative_cps.append([
                labels[p0],
                labels[p1],
                float(scaled_chunk_p[p0_reindex, p1].item()),
                float(scaled_chunk_c[p0_reindex, p1].item())
            ])

        num_putative_pairs += len(indices_list)
        print(num_putative_pairs)
    
        chunk_num +=1
        if chunk_num % 10 == 0:
            print(f'Processed {chunk_num} chunks!')
            
    cutoffs = {}
    cutoffs['score_diff'] = score_diff_cutoff
    cutoffs['progres_cutoff'] = progres_cutoff
    cutoffs['cirpin_cutoff'] = cirpin_cutoff
    return putative_cps, cutoffs

### Calculate putative pairs 

In [22]:
putative_pairs, settings = get_putative_cps()

2555
4764
7285
9619
12201
14653
17213
19699
22335
24278
Processed 10 chunks!
26624
28908
31690
34053
36317
36652


In [24]:
settings

{'score_diff': 0.3, 'progres_cutoff': 0.6, 'cirpin_cutoff': 0.9}

# Save the putative pairs

In [27]:
import pickle
import os
def save_pairs(list_of_pairs, cutoffs):
    cutoff_str = ''.join(f'{k}_{v}_' for k, v in cutoffs.items())

    fp = '/home/ubuntu/scope40'
    full_fp = os.path.join(fp,f'putative_pairs_list_{cutoff_str}.pkl')
    with open(full_fp, 'wb') as f:
        pickle.dump(list_of_pairs, f)
    

In [146]:
save_pairs(putative_pairs, settings)

### Load putative pairs list

In [134]:
with open('/home/ubuntu/scope40/putative_pairs_list_score_diff_0.3_progres_cutoff_0.6_cirpin_cutoff_0.9_.pkl', 'rb') as f:
        l = pickle.load(f)

### Check is a SCOPe domain of interest is in putative pairs

In [26]:
single_list =[]
for i in l:
    for x in i:
        single_list.append(x)

NameError: name 'l' is not defined

In [136]:
domain_of_interest = 'd2hgaa1' 
domain_of_interest in single_list

True

### Check a few examples from putative pairs to make sure function is working right

In [28]:
q = '/home/ubuntu/scope40/pdbstyle-2.08/d1z05a2.pdb'
t = '/home/ubuntu/scope40/pdbstyle-2.08/d5mu9a2.pdb'

In [29]:
# Check CIRPIN score
model_loaded = pg.load_trained_model(trained_model='/home/ubuntu/progres_link/trained_models/CIRPIN/CIRPIN_model/model_5k_cp_epoch301.pt')
qs = pg.embed_structure(q,model=model_loaded)
ts = pg.embed_structure(t,model=model_loaded)
pg.embedding_similarity(qs,ts)

Using model model_5k_cp_epoch301.pt


tensor(0.9008)

In [30]:
# Check Progres score
model_loaded = pg.load_trained_model(trained_model='/home/ubuntu/progres_link/trained_models/v_0_2_0/trained_model.pt')
qs = pg.embed_structure(q,model=model_loaded)
ts = pg.embed_structure(t,model=model_loaded)
pg.embedding_similarity(qs,ts)

Using model trained_model.pt


tensor(0.3845)

### Run verify_putative_pairs_scope40.py on saved putative pairs to check pairs using TM-align