NB! Umap installation required. Type: 'pip install umap-learn'.

This notebook visualises the Synthetic Gaussians dataset and compares its embedding into a pre-trained AE latent space to standard dimensionality reduction techniques such as:

0) PCA https://pytorch.org/docs/stable/generated/torch.pca_lowrank.html
1) LLE https://cs.nyu.edu/~roweis/lle/papers/lleintroa4.pdf
2) t-SNE https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding
3) UMAP https://umap-learn.readthedocs.io/en/latest/

The level sets are plotted to estimate the distortion.  

# I. Train and test datasets

In [None]:
Path_pictures = f"../plots/"
# Hyperparameters for dataset

D = 784 #dimension
#D = 3
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
d = 2 # starting dimention of gaussians
#n = 10**3 # num of points in each plane
shift_class = 0
intercl_var = 0.1 #initially 0.1
var_class = 1

split_ratio = 0.2

# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up

import torch
import ricci_regularization
import matplotlib.pyplot as plt
import numpy as np
import math
from sklearn import manifold

In [None]:
# Generate dataset
# via classes
torch.manual_seed(0) # reproducibility
my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var, var_class = var_class)

train_dataset = my_dataset.create

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])


# II. Fit dimension reduction models

## LLE

In [None]:
points, labels = test_data[:]

In [None]:
synthetic_lle, synthetic_err = manifold.locally_linear_embedding(
    points, n_neighbors=9, n_components=2, reg=0.0001)

In [None]:
fig, axs = plt.subplots(figsize=(8, 8))
axs.scatter(synthetic_lle[:, 0], synthetic_lle[:, 1], c=labels)
#axs.set_title("LLE Embedding of Synthetic Gaussians dataset")

## TSNE

In [None]:
from sklearn.manifold import TSNE

In [None]:
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    return plt.cm.get_cmap(base_cmap, N)

## II.1. TSNE check

In [None]:
test_data[:][0].shape

In [None]:
# TNSE check on test set
synthetic_points = test_data[:][0]

tsne   = TSNE(n_components=2, verbose=1, random_state=123)
z_test = tsne.fit_transform(synthetic_points.numpy())

In [None]:
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 20
plt.figure(figsize=(12,9), dpi=400)
plt.scatter( z_test[:,0], z_test[:,1], c=test_data[:][1], alpha=0.5, cmap=discrete_cmap(k,'jet'),marker='o',edgecolors=None )
plt.title( "t-SNE projection of the \n Synthetic dataset")
plt.colorbar(ticks=range(k),orientation='vertical',shrink = 0.7)
#plt.savefig("/home/alazarev/CodeProjects/Experiments/TSNE_synthetic.eps",bbox_inches='tight',format='eps')
plt.show()


In [None]:
# TNSE check on train set
"""
synthetic_points = train_data[:][0].view(-1,28*28)

tsne   = TSNE(n_components=2, verbose=1, random_state=123)
z_train = tsne.fit_transform(synthetic_points.numpy())
"""

In [None]:
"""
plt.scatter( z_train[:,0], z_train[:,1], c=train_data[:][1], alpha=0.5 )
plt.title( "TSNE projection of train data")
plt.show()
"""

### Plots using Seaborn

In [None]:
"""
import pandas as pd 

# Format data
df = pd.DataFrame()
#df["y"] = labels.numpy()
df["y"] = test_data[:][1].numpy() #test_data[:][1] are labels
df["comp-1"] = z_test[:,0]
df["comp-2"] = z_test[:,1]

import seaborn as sns
import numpy as np # this module is useful to work with numerical arrays
sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(),
                palette=sns.color_palette("hls", 10),
                data=df).set(title="Synthetic dataset data T-SNE projection")
"""

## II.2. UMAP

In [None]:
points = test_data[:][0]
labels = test_data[:][1]

In [None]:
points.shape

In [None]:
!pip install umap-learn

In [None]:
import umap

mapper = umap.UMAP().fit(points)


In [None]:
encoded_points = mapper.embedding_

In [None]:
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 20
plt.figure(figsize=(12,9), dpi=400)
plt.scatter( encoded_points[:,0], encoded_points[:,1], c=test_data[:][1], alpha=0.5, cmap=discrete_cmap(k,'jet'),marker='o',edgecolors=None )
plt.title( "UMAP embedding of the \n Synthetic dataset")
plt.colorbar(ticks=range(k),orientation='vertical',shrink = 0.7)
#plt.savefig("/home/alazarev/CodeProjects/Experiments/UMAP_synthetic.eps",bbox_inches='tight',format='eps')
plt.show()

# III. 3D visualization

