# Visualising embeddings with t-SNE


> ðŸ“š [t-SNE algorithm](https://opentsne.readthedocs.io/en/stable/tsne_algorithm.html)

In [None]:
import os

import numpy as np
import pandas as pd
import plotly.express as px
import torch
from openTSNE import TSNE

from hopwise.utils import init_seed

### Loading the checkpoint

In [None]:
checkpoint_name = "TransE-Jan-23-2025_16-48-43.pth"
checkpoint_name = "TransE-Jan-23-2025_16-48-43.pth"

In [None]:
checkpoint = torch.load(os.path.join("saved", checkpoint_name), weights_only=False)
config = checkpoint["config"]
init_seed(config["seed"], config["reproducibility"])

for weight in checkpoint["state_dict"].keys():
    checkpoint["state_dict"][weight] = checkpoint["state_dict"][weight].to(torch.device("cpu")).numpy()

In [None]:
def plot_fn(embeddings, desc="Entity"):
    ids = list(range(embeddings.shape[0]))
    fig = px.scatter(
        x=embeddings[:, 0],
        y=embeddings[:, 1],
        color=ids,
        labels={"x": "Embedding Dimension 1", "y": "Embedding Dimension 2", "color": f"{desc} ID"},
        title=f"{config['model']} {desc} Embeddings",
        width=1024,
        height=1024,
        template="plotly_white",
    )

    fig.show()

## Knowledge-Aware Models

> ðŸ’¡ See https://opentsne.readthedocs.io/en/stable/examples/02_advanced_usage/02_advanced_usage.html for a more detailed configuration

In [None]:
tsne = TSNE(
    perplexity=30,
    n_jobs=8,
    initialization="random",
    metric="cosine",
    random_state=config["seed"],
    verbose=True,
)

### Plot Users

In [None]:
user_weights = checkpoint["state_dict"]["user_embedding.weight"]
tsne_embeddings_users = tsne.fit(user_weights)
tsne_embeddings_users = tsne.fit(user_weights)

In [None]:
plot_fn(tsne_embeddings_users, "User")
plot_fn(tsne_embeddings_users, "User")

### Plot entities

In [None]:
entity_weights = checkpoint["state_dict"]["entity_embedding.weight"]
tsne_embeddings_entities = tsne.fit(entity_weights)
tsne_embeddings_entities = tsne.fit(entity_weights)

In [None]:
plot_fn(tsne_embeddings_entities, "Entity")
plot_fn(tsne_embeddings_entities, "Entity")

### Plot relations

In [None]:
relation_weights = checkpoint["state_dict"]["relation_embedding.weight"]
tsne_embeddings_relations = tsne.fit(relation_weights)

In [None]:
plot_fn(tsne_embeddings_relations, "Relation")

### Combine embeddings in the same plot

In [None]:
def combine_embeddings(**kwargs):
    embeddings_list = list()
    identifiers_list = list()

    for embeddings_name, embeddings in kwargs.items():
        embeddings_list.append(embeddings)
        identifiers_list.extend([f"{embeddings_name} {id}" for id in range(embeddings.shape[0])])
        print(f"[+] {embeddings_name}: {embeddings.shape}")

    embeddings_list = np.concatenate(embeddings_list, axis=0)

    combined_df = pd.DataFrame(
        {
            "x": embeddings_list[:, 0],
            "y": embeddings_list[:, 1],
            "type": [id.split(" ")[0] for id in identifiers_list],
            "identifier": identifiers_list,
        }
    )

    fig = px.scatter(
        combined_df,
        x="x",
        y="y",
        color="type",
        hover_data=["identifier"],
        labels={"x": "Embedding Dimension 1", "y": "Embedding Dimension 2", "type": "Embedding Type"},
        title=f"Visualising Combined Embeddings {checkpoint_name}",
        width=1024,
        height=1024,
        template="plotly_white",
    )
    fig.show()

In [None]:
combine_embeddings(user=tsne_embeddings_users, entity=tsne_embeddings_entities, relation=tsne_embeddings_relations)