In [2]:
import numpy as np
from jax import numpy as jnp
from plotly import express as px
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE

import jaxsne

In [3]:
digits, digit_class = load_digits(return_X_y=True)

In [4]:
xf, yf = TSNE().fit_transform(X=digits).T
px.scatter(x=xf, y=yf, color=digit_class.astype("U")).update_layout(
    width=600,
    height=600,
    xaxis_scaleanchor="y",
    plot_bgcolor="rgba(0,0,0,0)",
    xaxis_visible=False,
    yaxis_visible=False,
)

In [5]:
rng = np.random.default_rng()
nd, _ = digits.shape
samp_inds = rng.choice(nd, size=400, replace=False)
data = digits[samp_inds]
data_class = digit_class[samp_inds].astype("U")

In [6]:
xd, yd = TSNE(method="exact").fit_transform(data).T
px.scatter(x=xd, y=yd, color=data_class).update_layout(
    width=600,
    height=600,
    xaxis_scaleanchor="y",
    plot_bgcolor="rgba(0,0,0,0)",
    xaxis_visible=False,
    yaxis_visible=False,
)

In [28]:
res = jaxsne.sne(data)
x, y = res.T
px.scatter(x=x, y=y, color=data_class).update_layout(
    width=600,
    height=600,
    xaxis_scaleanchor="y",
    plot_bgcolor="rgba(0,0,0,0)",
    xaxis_visible=False,
    yaxis_visible=False,
)

In [29]:
res = jaxsne.sne(data, out_metric=jaxsne.metric.poincare)
res = res / (1 + jnp.linalg.norm(res, 2, 1)[:, None])
x, y = res.T
px.scatter(x=x, y=y, color=data_class).add_shape(
    type="circle",
    x0=-1,
    y0=-1,
    x1=1,
    y1=1,
).update_layout(
    width=600,
    height=600,
    xaxis_scaleanchor="y",
    plot_bgcolor="rgba(0,0,0,0)",
    xaxis_visible=False,
    yaxis_visible=False,
)

In [30]:
res = jaxsne.sne(data, n_components=3, out_metric=jaxsne.metric.cosine)
res = res / jnp.linalg.norm(res, 2, 1)[:, None]
x, y, z = res.T
px.scatter_3d(x=x, y=y, z=z, color=data_class).update_layout(
    width=600,
    height=600,
    scene_xaxis_visible=False,
    scene_yaxis_visible=False,
    scene_zaxis_visible=False,
)