In [None]:
import os
import json
import torch
import plotly
import argparse
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd

from tqdm import tqdm

from sklearn.preprocessing import normalize
from scipy.spatial import SphericalVoronoi, geometric_slerp

In [None]:
def unit_sphere(resolution=100):
    """Return the coordinates for the unit 3D sphere with provided resolution"""
    u, v = np.mgrid[0:2*np.pi:resolution*2j, 0:np.pi:resolution*1j]
    xx = np.cos(u)*np.sin(v)
    yy = np.sin(u)*np.sin(v)
    zz = np.cos(v)
    return (xx, yy, zz)

In [None]:
BASE_DIR = '/content/drive/MyDrive/umap_voronoi_results/checkpoint-645000_gemma32_distributed/'
df = pd.read_csv( os.path.join(BASE_DIR, 'info_rows_umap_voronoi_first_target.csv') )

In [None]:
df

Unnamed: 0,index_sentence,id,index_id,src_lang_code,tgt_lang_code,bos,src_tag_token,last_token,text_token
0,0,1099,3,eng_Latn,glg_Latn,no,no,no,""""
1,0,1812,4,eng_Latn,glg_Latn,no,no,no,We
2,0,3193,5,eng_Latn,glg_Latn,no,no,no,now
3,0,913,6,eng_Latn,glg_Latn,no,no,no,have
4,0,755,7,eng_Latn,glg_Latn,no,no,no,4
...,...,...,...,...,...,...,...,...,...
271,0,25,30,spa_Latn,glg_Latn,no,no,no,","
272,0,15,31,spa_Latn,glg_Latn,no,no,no,""""
273,0,5299,32,spa_Latn,glg_Latn,no,no,no,agre
274,0,398,33,spa_Latn,glg_Latn,no,no,no,go


In [None]:
targets = df['src_lang_code'].values
tokens  = df['text_token'].values

In [None]:
def plot_voronoi_3d(points_2d, targets, layer, tokens = None):

    points = np.zeros((points_2d.shape[0], 3))
    points[:, 0] = np.sin(points_2d[:, 0]) * np.cos(points_2d[:, 1])
    points[:, 1] = np.sin(points_2d[:, 0]) * np.sin(points_2d[:, 1])
    points[:, 2] = np.cos(points_2d[:, 0])

    unique_targets = np.unique(targets)
    centroid_points = np.zeros([unique_targets.shape[0], 3])
    for ix, target in enumerate(unique_targets):
        target_points = points[targets == target]
        centroid_points[ix] = target_points.mean(axis=0)
    centroid_points = normalize(centroid_points)

    sv = SphericalVoronoi(centroid_points, 1, np.zeros(3))

    # sort vertices (optional, helpful for plotting)
    sv.sort_vertices_of_regions()
    t_vals = np.linspace(0, 1, 100)

    if tokens is not None:
      fig_umap_3d = px.scatter_3d(
          points, x=0, y=1, z=2,
          title=f"UMAP 3D projection",
          color=targets,
          text = tokens
      )
    else:
      fig_umap_3d = px.scatter_3d(
          points, x=0, y=1, z=2,
          title=f"UMAP 3D projection",
          color=targets
      )

    x_sphere_surface, y_sphere_surface, z_sphere_surface = unit_sphere()
    fig_umap_3d.add_trace(
        go.Surface(x=x_sphere_surface, y=y_sphere_surface, z=z_sphere_surface,
                   colorscale=['#f0f3f3', '#f0f3f3'],
                   showscale=False,
                   lighting=dict(ambient=1),
                   opacity=0.75)
    )

    fig_umap_3d.add_trace(go.Scatter3d(
        x=centroid_points[:, 0], y=centroid_points[:, 1], z=centroid_points[:, 2],
        name='centroids',
        mode='markers',
        marker=dict(
            size=12,
            symbol="cross",
            color="black",
        )
    ))

    for region in sv.regions:
        n = len(region)
        for i in range(n):
            start = sv.vertices[region][i]
            end = sv.vertices[region][(i + 1) % n]
            result = geometric_slerp(start, end, t_vals)
            fig_umap_3d.add_trace(go.Scatter3d(
                x=result[..., 0],
                y=result[..., 1],
                z=result[..., 2],
                showlegend=False,
                mode="lines",
                # opacity=value,
                line=dict(
                    color="black",
                    width=3,
                )))

    fig_umap_3d.update_layout(scene=dict(
        xaxis=dict(backgroundcolor="white",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                showticklabels=False,
                visible=False),
        yaxis=dict(backgroundcolor="white",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                showticklabels=False,
                visible=False),
        zaxis=dict(backgroundcolor="white",
                gridcolor="white",
                showbackground=True,
                zerolinecolor="white",
                showticklabels=False,
                visible=False)
    ))
    if tokens is not None:
      plotly.offline.plot(fig_umap_3d, filename=f"voronoi_layer_{layer}.html")
    else:
      plotly.offline.plot(fig_umap_3d, filename=f"voronoi_layer_{layer}_text.html")
    #plotly.plot(fig_umap_3d)

