In [None]:
import h5py
import numpy as np
from pandas import read_csv
from sklearn.manifold import TSNE
from bio_embeddings.visualize import render_3D_scatter_plotly, save_plotly_figure_to_html

In [None]:
mapping_file = read_csv('mapping_file.csv', index_col=0)

In [None]:
embeddings = []
with h5py.File('reduced_embeddings_file.h5', 'r') as f:
    for remapped_id in mapping_file.index:
        embeddings.append(np.array(f[str(remapped_id)]))

In [None]:
tsne_params = dict()
tsne_params['n_components'] = 3
tsne_params['perplexity'] = 30
tsne_params['random_state'] = 420
tsne_params['n_iter'] = 15000
tsne_params['verbose'] = 1
tsne_params['n_jobs'] = -1
tsne_params['metric'] = 'cosine'

transformed_embeddings = TSNE(**tsne_params).fit_transform(embeddings)

In [None]:
mapping_file['x'] = transformed_embeddings[:, 0]
mapping_file['y'] = transformed_embeddings[:, 1]
mapping_file['z'] = transformed_embeddings[:, 2]

In [None]:
mapping_file.to_csv('projected_embeddings_file_TSNE.csv')

In [None]:
annotations_files_folder = 'annotations/'
figures_files_fodler = 'figures/'

annotation_files = ['disprot_2019_09_floats.csv', 
                    'disprot_2019_09_extreme_ends_0.2vs0.8.csv', 
                    'disprot_2019_09_extreme_ends_0.3vs0.7.csv', 
                    'disprot_2019_09_extreme_ends_0.5vs0.5.csv',
                    'disprot_2019_09_3classes_0.2_0.8.csv']

for annotation_file_path in annotation_files:
    annotation_file = read_csv(str(annotations_files_folder + annotation_file_path), index_col=0)
    if annotation_file['label'].nunique() < 3:
        annotation_file['label'] = annotation_file['label'].apply(str)

    merged_annotation_file = annotation_file.join(mapping_file.set_index('original_id'))
    figure = render_3D_scatter_plotly(merged_annotation_file)
    save_plotly_figure_to_html(figure, str(figures_files_fodler + annotation_file_path + ".html"))