<a href="https://colab.research.google.com/github/pikanaeri/Extracting-3Di-Embeddings-from-Protein-Sequences/blob/main/phrog-embedding-figures/Category_Category_Distance_Matrix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Importing Dependencies
import pickle
from ast import literal_eval
import pandas as pd
import numpy as np
import random
import os

from sklearn import metrics
from sklearn.cluster import SpectralClustering

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import textwrap

from statannotations.Annotator import Annotator
import itertools

import networkx as nx
from scipy.spatial.distance import squareform

phrog_palette = {
    'DNA, RNA and nucleotide metabolism': 'red',
    'connector': 'blue',
    'head and packaging': 'green',
    'integration and excision': 'pink',
    'lysis': 'gray',
    'moron, auxiliary metabolic gene and host takeover': 'brown',
    'other': 'purple',
    'tail': 'darkorange',
    'transcription regulation': 'cyan'
}

In [None]:
#@title Reading in PHROGs Information
#@markdown * This code will take the labels from the PHROGs information list and place them in a dictionary mapping the PHROG number to the PHROG annotation
#@markdown * Download PHROG_index.tsv here: https://storage.googleapis.com/plm-model-comparison/PHROG_index.tsv

f = open("/home/kellylab/PHROG_index.tsv", "r")
labels = f.readline().strip().split("\t")

#Stores name (phrog #), category for each PHROG family
phrog_dict = {}
#Stores PHROG family count per functional category
phrog_count = {}
phrog_cnt = 0
for line in f:
  information = line.strip().split("\t")
  nm = int(information[0].split("phrog_")[1])
  phrog_dict[nm] = information[6]
  if information[6] in phrog_count:
    phrog_count[information[6]] += 1
  else:
    phrog_count[information[6]] = 1
  phrog_cnt += 1

for i in phrog_count:
  print(i, phrog_count[i])
print("total ", phrog_cnt)

f.close()


In [None]:
#@title Reading in PHROGs Embeddings
#@markdown * This code will go through all of the embedding data
#@markdown * Dictionary stores number, embedding
#@markdown * Once the embedding vectors are created and averaged, store them into a final_average_embeddings folder and upload them here

categories = {}
embeddings = {}
os.chdir("final_average_embeddings")

for i in os.listdir():
  if i.endswith(".pkl"):
    f2 = open(i, "rb")
    i2 = i.replace(".pkl", "").replace("phrog_", "").replace("_averaged","")
    num = int(i2)
    if phrog_dict[num] == "unknown function":
      f2.close()
      continue
    embeddings[num] = pickle.load(f2)
    if phrog_dict[num] in categories:
      categories[phrog_dict[num]].append(num)
    else:
      categories[phrog_dict[num]] = [num]
    f2.close()

phrog_dists = {'category':categories.keys(), 'phrogs':categories.values()}
phrog_dists = pd.DataFrame(phrog_dists)
phrog_dists.sort_values(by='category', inplace=True)
cs = set(phrog_dists['category'])

In [None]:
#@title Creating Distance Matrix

cat_vs_cat_dists = []
cat_vs_cat = {}
for i in range(len(cs)):
    c = list(cs)[i]
    dic = {}
    d = phrog_dists[phrog_dists['category'] == c]
    for j in range(len(cs)):
        if i == j:
            continue
        cc = list(cs)[j]
        nd = phrog_dists[phrog_dists['category'] == cc]
        c_vecs = []
        nd_vecs = []
        for ph in d['phrogs'].iloc[0]:
            c_vecs.append(embeddings[ph])
        print(nd['phrogs'])
        for ph in nd['phrogs'].iloc[0]:
            nd_vecs.append(embeddings[ph])
        c_vecs = np.array(c_vecs)
        nd_vecs = np.array(nd_vecs)
        sims = metrics.pairwise.cosine_similarity(c_vecs, nd_vecs)
        m_sims = np.mean(sims, axis=0)
        mm_sims = np.mean(m_sims, axis=0)
        cat_vs_cat_dists.append(zip([c]*len(m_sims), [cc]*len(m_sims), m_sims))
        if j > i:
            dic[cc] = {'weight': mm_sims}
    if len(dic) > 0:
        cat_vs_cat[c] = dic

cat_vs_cat_dists = [x for y in cat_vs_cat_dists for x in y]
df_cat_vs_cat_dists = pd.DataFrame(cat_vs_cat_dists, columns=['label1', 'label2', 'similarity'])
df_cat_vs_cat_dists.sort_values(by='label1', inplace=True)
df_cat_vs_cat_dists = df_cat_vs_cat_dists.groupby(['label1', 'label2']).mean().drop_duplicates()
df_cat_vs_cat_dists.to_csv("cat_vs_cat_dists.csv")