In [None]:
"""
D = 3
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
d = 2 # starting dimention of gaussians
#n = 10**3 # num of points in each plane
shift_class = 0
intercl_var = 1 #initially 0.1
var_class = 0.1
torch.manual_seed(0) # reproducibility
my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var, var_class = var_class)

train_dataset = my_dataset.create

split_ratio = 0.2
m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])
"""
# only for D=3
if D==3:
    points = test_data[:][0].squeeze()
    labels = test_data[:][1]
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    plt.rcParams.update({'font.size': 20})
    ax.scatter(points[:,0],
               points[:,1],
               points[:,2],
               c=labels, s=30, alpha = 0.5,cmap=plt.cm.get_cmap('jet', k))
    plt.title("Synthetic dataset in 3d")
    ax.view_init(azim=155, elev=15)
    plt.show()

# Histograms

In [None]:
import numpy as np
import math

### Synthetic dataset check
points = []
all_points = []
plane_idx = 0
for tensor,label in train_dataset:
    if int(label)==plane_idx:
        points.append(tensor)
    all_points.append(tensor)    
# end for
array_points = np.array(points).squeeze()
points_in_0th_Gaussian = torch.from_numpy(array_points)
shift_1 = my_dataset.shifts[0]

deviations_squared = (points_in_0th_Gaussian - shift_1.T).norm(dim=1)**2

fig, ax = plt.subplots()
plt.title(f"Squares of l2 norms of deviations in plane {plane_idx} \nof the dataset")
ax.hist(deviations_squared,bins=round(math.sqrt(n)))
plt.xlabel("Squared l2 norm of deviation")
#fig.text(0.0,-0.35, f"Mean square of l2 norms of deviations:{deviations_squared.mean().item():.4f} \nSet parameters: {k} Gaussian(s) defined in {d}-dimensional planes \n are isometrically embedded into {D} dimensional space \nEach Gaussian is of variance={var_class} and their means are randomly \nsampled from $\mathcal{{N}}(0,{intercl_var}\cdot\mathbf{{I}})$).")
fig.text(0.0,-0.15, f"Mean square of l2 norms of deviations:{deviations_squared.mean().item():.4f} \nSet params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}.")
plt.show()

#torch.trace(torch.cov(points_in_0th_Gaussian.T))
array_points = np.array(all_points).squeeze()
all_points = torch.from_numpy(array_points)

norms_squared = (all_points).norm(dim=1)**2

fig, ax = plt.subplots()
plt.title(f"Squares of l2 norms of all points of the dataset")
ax.hist(norms_squared,bins=round(math.sqrt(n)))
plt.xlabel("Squared l2 norm of a point")
#fig.text(0.0,-0.35, f"Mean square of l2 norms of points in the set: {norms_squared.mean().item():.4f} \nSet parameters: {k} Gaussian(s) defined in {d}-dimensional planes \n are isometrically embedded into {D} dimensional space \nEach Gaussian is of variance={var_class} and their means are randomly \nsampled from $\mathcal{{N}}(0,{intercl_var}\cdot\mathbf{{I}})$).")
fig.text(0.0,-0.15, f"Mean square of l2 norms of points in the set: {norms_squared.mean().item():.4f} \nSet params: n={n}, k={k}, d={d}, D={D}, $\sigma$={var_class}, $\sigma_{{I}}$={intercl_var}.")
plt.show()

for plane_idx in range(k):
    points =[]
    for tensor,label in train_dataset:
        if int(label)==plane_idx:
            points.append(tensor)
    # end for
    array_points = np.array(points).squeeze()

    m = len(array_points)
    mean = torch.from_numpy((array_points.sum(axis=0))/m)
     
    print(f"\n Plane {plane_idx}.")
    print(f"The mean L2 norm of of samples from plane {plane_idx}:\n {mean.norm()}")
    print(f"To be compared to its estimate (the norm of the random shift):\n {my_dataset.shifts[plane_idx].norm()} ~sqrt(D): {math.sqrt(D)}")

    matrix_of_Gaussian_samples = torch.from_numpy(array_points)
    cov_matrix = torch.cov(matrix_of_Gaussian_samples.T)
    print(f"Frobenius norm of the covariance matrix of samples from plane {plane_idx}:\n {cov_matrix.norm()}")
    print(f'To be compared to its estimate ~var_class*sqrt(d):\n {var_class*math.sqrt(d)}')
# end for

# Distance from the mean Heatmaps

In [None]:
shifts = my_dataset.shifts

In [None]:
data_for_plot = test_data

#latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
int_labels = labels.to(int)
init_data = data_for_plot[:][0]
centers = []

for label in int_labels:
    centers.append(shifts[label])
centers_tensor = torch.from_numpy(np.array(centers).squeeze())
distances = torch.norm(init_data-centers_tensor,dim=1)

In [None]:
labels

