Script to get all the .zips from my drive folder, and access and proces all the MSA embeddings

To Do :
- text

Notes :
- nanstd doesn't work well with dtype argument, using std instead since AF2 MSA embeddings don't contain either NaNs or inf, but are float16 (and I'm saving back to float16). `std(float16_list, dtype = np.float64)` works but `nanstd(float16_list, dtype = np.float64)` doesn't and gives an inf, even if no NaN in the list -> only on the micro loop, macro is still nanstd, check if any inf have come through

# Set Up

In [None]:
#@title Mount google drive
from google.colab import drive
drive.mount('/content/drive')

from pydrive.drive import GoogleDrive
from pydrive.auth import GoogleAuth
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)
print("You are logged into Google Drive and are good to go!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
You are logged into Google Drive and are good to go!


In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import os
#from glob import glob
from zipfile import ZipFile
import json

import time

In [None]:
#@title Paths and global vars
path = "/content/drive/MyDrive/Biotech_Work/Dev_files/ColabFold_runs/IO/output"

positive_seqs = pd.read_csv('shaped_positives_reduced.csv')
negative_seqs = pd.read_csv('shaped_negatives_reduced.csv')

# Functions

In [None]:
#@title Functions

#
# both glob.glob and pathlib.glob don't like working with a zip or a list of strings
# > everyone just ends up re-writing stuff, and we don't have a complicated regex anyway
def sort_files(file_list, prefix = ''):

  disto_list = []
  score_list = []
  repr_list = []
  pdb_list = []

  for file_name in file_list:
    if 'custom_disto' in file_name:
      disto_list.append(prefix + file_name)
    if '.pdb' in file_name:
      pdb_list.append(prefix + file_name)
    if 'scores' in file_name:
      score_list.append(prefix + file_name)
    if 'repr' in file_name:
      repr_list.append(prefix + file_name)
  #

  return pdb_list, score_list, disto_list, repr_list


#
# don't want to bother with the bytes or the path-like objects
# > so extract, process, delete
def extract_and_process_embeddings(file_list, zip_archive, nb_pep):

  # store with key : feature name, value : list of vals over the 1-10 ranks
  storage = {}

  for file_name in file_list:
    new_local_file = zip_archive.extract(file_name)

    # if storage['max_pae'] == [] and 'rank_001' not in new_local_file:
    #   print('error : rank_001 not first in list', new_local_file)

    embeddings = np.load(new_local_file, allow_pickle = True).T # (256, N_AA)
    # all checked : msa embeddings are clean
    #print("dtype, nan, inf :", embeddings.dtype, np.sum(np.isnan(embeddings)), np.sum(np.isinf(embeddings)))

    if 'rank_001' in new_local_file:
      print("embeddings shape :", embeddings.shape)

    # # just for initial check
    # print("ch 71 mean compared to around :", np.mean(embeddings[70:73], axis = -1))
    # print("ch 227 mean compared to around :", np.mean(embeddings[226:230], axis = -1))

    # no point in precising channel 71 or 227 since we're reducing in the other dimension
    # ! USE STD AND NOT NANSTD BECAUSE NANSTD DOESNT WORK WELL WITH DTYPE ARG !
    aa_select = {'all_aa': embeddings, 'pep': embeddings[:, :nb_pep], 'prot': embeddings[:, nb_pep:]}
    aggreg_select = {'mean': np.mean, 'std': np.std, 'min': np.nanmin, 'max': np.nanmax}

    for aggreg_method in aggreg_select.keys():
      for aa_selection in aa_select.keys():
        temp_func = aggreg_select[aggreg_method] # np.mean etc
        temp_obj = aa_select[aa_selection] # total, just pep aa, etc

        temp_key = f"{aggreg_method}_{aa_selection}"
        if temp_key not in storage:
          storage[temp_key] = []
        if aggreg_method == 'mean' or aggreg_method == 'std':
          storage[temp_key].append(temp_func(temp_obj, axis = -1, dtype = np.float64))
        else:
          storage[temp_key].append(temp_func(temp_obj, axis = -1))

    os.remove(new_local_file)
  #

  feature_dict = {}
  for feature in storage:
    feature_dict['rank_1_'+feature] = np.float16(storage[feature][0])
    # feature_dict['max_ranks_'+feature] = np.max(storage[feature]).astype(np.float16)
    feature_dict['median_ranks_'+feature] = np.nanmedian(storage[feature], axis = 0).astype(np.float16)
    feature_dict['mean_ranks_'+feature] = np.nanmean(storage[feature], axis = 0, dtype = np.float64).astype(np.float16)
    feature_dict['std_ranks_'+feature] = np.nanstd(storage[feature], axis = 0, dtype = np.float64).astype(np.float16)
  feature_dict['nb_ranks'] = len(storage[feature])

  return feature_dict



# Main Run

In [None]:
#@title Main run

# main run
all_drive_elems = os.listdir(path)
temp_zips = [elem for elem in all_drive_elems if ".zip" in elem]

zip = temp_zips[0] #TODO becomes a loop

global_storage = {}

for zip in temp_zips:
  time.sleep(4) # oof Drive <-> Colab connection
  # worked with 2 sec delays at first, broke w/ 3 sec later on

  complex_id = zip.split('.')[0]
  is_positive_sample = len(complex_id) == 8

  if is_positive_sample:
    complex_sequence = positive_seqs[positive_seqs['id'] == complex_id]['sequence'].values[0]
  else:
    complex_sequence = negative_seqs[negative_seqs['id'] == complex_id]['sequence'].values[0]
  nb_peptide = len(complex_sequence.split(':')[0])

  print("---- " + complex_id)
  print("seq length", len(complex_sequence))

  archive = ZipFile(path + '/' + zip, 'r')
  files = archive.namelist()
  pdb, scores, disto, repr = sort_files(files) # name paths

  feature_dict = extract_and_process_embeddings(repr, archive, nb_peptide)

  global_storage[complex_id] = feature_dict
#

# print shape
# print boolean if channel 71 is much lower mean, min than global
# print boolean if channel 227 is much higher mean, max than global

---- 6iur_C-6iqj_A
seq length 182
embeddings shape : (256, 181)
---- 6kmj_C-6i42_A
seq length 182
embeddings shape : (256, 181)
---- 6qbb_P-6jfa_A
seq length 183
embeddings shape : (256, 182)
---- 6l7c_S-6i7q_V
seq length 183
embeddings shape : (256, 182)
---- 6jnf_D-6ifc_A
seq length 183
embeddings shape : (256, 182)
---- 6i42_B-6jfa_A
seq length 184
embeddings shape : (256, 183)
---- 6i7q_H-6i42_A
seq length 185
embeddings shape : (256, 184)
---- 6sat_P-6spb_F
seq length 185
embeddings shape : (256, 184)
---- 6qmp_A-6tzc_B
seq length 186
embeddings shape : (256, 185)
---- 6lry_B-6i42_A
seq length 186
embeddings shape : (256, 185)
---- 6i7q_H-6q68_A
seq length 187
embeddings shape : (256, 186)
---- 6o23_E-6p8s_A
seq length 187
embeddings shape : (256, 186)
---- 6vo5_C-6p8s_A
seq length 187
embeddings shape : (256, 186)
---- 6l7c_S-6q68_A
seq length 189
embeddings shape : (256, 188)
---- 6l0v_B-6spb_F
seq length 191
embeddings shape : (256, 190)
---- 6uyo_B-6gc3_A
seq length 191
embedd

In [None]:
np.save('all_msa_embedding_features_dict.npy', np.array([global_storage]))

In [None]:
for key in global_storage:
  for feature in global_storage[key]:
    if feature == 'nb_ranks':
      continue
    nb_nan = np.sum(np.isnan(global_storage[key][feature]))
    nb_inf = np.sum(np.isinf(global_storage[key][feature]))
    print(feature)
    print(global_storage[key][feature].shape)
    if nb_nan > 0 or nb_inf > 0:
      print(feature, nb_nan, nb_inf)

In [None]:
#@title Sandbox


In [None]:
complex_id = zip.split('.')[0]
is_positive_sample = len(complex_id) == 8

if is_positive_sample:
  complex_sequence = positive_seqs[positive_seqs['id'] == complex_id]['sequence'].values[0]
else:
  complex_sequence = negative_seqs[negative_seqs['id'] == complex_id]['sequence'].values[0]
nb_peptide = len(complex_sequence.split(':')[0])

print(complex_id)
print(is_positive_sample)
print(complex_sequence)
print(nb_peptide)

In [None]:
archive = ZipFile(path + '/' + zip, 'r')
files = archive.namelist()

In [None]:
pdb, scores, disto, repr = sort_files(files)

In [None]:
repr

In [None]:
archive.extract('6iur_C-6iqj_A_single_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy')

In [None]:
test = np.load('/content/6iur_C-6iqj_A_single_repr_rank_004_alphafold2_multimer_v3_model_2_seed_000.npy', allow_pickle = True)

In [None]:
test.shape

In [None]:
plt.imshow(test.T)
plt.colorbar()
plt.ylabel('embedding channels')
plt.xlabel('N_AA')

In [None]:
test.T[:, :50].shape

In [None]:
np.max(test.T[225:230], axis = -1)

In [None]:
len('STMDWEVERAELQARIAFLQGERKGQENLKKDLVRRIKMLEYALKQERAK:MSWQSYVDDHLMCEVEGNHLTHAAIFGQDGSVWAQSSAFPQLKPAEIAGINKDFEEAGHLAPTGLFLGGEKYMVVQGEAGAVIRGKKGPGGVTIKKTTQALVFGIYDEPMTGGQCNLVVERLGDYLIESGL')

bin_edges
```
array([ 2.312,  2.625,  2.938,  3.25 ,  3.562,  3.875,  4.188,  4.5  ,
        4.812,  5.125,  5.438,  5.75 ,  6.062,  6.375,  6.688,  7.   ,
        7.312,  7.625,  7.938,  8.25 ,  8.56 ,  8.875,  9.19 ,  9.5  ,
        9.81 , 10.125, 10.44 , 10.75 , 11.06 , 11.375, 11.69 , 12.   ,
       12.31 , 12.625, 12.94 , 13.25 , 13.56 , 13.875, 14.19 , 14.5  ,
       14.81 , 15.125, 15.44 , 15.75 , 16.06 , 16.38 , 16.69 , 17.   ,
       17.31 , 17.62 , 17.94 , 18.25 , 18.56 , 18.88 , 19.19 , 19.5  ,
       19.81 , 20.12 , 20.44 , 20.75 , 21.06 , 21.38 , 21.69 ],
      dtype=float16)
```

