# The geometry of hidden representations in protein language models

### Import libraries

In [1]:
import sys
sys.path.insert(0,'src/')

import numpy as np
import os
import plotly.graph_objects as go
from plotly.graph_objects import Layout
from intrinsic_dimension import block_analysis, save_ID_results, plot_curve_ID

In [2]:
mapping = {'esm1b':['ESM-1b','rgb(220,20,60)', np.arange(34)],'esm1v':['ESM-1v','rgb(0,139,139)',np.arange(34)],
            'ProtBert':['ProtBert','rgb(136, 78, 160)',np.arange(31)], 'ProtT5':['ProtT5-XL-U50','rgb(245, 148, 0)',np.arange(25)], 
            'esm2-15B':['ESM-2 (15B)','rgb(68, 1, 84)',np.arange(49)],'esm2-3B':['ESM-2 (3B)','rgb(65, 68, 135)',np.arange(37)],
            'esm2-650M':['ESM-2 (650M)','rgb(42, 120, 142)',np.arange(34)], 'esm2-150M':['ESM-2 (150M)','rgb(34, 168, 132)', np.arange(31)],
            'esm2-35M':['ESM-2 (35M)','rgb(122, 209, 81)',np.arange(13)], 'esm2-8M':['ESM-2 (8M)','rgb(253, 231, 37)',np.arange(7)], 
            'esm-MSA':['ESM-MSA-1b','rgb(66,145,31)',np.arange(13)]}

## 1. Intrinsic Dimension

### Download ProteinNet data and estimate ID

In [6]:
# choose the model you are interested in between: 'esm1b', 'esm1v', 'ProtBert', 'ProtT5', 'esm2-'...
model = 'esm1b'

# take mapping values for the model
map = mapping[model]

# create dict which will store results for each reps 
reps_id = {i : [] for i in map[2]}

# define path of input and output
device = 'lucrezia'
if device == 'orfeo':
    input_path = '/Users/lucreziavaleriani/Desktop/mount_orfeo/data_repo/pdist/'+ model
else:
    input_path = '/Users/lucreziavaleriani/Desktop/'
    res_path = '/Users/lucreziavaleriani/Desktop/'


In [7]:
# for each rep calculate the id value, save the results of the procedure if needed, 
# and store the mean value of ID inside the dictionary
from dadapy import data

for rep in reps_id.keys():
    dist_mat = np.load(os.path.join(input_path,'pdist_pnet_rep'+str(rep)+'.npy'))
    
    # d = data.Data(dist_mat)
    # print('start')
    # id_twoNN_noisy, _, r_noisy = d.compute_id_2NN()


    
    #dim,std,n_point = block_analysis(dist_mat, blocks=list(range(1, 21)), fraction=0.9)
    #esm1b = save_ID_results('id_0.csv',dim,std,n_point)

    #reps_id[rep].append(np.mean(dim))
    


start


### Plot curve(s)

In [24]:
layout = Layout(plot_bgcolor='rgba(0,0,0,0)')
fig = go.Figure(layout=layout)

plot_curve_ID(fig,len(map[2])/(len(map[2])-1),reps_id.values(),map[1],map[0])

fig.update_xaxes(showline=True, linewidth=1, linecolor='black',showticklabels=True, tickmode = 'array',tickvals = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],range=[-0.01,1.05],title='relative depth',ticks='outside')
fig.update_yaxes(showline=True, linewidth=1, linecolor='black',showticklabels=True, tickmode = 'array',tickvals = [i for i in range(0,26,2)],title='ID',range=[0,24],ticks='outside')

fig.update_layout(width=900, height=700, font=dict(size=20),
                  legend=dict(orientation="h",
                              yanchor="top",
                              y=1.12,
                              xanchor="center",
                              x=0.5, font = dict(size = 25)
                              )
                    ) 
#pio.write_image(fig, '/u/area/lvaleriani/scratch/esm2/curve/svgplots/pnet_esm1.png',scale=5,width=900, height=700)

SyntaxError: unexpected EOF while parsing (2246470315.py, line 19)

## Neighboorhod Overlap - Superfamily

In [3]:
import numpy as np 
from neighborhood_overlap import create_ds, add_data, overlap_mean, plot_no

In [5]:
# choose the model you are interested in between: 'esm1b', 'esm1v', 'ProtBert', 'ProtT5', 'esm2-'...
model = 'esm1b'

# take mapping values for the model
map = mapping[model]

# download NN matrices
path = '/Users/lucreziavaleriani/Desktop/mount_orfeo/fold_rh/hn/newscope/1b/sp_all/rep'
label_path = '/Users/lucreziavaleriani/Desktop/mount_orfeo/fold_rh/data/new_label/'

In [6]:
def get_ds(path, layer, ng=100):
    ds = create_ds(layer)
    for i in np.arange(0, layer+1):
        neig = np.load(path + str(i) + '-neigh.npy')
        idx = np.load(path + str(i) + '-idx.npy')

        label = open(label_path + 'sp_lab.txt')
        label = np.array([l.strip('\n') for l in label])

        ds = add_data(ds, i, neig, idx, label, ng)
    return ds


In [7]:
ds = get_ds(path,33,ng=100)
overlap1 = np.array(overlap_mean(ds,10))

## PLOT

In [47]:
layout = Layout(plot_bgcolor='rgba(0,0,0,0)')
fig = go.Figure(layout=layout)#.set_subplots(1, 1, horizontal_spacing=0.1,vertical_spacing=0.02)

plot_no(fig, len(map[2]), overlap1, map[1], map[0])

fig.update_xaxes(showline=True, linewidth=1, linecolor='black', showticklabels=True, tickmode = 'array',tickvals = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],range=[-0.01,1.05],title='relative depth',ticks='outside')
fig.update_yaxes(showline=True, linewidth=1, linecolor='black', tickvals = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] , range=[0,1], title = r'$\chi^{gt}$',ticks='outside')

fig.update_layout(width=900, height=700, font=dict(size=20),
                  legend=dict(orientation="h",
                              yanchor="top",
                              y=1.12,
                              xanchor="center",
                              x=0.5, font = dict(size = 25)
                              )
                    ) 

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34]
