In [None]:
import numpy as np
import pandas as pd

from ema import EmbeddingHandler

In [None]:
DATA_DIR = '../examples/HCN1-variants/'
FP_METADATA = DATA_DIR + 'metadata.csv'
FP_EMB_ESM1b = DATA_DIR + 'esm1b_t33_650M_UR50S-embeddings.npy'
FP_EMB_ESM2 = DATA_DIR + 'esm2_t33_650M_UR50D-embeddings.npy'

## Load data

In [None]:
# load metadata and embeddings 

metadata = pd.read_csv(FP_METADATA)
emb_esm1b = np.load(FP_EMB_ESM1b)
emb_esm2 = np.load(FP_EMB_ESM2)

print(emb_esm1b.shape, emb_esm2.shape)
metadata.head()

In [None]:
# initialize embedding handler
emb_handler = EmbeddingHandler(metadata)

# add embeddings to the handler
emb_handler.add_emb_space(embeddings=emb_esm1b, emb_space_name='esm1b')
emb_handler.add_emb_space(embeddings=emb_esm2, emb_space_name='esm2')

## Explore embedding space

In [None]:
emb_handler.plot_emb_hist()

In [None]:
emb_handler.plot_emb_box(group="sample")

### Visualisation of dimensionality reduction x Metadata

#### PCA

In [None]:
emb_handler.visualise_emb_pca(emb_space_name="esm1b", 
                              colour="binary_disorder_prediction")

In [None]:

emb_handler.visualise_emb_pca(emb_space_name="esm2", 
                              colour="binary_disorder_prediction")

#### t-SNE

In [None]:
emb_handler.visualise_emb_tsne(emb_space_name="esm1b", 
                               colour="binary_disorder_prediction")

In [None]:
emb_handler.visualise_emb_tsne(emb_space_name="esm2", 
                               colour="binary_disorder_prediction")

### Unsupervised clustering x Metadata

In [None]:
emb_handler.recalculate_clusters(n_clusters=3, emb_space_name="esm1b")
emb_handler.recalculate_clusters(n_clusters=3, emb_space_name="esm2")

In [None]:
emb_handler.plot_feature_cluster_overlap(
    emb_space_name="esm1b", 
    feature="binary_disorder_prediction"
)

In [None]:
emb_handler.plot_feature_cluster_overlap(
    emb_space_name="esm2", 
    feature="binary_disorder_prediction"
)

### Pairswise distances between samples

In [None]:
emb_handler.plot_emb_dis_hist(distance_metric = "euclidean")

In [None]:
fig_esm1b = emb_handler.plot_emb_dis_heatmap(
    emb_space_name="esm1b",
    distance_metric="euclidean",
    order_x="binary_disorder_prediction",
    order_y="binary_disorder_prediction",
)
fig_esm2 = emb_handler.plot_emb_dis_heatmap(
    emb_space_name="esm2",
    distance_metric="euclidean",
    order_x="binary_disorder_prediction",
    order_y="binary_disorder_prediction",
)
fig_esm1b.show(), fig_esm2.show()

#### Including meatdata

In [None]:
emb_handler.plot_emb_dis_dif_dis_per_group(emb_space_name="esm1b",
                                      distance_metric="cosine",
                                      group="binary_disorder_prediction",
                                      plot_type="box"
                                      )

In [None]:
emb_handler.plot_emb_dis_dif_dis_per_group(emb_space_name="esm2",
                                      distance_metric="cosine",
                                      group="binary_disorder_prediction",
                                      plot_type="box"
                                      )

In [None]:
emb_handler.plot_emb_dis_scatter(
    emb_space_name_1 = "esm1b",
    emb_space_name_2 = "esm2",
    distance_metric = "euclidean",
    colour_group="binary_disorder_prediction",
    colour_value_1 = "True"
)

In [None]:
emb_handler.plot_emb_dis_continuous_correlation(emb_space_name="esm1b",
                                                distance_metric="euclidean",
                                                feature="disorder_propensity")