<a href="https://colab.research.google.com/github/tayfununal/Uniform-Autoencoder-with-Latent-Flow-Matching/blob/main/datasets/visualization/Spiral_Vis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['font.size'] = 20
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

In [2]:
def make_spiral_4class(n_points=1000, noise=0.5):
    t = np.linspace(0, 4 * np.pi, n_points)
    x = t * np.cos(t)
    y = t * np.sin(t)
    z = t

    # Split the labels into four groups
    c = np.zeros_like(t, dtype=int)
    c[(t >= np.pi) & (t < 2*np.pi)] = 1
    c[(t >= 2*np.pi) & (t < 3*np.pi)] = 2
    c[(t >= 3*np.pi)] = 3

    # 3D data with noise
    data = np.stack([x, y, z], axis=1)
    data += noise * np.random.randn(*data.shape)

    data = torch.tensor(data, dtype=torch.float32)
    c = torch.tensor(c, dtype=torch.long).unsqueeze(1)  # (N, 1)

    return data, c


In [None]:
!pip uninstall -y plotly kaleido
!pip install plotly==5.24.1 kaleido==0.2.1

In [5]:
import plotly.express as px
import torch
import pandas as pd

# Data
X, y = make_spiral_4class(n_points=1000, noise=0.8)
X = X.numpy()
y = y.squeeze().numpy()

X_np = X
y_np = y

# Create a DataFrame compatible with Plotly.
df = pd.DataFrame({
    'x1': X_np[:, 0],
    'x2': X_np[:, 1],
    'x3': X_np[:, 2],
    'label': [str(i) for i in y_np]  # Classes are strings: ‘0’, ‘1’, ‘2’, ‘3’
})

# Match colors in a specific order
color_order = ['0', '1', '2', '3']
colors = ['blue', 'green', 'orange', 'red']

fig = px.scatter_3d(
    df,
    x='x1',
    y='x2',
    z='x3',
    color='label',
    category_orders={'label': color_order},  # Fixed class order
    color_discrete_sequence=colors,
    labels={'x1': 'x₁', 'x2': 'x₂', 'x3': 'x₃'},
)

fig.update_layout(
    margin=dict(l=0, r=0, t=20, b=0),
    scene=dict(

        xaxis_title_font=dict(
            size=30,
        ),
        yaxis_title_font=dict(
            size=30,
        ),
        zaxis_title_font=dict(
            size=30,
        )
    ),
    font=dict(  # Global font setting for all text
        family='Times New Roman',
        size=15,  # Global font size
    ),
    scene_camera=dict(
        eye=dict(x=1, y=2, z=0.5)  # Camera position
    ),
    legend_title_text='Class',
    legend=dict(
        font=dict(size=22),  # Legend text
        title_font=dict(     # Legend title
            family='Times New Roman',
            size=22,
            color='black'
        )
    )
)

fig.update_traces(marker=dict(size=4))

fig.write_image("spiral.pdf", format="pdf")
fig.show()