In [None]:
import numpy as np
import pandas as pd

In [None]:
from ema import EmbeddingHandler

In [None]:
experiment = "ion-channel-proteins"
patho_or_benign = "all"
wt = True

FP_AGG_EMB_0 = f'../examples/{experiment}/esm1b_t33_650M_UR50S-embeddings.npy'
FP_AGG_EMB_1 = f'../examples/{experiment}/esm2_t33_650M_UR50D-embeddings.npy'
# FP_AGG_EMB_0 = f'../examples/{experiment}/embeddings/full.npy'
# FP_AGG_EMB_1 = f'../examples/{experiment}/embeddings/chopped.npy'
# FP_AGG_EMB_1 = f"../examples/{experiment}/t5_u50.npy"
# FP_AGG_EMB_0 = f'../examples/{experiment}/esm1v_t33_650M_UR90S_1-embeddings.npy'
FP_META_DATA = f"../examples/{experiment}/metadata.csv"

emb_0 = np.load(FP_AGG_EMB_0)
emb_1 = np.load(FP_AGG_EMB_1)
# emb_2 = np.load(FP_AGG_EMB_2)

df_meta_data = pd.read_csv(FP_META_DATA)

if "1" in experiment:
    
    if wt != True:
        FP_pw_3d_distance = f"../examples/{experiment}/pw_3d_distance.csv"
        pw_3d_distance = pd.read_csv(FP_pw_3d_distance, index_col=0)
        
        # find variant index which contains the WT
        idx_wt = df_meta_data[df_meta_data["variant_id"].str.contains("WT")].index
        # remove this index from the meta data and emb
        df_meta_data = df_meta_data.drop(idx_wt)
        emb_0 = np.delete(emb_0, idx_wt, axis=0)
        emb_1 = np.delete(emb_1, idx_wt, axis=0)
    
    if patho_or_benign == "benign":
        # find indices of pathogenic variants
        idx_patho = df_meta_data[df_meta_data["data_source"] == "gnomad"].index
        # sample meta data and emb
        df_meta_data = df_meta_data.loc[idx_patho]
        # reindex the meta data
        df_meta_data = df_meta_data.reset_index(drop=True)
        emb_0 = emb_0[idx_patho]
        emb_1 = emb_1[idx_patho]
        # subset pw_3d_distance if necessary
        if wt != True:
            # find sample names  in first column of meta data
            sample_names = df_meta_data.iloc[:, 0]
            print(len(sample_names))
            # find indices of sample names in columns and rows of pw_3d_distance
            idx_pw_3d_distance = pw_3d_distance.index[pw_3d_distance.index.isin(sample_names)]
            col_idx_pw_3d_distance = pw_3d_distance.columns[pw_3d_distance.columns.isin(sample_names)]
            pw_3d_distance = pw_3d_distance.loc[idx_pw_3d_distance, col_idx_pw_3d_distance]
    
    if patho_or_benign == "patho":
        # find indices of pathogenic variants
        idx_patho = df_meta_data[df_meta_data["data_source"] == "clinvar"].index
        # sample meta data and emb
        df_meta_data = df_meta_data.loc[idx_patho]
        # reindex the meta data
        df_meta_data = df_meta_data.reset_index(drop=True)
        emb_0 = emb_0[idx_patho]
        emb_1 = emb_1[idx_patho]
        if wt != True:
            # find sample names  in first column of meta data
            sample_names = df_meta_data.iloc[:, 0]
            print(len(sample_names))
            # find indices of sample names in columns and rows of pw_3d_distance
            idx_pw_3d_distance = pw_3d_distance.index[pw_3d_distance.index.isin(sample_names)]
            col_idx_pw_3d_distance = pw_3d_distance.columns[pw_3d_distance.columns.isin(sample_names)]
            pw_3d_distance = pw_3d_distance.loc[idx_pw_3d_distance, col_idx_pw_3d_distance]

## Load data

In [None]:
df_meta_data.head()

## Initialise ema embedding object

In [None]:
# initialize embedding handler
emb = EmbeddingHandler(sample_meta_data=df_meta_data)

# add embeddings to the handler
emb.add_emb_space(embeddings=emb_0, emb_space_name="esm1b")
emb.add_emb_space(embeddings=emb_1, emb_space_name="esm2")
# emb.add_emb_space(embeddings=emb_2, emb_space_name="ESM1v")

if "1" in experiment:
    if wt != True:
        emb.add_pw_metadata(pw_3d_distance, "3d_distance")

## Explore embedding spaces

In [None]:
# show histogram of embedding value distributions of the different embeddings
fig = emb.plot_emb_hist()
fig.show()

In [None]:
# within one sample space show distribution of embedding values between samples
fig = emb.plot_emb_box(group="sample")
fig.show()

In [None]:
emb.get_groups()

In [None]:
fig = emb.plot_emb_box(group="family")
fig.show()

In [None]:
fig = emb.plot_emb_box(group="family")
fig.show()

### Overlap of meta data clusters and unsupervised clusters

In [None]:
emb.recalculate_clusters(n_clusters=5, emb_space_name="esm1b")
emb.recalculate_clusters(n_clusters=5, emb_space_name="esm2")

In [None]:
fig = emb.plot_feature_cluster_overlap(emb_space_name="esm1b", feature="family")
fig.show()

In [None]:
fig = emb.plot_feature_cluster_overlap(emb_space_name="esm2", feature="family")
fig.show()

