#### IMPORTS

In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from custom_tsne import CustomTSNE
from explainer import Explainer

#### LOADING DATA

In [2]:
filepath = "datasets/column_3C.xls"

# Loading dataset

df = pd.read_csv(filepath)

X = df.loc[:, df.columns != "class"].to_numpy().astype(np.float64)
scaler = StandardScaler()
X = scaler.fit_transform(X)
features = df.columns.to_list()[:-1]
targets = df["class"].to_numpy()

# Loading t-SNE data

tsne = CustomTSNE(X, targets)
tsne.load_tsne_data("embeddings/vertebral/perplexity_30")

#### Plot t-SNE embedding

In [3]:
tsne.plot_tsne_embedding_classification()

#### COMPUTE GRADIENTS

In [4]:
explainer = Explainer(tsne, features)
explainer.compute_all_gradients()

array([[[-0.40148258, -0.29301862, -0.39110016, -0.29706508,
         -0.13700055,  0.02683929],
        [-0.00426168, -0.57970697, -0.09709528,  0.42675998,
         -0.21077273, -0.2402893 ]],

       [[-0.1605473 ,  1.60518389, -0.70939658, -1.40299061,
         -1.42344697, -0.10447007],
        [ 0.87380082,  0.15276197,  0.53176123,  1.00814202,
         -0.25004184, -0.22758012]],

       [[ 0.34215867,  0.81791982, -0.49442663, -0.1704812 ,
         -1.72634886,  0.37248164],
        [ 0.47134669, -0.63078673,  0.46692958,  1.07557095,
         -0.23059658,  0.34693585]],

       ...,

       [[ 0.43965384,  3.1310475 ,  0.56690377, -1.76996759,
         -2.44134597, -0.64875358],
        [ 1.94509805, -0.29787292,  0.53666653,  2.71978367,
         -0.46122875, -0.05329938]],

       [[-0.05315716,  0.13602199, -0.96850902, -0.16967731,
         -0.99626126,  0.13289703],
        [ 1.21149196,  0.08846524,  0.18399738,  1.48970891,
         -0.18624362,  0.17421676]],

       

#### EXPLAIN EMBEDDING

##### FEATURE IMPORTANCE RANKING

In [5]:
explainer.plot_feature_importance_ranking()

##### VECTOR FIELD FOR EACH FEATURE

In [6]:
for i in range(6):
    explainer.plot_arrow_fields(i, scale=0.2)