#### IMPORTS

In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from custom_tsne import CustomTSNE
from explainer import Explainer

#### LOADING DATA

In [2]:
filepath = "datasets/column_3C.xls"

# Loading dataset

df = pd.read_csv(filepath)

X = df.loc[:, df.columns != "class"].to_numpy().astype(np.float64)
scaler = StandardScaler()
X = scaler.fit_transform(X)
features = df.columns.to_list()[:-1]
targets = df["class"].to_numpy()

# Loading t-SNE data

tsne = CustomTSNE(X, targets)
tsne.load_tsne_data("embeddings/vertebral/perplexity_30")

#### Plot t-SNE embedding

In [3]:
tsne.plot_tsne_embedding_classification()

#### COMPUTE GRADIENTS

In [4]:
explainer = Explainer(tsne, features)
explainer.compute_all_gradients()

array([[[ 0.14699684,  0.66370871,  0.21858535, -0.30610677,
          0.23293221,  0.22675574],
        [-0.37189417, -0.08169607, -0.32407345, -0.41663394,
         -0.07003407,  0.11381595]],

       [[-0.99303435, -0.9825319 , -1.00487954, -0.54256854,
         -0.77638947,  0.37275186],
        [-0.22212175,  1.3614492 , -0.16239438, -1.30032819,
         -0.52606472,  0.02246377]],

       [[-0.53672547,  0.39903314, -0.33791509, -0.98672683,
          0.77276851, -0.41626336],
        [ 0.18395435,  0.98628939, -0.587439  , -0.49916753,
         -1.53696033,  0.23137852]],

       ...,

       [[-3.30215346, -1.32761319, -0.52226873, -3.25039961,
          1.90830827,  0.12393393],
        [-0.62486654,  1.94884284,  0.77891783, -2.25545391,
         -1.53595766, -0.64327943]],

       [[-1.24518211, -0.68819905, -1.14076422, -1.08580587,
          0.97748295, -0.38291328],
        [ 0.22521178,  0.75142931, -0.31195815, -0.27107618,
         -0.83279133,  0.3760876 ]],

       

#### EXPLAIN EMBEDDING

##### FEATURE IMPORTANCE RANKING

In [5]:
explainer.plot_feature_importance_ranking()

##### VECTOR FIELD FOR EACH FEATURE

In [11]:
for i in range(6):
    explainer.plot_arrow_fields(i, scale=0.2)