In [None]:
emb.get_samples_per_group_value(
    group="cluster_esm1b", group_value="1"
)

### Visualise correlation of individual dimensions with meta data

In [None]:
# discrete meta data
emb.plot_emb_cor_per_dim(emb_space_name="esm1b", feature="family")

In [None]:
# continuous meta data
emb.plot_emb_cor_per_dim(emb_space_name="esm2", feature="family")

### Visualise embeddings with dimensionality reduction

#### UMAP

In [None]:
fig = emb.visualise_emb_umap(emb_space_name="esm1b", colour="")
fig.show()

In [None]:
fig = emb.visualise_emb_umap(emb_space_name="esm2", colour="family")
fig.show()

#### PCA

In [None]:
fig = emb.visualise_emb_pca(emb_space_name="esm1b", colour="cluster_esm1b")
fig.show()

In [None]:
fig = emb.visualise_emb_pca(emb_space_name="esm2", colour="cluster_esm1b")
fig.show()

#### t-SNE

In [None]:
fig = emb.visualise_emb_tsne(emb_space_name="esm1b", colour="cluster_esm1b")
fig.show()

In [None]:
fig = emb.visualise_emb_tsne(emb_space_name="esm2", colour="cluster_esm1b")
fig.show()

## Explore sample distances

In [None]:
fig = emb.plot_emb_dis_heatmap(
        emb_space_name="esm1b",
        distance_metric="adjusted_cosine",
        order_x="family",
        order_y="family",
    )
fig.show()
fig = emb.plot_emb_dis_heatmap(
        emb_space_name="esm2",
        distance_metric="adjusted_cosine",
    )
fig.show()

In [None]:
fig = emb.plot_emb_dis_hist("adjusted_cosine")
fig.show()

In [None]:
emb.get_col_continuous()

In [None]:
fig = emb.plot_emb_dis_continuous_correlation(
        emb_space_name="esm2",
        distance_metric="euclidean",
        feature="disorder_propensity",
)
fig.show()

In [None]:
fig = emb.plot_emb_dis_box(group="family", distance_metric="euclidean")
fig.show()

In [None]:
emb.get_value_count_per_group("family")

In [None]:
emb.plot_emb_dis_dif_dis_per_group(
        emb_space_name="esm1b",
        distance_metric="euclidean",
        group="family",
        # group_value="Long QT syndrome",
        # plot_type="box", 
)

In [None]:
emb.plot_emb_dis_dif_dis_per_group(
        emb_space_name="esm2",
        distance_metric="euclidean",
        group="cluster_esm1b",
        # group_value ="0",
        # plot_type="box", 
)

# plot the distribution of sample distances between the two embeddings
# (e.g. euclidean distance, cosine similarity)

In [None]:
fig = emb.plot_emb_dis_his_with_fitted_functions(
    emb_space_name="esm2",
    distance_metric="euclidean",
    rank="bimodal_dis"
)
fig.show()

In [None]:
fig = emb.plot_emb_dis_scatter(emb_space_name_1="esm1b", 
                                emb_space_name_2="esm2", 
                                distance_metric="euclidean",
                                colour_group="cluster_esm1b",
                                colour_value_1="3",
                                colour_value_2="2",
                                # rank="normal_dis"
                                )
fig.show()

## Explore difference between sample distances between embedding spaces

In [None]:
fig = emb.plot_emb_dis_dif_heatmap(
        emb_space_name_1="esm1b",
        emb_space_name_2="esm2",
        distance_metric="adjusted_cosine", # seuclidean, cosine, sqeuclidean, rank, knn, adjusted_cosine
    )
fig.show()

In [None]:
fig = emb.plot_emb_dis_dif_box(
        emb_space_name_1="full_length",
        emb_space_name_2="chopped",
        distance_metric="cityblock_scaled",
        group="family",
    )
fig.show()

In [None]:
fig = emb.plot_emb_dis_dif_heatmap(
        emb_space_name_1="full_length",
        emb_space_name_2="chopped",
        distance_metric="cityblock", # seuclidean, cosine, sqeuclidean, rank
    )
fig.show()

In [None]:
fig = emb.plot_emb_dis_dif_box(
        emb_space_name_1="full_length",
        emb_space_name_2="chopped",
        distance_metric="knn",
        group="family",
    )
fig.show()

# CAREFUL: THIS IS NOT ACURATE AS THE DISTANCE MATRIX IS NOT SYMMETRIC

In [None]:
fig = emb.plot_emb_dist_dif_percentiles(
    emb_space_name_1="bfd",
    emb_space_name_2="u50",
    distance_metric="adjusted_cosine",
)
fig.show()
fig = emb.plot_emb_dist_dif_percentiles(
    emb_space_name_1="bfd",
    emb_space_name_2="u50",
    distance_metric="adjusted_cosine",
    subset_group="patho_effect",
    subset_group_value="1",
    compare_subset_to="within_group",
)
fig.show()
fig = emb.plot_emb_dist_dif_percentiles(
    emb_space_name_1="bfd",
    emb_space_name_2="u50",
    distance_metric="adjusted_cosine",
    subset_group="patho_effect",
    subset_group_value="1",
    compare_subset_to="outside_group",
)
fig.show()
fig = emb.plot_emb_dist_dif_percentiles(
    emb_space_name_1="bfd",
    emb_space_name_2="u50",
    distance_metric="adjusted_cosine",
    subset_group="patho_effect",
    subset_group_value="1",
    compare_subset_to=None,
)
fig.show()