# I. Train and test datasets

In [None]:
# Hyperparameters for dataset

D = 784 #dimension
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
#n = 10**3 # num of points in each plane

batch_size  = 16
split_ratio = 0.2
shift_class = 0

# Set manual seed for reproducibility
# torch.manual_seed(0)

In [None]:
import torch
import torchvision

phi = [] #list of k ontonormal bases in k planes
for j in range(k):
    # creating random planes
    rand_vectors = torch.rand(D, 2)
    q, r = torch.qr(rand_vectors)
    phi.append(q)
#phi

#creating samples from normal distributions via torch distributions
data = []
for i in range(k):
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2) + shift_class*(i+1), torch.eye(2))
    #m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))
    samples = m.sample(sample_shape=(n,)).T
    #samples = normalize(samples, p = 1, dim = 0)
    #data.append(normalize(torch.matmul(phi[i], samples)))
    data.append(torch.matmul(phi[i], samples))
data_tensor = torch.cat(data, dim=1)

data_tensor = data_tensor.T
data_tensor = data_tensor.reshape(k*n, 1, 28, 28)

labels_list = []
for i in range(k):
    labels_list.append(i*(torch.ones(n)))
labels = torch.cat(labels_list)

my_set = torch.utils.data.TensorDataset(data_tensor,labels)
train_dataset = my_set

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])


train_dataset.transform = train_transform

m=len(train_dataset)

train_data, test_data = torch.utils.data.random_split(train_dataset, [int(m-m*split_ratio), int(m*split_ratio)])

test_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

## TSNE check

In [None]:
from sklearn.manifold import TSNE

# TNSE check on test set
#synthetic_set = data_tensor.reshape(-1,28*28)
synthetic_set = test_data[:][0].view(-1,28*28)

tsne = TSNE(n_components=2, verbose=1, random_state=123)
z = tsne.fit_transform(synthetic_set.numpy())

In [None]:
import pandas as pd 

# Format data
df = pd.DataFrame()
#df["y"] = labels.numpy()
df["y"] = test_data[:][1].numpy() #test_data[:][1] are labels
df["comp-1"] = z[:,0]
df["comp-2"] = z[:,1]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays


sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(),
                palette=sns.color_palette("hls", 10),
                data=df).set(title="Synthetic dataset data T-SNE projection")