In [None]:
centers_tensor

In [None]:
plt.figure(figsize=(9,9), dpi=400)
plt.scatter( z_test[:,0], z_test[:,1], c=distances, s=40, alpha=0.5, marker='o', edgecolor='none', cmap='jet')
# use for logscale: norm=matplotlib.colors.LogNorm()
#plt.title( "TSNE embedding of the \n Synthetic dataset")
#plt.colorbar(label="Distance to cloud center",orientation='vertical',shrink = 0.7)
plt.show()

### 3 colormaps t-SNE

In [None]:
import pandas as pd
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((torch.tensor(z_test),
                                     labels.unsqueeze(1),
                                     distances.unsqueeze(1)),
                                     dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("t-SNE embedding for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    #fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
fig.savefig(f'{Path_pictures}/t-SNE_distance_to_means_3heatmaps.pdf',bbox_inches='tight',format='pdf')
fig.savefig(f'{Path_pictures}/t-SNE_distance_to_means_3heatmaps.png',bbox_inches='tight',format='png')
fig.show()

UMAP

In [None]:
plt.figure(figsize=(9,9), dpi=400)
plt.scatter( encoded_points[:,0], encoded_points[:,1], c=distances, s = 40,  alpha=0.5, marker='o', edgecolor='none', cmap='jet')
# use for logscale: norm=matplotlib.colors.LogNorm()
#plt.title( "UMAP embedding of the \n Synthetic dataset")
#plt.colorbar(label="Distance to cloud center",orientation='vertical',shrink = 0.7)
plt.show()

In [None]:
import pandas as pd
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((torch.tensor(encoded_points),labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("UMAP embedding for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    #fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
fig.savefig(f'{Path_pictures}/UMAP_distance_to_means_3heatmaps.pdf',bbox_inches='tight',format='pdf')
fig.show()

### PCA

In [None]:
A = test_data[:][0]
labels = test_data[:][1]
u,s,v = torch.pca_lowrank(A,q=2)

In [None]:
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((u,labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("UMAP embedding for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    #fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
fig.savefig(f'{Path_pictures}/PCA_distance_to_means_3heatmaps.pdf',bbox_inches='tight',format='pdf')
fig.show()

## LLE

In [None]:

plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((torch.tensor(synthetic_lle),labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("UMAP embedding for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    #fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
fig.savefig(f'{Path_pictures}/LLE_distance_to_means_3heatmaps.pdf',bbox_inches='tight',format='pdf')
fig.show()

## Distane to means in 3d

In [None]:
D = 3
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
d = 2 # starting dimention of gaussians
#n = 10**3 # num of points in each plane
shift_class = 0
intercl_var = 2 #initially 0.1
var_class = 0.75
torch.manual_seed(0) # reproducibility
my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var, var_class = var_class)

train_dataset = my_dataset.create

split_ratio = 0.2
m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

shifts = my_dataset.shifts

data_for_plot = test_data

#latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
int_labels = labels.to(int)
init_data = data_for_plot[:][0]
centers = []

for label in int_labels:
    centers.append(shifts[label])
centers_tensor = torch.from_numpy(np.array(centers).squeeze())
distances = torch.norm(init_data-centers_tensor,dim=1)

# only for D=3

points = test_data[:][0].squeeze()
labels = test_data[:][1]
fig = plt.figure(figsize=(12,9), dpi=400)
ax = fig.add_subplot(projection='3d')
plt.rcParams.update({'font.size': 20})

plot = ax.scatter(points[:,0],
            points[:,1],
            points[:,2],
            c=distances, s=15, alpha = 0.5,cmap='jet')
plt.colorbar(plot,label="Distance to cloud center",orientation='vertical',shrink = 0.5,location='left')
#plt.title("Synthetic dataset in 3d")
ax.view_init(azim=145, elev=15)

ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Now set color to white (or whatever is "invisible")
ax.xaxis.pane.set_edgecolor('w')
ax.yaxis.pane.set_edgecolor('w')
ax.zaxis.pane.set_edgecolor('w')

# Bonus: To get rid of the grid as well:
ax.grid(True)
#plt.savefig(f'{Path_pictures}/synthetic_3d.png',bbox_inches='tight',format='png')

plt.show()

### 3 colormaps in 3d

In [None]:
latent_labels_distances = torch.cat((points,labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["right","bottom","left"]
colorbar_orientations = ["vertical","horizontal","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.75),(0.75,0.5),(0.5,0.5)]

#plt.title("t-SNE embedding for the Synthetic dataset")
fig = plt.figure(figsize=(12,9), dpi=400)
ax = fig.add_subplot(projection='3d')
plt.rcParams.update({'font.size': 20})

for plane_idx in range(k):
    #break
    # D is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[D] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=D].values)
    p = ax.scatter( latent_points_in_plane[:,0], 
                   latent_points_in_plane[:,1],
                   latent_points_in_plane[:,2], 
                   c=latent_points_in_plane[:,3], 
                   alpha=0.5, marker='o', s=15, edgecolor='none', 
                   cmap=cmaps[plane_idx])
    #fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", orientation=colorbar_orientations[plane_idx],shrink = colorbar_shrinks[plane_idx],location = colorbar_locations[plane_idx],pad = 0.05, anchor = colorbar_anchors[plane_idx])
    ax.view_init(azim=145, elev=15)

ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False

# Now set color to white (or whatever is "invisible")
ax.xaxis.pane.set_edgecolor('w')
ax.yaxis.pane.set_edgecolor('w')
ax.zaxis.pane.set_edgecolor('w')

# Bonus: To get rid of the grid as well:
ax.grid(True)
fig.savefig(f'{Path_pictures}/3d_distance_to_means_3heatmaps.pdf',bbox_inches='tight',format='pdf')
fig.savefig(f'{Path_pictures}/3d_distance_to_means_3heatmaps.png',bbox_inches='tight',format='png')
fig.show()

In [None]:
D = 2
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
d = 2 # starting dimention of gaussians
#n = 10**3 # num of points in each plane
shift_class = 0.0
intercl_var = 1/3 #initially 0.1
var_class = 1/81
torch.manual_seed(7) # reproducibility
my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var, var_class = var_class)

train_dataset = my_dataset.create

split_ratio = 0.2
m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

shifts = my_dataset.shifts

data_for_plot = test_data

#latent = encoder(data_for_plot[:][0].squeeze()).detach()
labels = data_for_plot[:][1]
int_labels = labels.to(int)
init_data = data_for_plot[:][0]
centers = []

for label in int_labels:
    centers.append(shifts[label])
centers_tensor = torch.from_numpy(np.array(centers).squeeze())
distances = (3/0.4)*torch.norm(init_data-centers_tensor,dim=1)

# only for D=3

points = test_data[:][0].squeeze()
labels = test_data[:][1]
plt.figure(figsize=(9,9), dpi=400)
plt.rcParams.update({'font.size': 20})

plt.scatter(points[:,0],
            points[:,1],
            c=distances, s=15, alpha = 0.5,cmap='jet')
plt.xlim((-1,1))
plt.ylim((-1,1))
#plt.colorbar(label="Distance to cloud center",orientation='vertical',shrink = 0.5,location='left')
#plt.title("Synthetic dataset in 3d")

#plt.savefig(f'{Path_pictures}/synthetic_2d_ideal.png',bbox_inches='tight',format='png')
#plt.savefig(f'{Path_pictures}/synthetic_2d_ideal.png',bbox_inches='tight',format='png')

plt.show()

In [None]:
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 24
latent_labels_distances = torch.cat((points,labels.unsqueeze(1),distances.unsqueeze(1)),dim=1)
my_dataframe = pd.DataFrame(latent_labels_distances)
cmaps = ["jet","hsv","twilight"]
#cmaps = ["jet","plasma","twilight"]
#cmaps = ["jet","jet","jet"]
colorbar_locations = ["left","left","left"]
colorbar_orientations = ["vertical","vertical","vertical"]
colorbar_shrinks = [0.5,0.5,0.5]
colorbar_anchors = [(0.5,0.5),(0.5,0.5),(0.5,0.5)]

fig, ax = plt.subplots(figsize=(9,9),dpi=400)
#plt.title("Desired AE latent space for the Synthetic dataset")
for plane_idx in range(k):
    # d is the number of the last column. It contains labels, i.e. colors
    results_df = my_dataframe.loc[my_dataframe[d] == plane_idx]
    #select all columns but the labeling color
    latent_points_in_plane = torch.tensor(results_df.loc[:,results_df.columns!=d].values)
    p = ax.scatter( latent_points_in_plane[:,0], latent_points_in_plane[:,1], c=latent_points_in_plane[:,2], alpha=0.5, marker='o', edgecolor='none', cmap=cmaps[plane_idx])
    plt.yticks([])
    #ax.yticks([])
    fig.colorbar(p, label=f"Distance to the center of cloud {plane_idx}", 
                 orientation=colorbar_orientations[plane_idx],
                 shrink = colorbar_shrinks[plane_idx],
                 location = colorbar_locations[plane_idx],
                 pad = 0.25, anchor = colorbar_anchors[plane_idx])
plt.xlim((-1,1))
plt.ylim((-1,1))
fig.savefig(f'{Path_pictures}/ideal_synthetic_distance_to_means_3heatmaps.png',bbox_inches='tight',format='png')
fig.show()