# I. Train and test datasets

In [None]:
# Hyperparameters for dataset

#D = 784 #dimension
D = 3
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
d = 2
#n = 10**3 # num of points in each plane
shift_class = 0
intercl_var = 2


# Hyperparameters for data loaders
batch_size  = 16
split_ratio = 0.8

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
# adding path to the set generating package
import sys
sys.path.append('../') # have to go 1 level up

import ricci_regularization

In [None]:
import torch
import ricci_regularization


# Generate dataset
# via classes
torch.manual_seed(0) # reproducibility
my_dataset = ricci_regularization.SyntheticDataset(k=k,n=n,d=d,D=D,
                                    shift_class=shift_class, intercl_var=intercl_var)

train_dataset = my_dataset.create
#train_dataset = generate_dataset(D, k, n, shift_class=shift_class)
#train_dataset = ricci_regularization.generate_dataset(D, k, n, shift_class=shift_class)

m = len(train_dataset)
train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data , batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# II. Fit dimension reduction models

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

## II.1. TSNE check

In [None]:
# TNSE check on test set
synthetic_points = test_data[:][0].view(-1,D)

tsne   = TSNE(n_components=2, verbose=1, random_state=123)
z_test = tsne.fit_transform(synthetic_points.numpy())

In [None]:
plt.rcParams.update({'font.size': 20}) # makes all fonts on the plot be 20
plt.scatter( z_test[:,0], z_test[:,1], c=test_data[:][1], alpha=0.5 )
plt.title( "TSNE projection of the Synthetic data-set")
plt.show()

In [None]:
# TNSE check on train set
"""
synthetic_points = train_data[:][0].view(-1,28*28)

tsne   = TSNE(n_components=2, verbose=1, random_state=123)
z_train = tsne.fit_transform(synthetic_points.numpy())
"""

In [None]:
"""
plt.scatter( z_train[:,0], z_train[:,1], c=train_data[:][1], alpha=0.5 )
plt.title( "TSNE projection of train data")
plt.show()
"""

## II.2. Plots using Seaborn

In [None]:
"""
import pandas as pd 

# Format data
df = pd.DataFrame()
#df["y"] = labels.numpy()
df["y"] = test_data[:][1].numpy() #test_data[:][1] are labels
df["comp-1"] = z_test[:,0]
df["comp-2"] = z_test[:,1]

import seaborn as sns
import numpy as np # this module is useful to work with numerical arrays
sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(),
                palette=sns.color_palette("hls", 10),
                data=df).set(title="Synthetic dataset data T-SNE projection")
"""

# III. 3D visualization

In [None]:
if D==3:
    points = test_data[:][0].squeeze()
    labels = test_data[:][1]
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    plt.rcParams.update({'font.size': 20})
    ax.scatter(points[:,0],
               points[:,1],
               points[:,2],
               c=labels, s=30, alpha = 0.5,cmap=plt.cm.get_cmap('jet', k))
    plt.title("Synthetic dataset in 3d")
    ax.view_init(azim=160, elev=55)
    plt.show()