In [None]:
for layer in range(0, 19):
  embeddings_path = f'{BASE_DIR}/lang_umap_voronoi_first_target_layer{layer}.npy'
  embeddings = np.load(embeddings_path)
  plot_voronoi_3d(embeddings, targets, layer, tokens)

### Three-row plots

In [None]:
import matplotlib.colors as mcolors

unique_lang_codes = df['src_lang_code'].unique()
colors = plt.cm.get_cmap('tab20', len(unique_lang_codes))
lang_to_color = {lang: mcolors.rgb2hex(colors(i)[:3]) for i, lang in enumerate(unique_lang_codes)}
df['color'] = df['src_lang_code'].map(lang_to_color)
colors = df['color'].values
lang_names = list(df['src_lang_code'].values)


The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.



In [None]:
pip install -U kaleido



In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.spatial import SphericalVoronoi
from sklearn.preprocessing import normalize

def plot_voronoi_3d_subplot(points_2d, targets, layer, colors, lang_names, lang_to_color, tokens=None, fig=None, row=1, col=1):
    points = np.zeros((points_2d.shape[0], 3))
    points[:, 0] = np.sin(points_2d[:, 0]) * np.cos(points_2d[:, 1])
    points[:, 1] = np.sin(points_2d[:, 0]) * np.sin(points_2d[:, 1])
    points[:, 2] = np.cos(points_2d[:, 0])

    unique_targets = np.unique(targets)
    centroid_points = np.zeros([unique_targets.shape[0], 3])
    for ix, target in enumerate(unique_targets):
        target_points = points[targets == target]
        centroid_points[ix] = target_points.mean(axis=0)
    centroid_points = normalize(centroid_points)

    sv = SphericalVoronoi(centroid_points, 1, np.zeros(3))
    sv.sort_vertices_of_regions()
    t_vals = np.linspace(0, 1, 100)

    x_sphere_surface, y_sphere_surface, z_sphere_surface = unit_sphere()

    for lang_name in set(lang_names):
      indices = [i for i, n in enumerate(lang_names) if n==lang_name ]
      fig.add_trace(go.Scatter3d(
          x=points[indices,0], y=points[indices,1], z=points[indices,2],
          name=lang_name,
          mode='markers',
          marker=dict(size=7, color=[colors[i] for i in indices][0]),
          showlegend=True if col == 1 else False
      ), row=row, col=col )

    # Plot centroids and spherical Voronoi diagram
    fig.add_trace(go.Scatter3d(
        x=centroid_points[:, 0], y=centroid_points[:, 1], z=centroid_points[:, 2],
        mode='markers', showlegend=False, marker=dict(size=7, symbol="cross", color="black")),
        row=row, col=col)
    fig.add_trace(go.Surface(x=x_sphere_surface, y=y_sphere_surface, z=z_sphere_surface, colorscale=['#f0f3f3', '#f0f3f3'],
                             opacity=0.9, showscale=False, showlegend=False, lighting=dict(ambient=1)), row=row, col=col)

    for region in sv.regions:
        n = len(region)
        for i in range(n):
            start = sv.vertices[region][i]
            end = sv.vertices[region][(i + 1) % n]
            result = geometric_slerp(start, end, t_vals)
            fig.add_trace(go.Scatter3d(
                x=result[..., 0],
                y=result[..., 1],
                z=result[..., 2],
                mode="lines",
                showlegend=False,
                line=dict(color="black", width=2)),
                row=row, col=col)

layers = [0, 17, 18]
fig = make_subplots(rows=1, cols=len(layers), specs=[[{'type': 'surface'}] * len(layers)])

for i, layer in enumerate(layers):
    embeddings_path = f'{BASE_DIR}/lang_umap_voronoi_first_target_layer{layer}.npy'
    embeddings = np.load(embeddings_path)
    plot_voronoi_3d_subplot(embeddings, targets, layer, colors, lang_names, lang_to_color, tokens, fig, 1, i+1)


for i in range(1, len(layers) + 1):
    fig.update_layout(**{
        f'scene{i}': {
            'xaxis': {'backgroundcolor': "white", 'gridcolor': "white", 'showbackground': True, 'zerolinecolor': "white", 'showticklabels': False, 'visible': False},
            'yaxis': {'backgroundcolor': "white", 'gridcolor': "white", 'showbackground': True, 'zerolinecolor': "white", 'showticklabels': False, 'visible': False},
            'zaxis': {'backgroundcolor': "white", 'gridcolor': "white", 'showbackground': True, 'zerolinecolor': "white", 'showticklabels': False, 'visible': False}
        }
    })

fig.update_layout(
    title="",
    height=550,
    legend=dict(
        x=0.5,
        xanchor="center",
        orientation="h",  # This makes the legend horizontal,
        font=dict(size=13)
    )
)
fig.show()