In [None]:
D = 784 #dimension
k = 3 # num of 2d planes in dim D
n = 6*(10**3) # num of points in each plane
#n = 10**3 # num of points in each plane

# Train and test datasets

In [None]:
from torch.nn.functional import normalize
from tqdm import tqdm
import matplotlib.pyplot as plt # plotting library
import numpy as np # this module is useful to work with numerical arrays
import pandas as pd 
import plotly.express as px
import random 
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(0)

phi = [] #list of k ontonormal bases in k planes
for j in range(k):
    # creating random planes
    rand_vectors = torch.rand(D, 2)
    q, r = torch.linalg.qr(rand_vectors)
    phi.append(q)
#phi

#creating samples from normal distributions via torch distributions
data = []
for i in range(k):
    #m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2) + 10*i, torch.eye(2))
    #m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2) + i, torch.eye(2))
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))
    samples = m.sample(sample_shape=(n,)).T
    #samples = normalize(samples, p = 1, dim = 0)
    #data.append(normalize(torch.matmul(phi[i], samples)))
    data.append(torch.matmul(phi[i], samples))
data_tensor = torch.cat(data, dim=1)

data_tensor = data_tensor.T
data_tensor = data_tensor.reshape(k*n, 1, 28, 28)

labels_list = []
for i in range(k):
    labels_list.append(i*(torch.ones(n)))
labels = torch.cat(labels_list)

my_set = TensorDataset(data_tensor,labels)
train_dataset = my_set

train_transform = transforms.Compose([
transforms.ToTensor(),
])


train_dataset.transform = train_transform

m=len(train_dataset)

train_data, test_data = random_split(train_dataset, [int(m-m*0.2), int(m*0.2)])
batch_size=128

test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)

# test_data[:][0] will give the vectors of data without labels from the test part of the dataset

# TSNE check

In [None]:
from sklearn.manifold import TSNE
from numpy import reshape
import seaborn as sns
import pandas as pd  

In [None]:
#tsne check on test set
#synthetic_set = data_tensor.reshape(-1,28*28)
synthetic_set = test_data[:][0].view(-1,28*28)

tsne = TSNE(n_components=2, verbose=1, random_state=123)
z = tsne.fit_transform(synthetic_set.numpy())
df = pd.DataFrame()
#df["y"] = labels.numpy()
df["y"] = test_data[:][1].numpy() #test_data[:][1] are labels
df["comp-1"] = z[:,0]
df["comp-2"] = z[:,1]


In [None]:
sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(),
                palette=sns.color_palette("hls", 10),
                data=df).set(title="Synthetic dataset data T-SNE projection")

# Fully connected AE

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        #self.encoder = nn.Linear(input_dim, hidden_dim)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, hidden_dim, bias=True),
        )
    def forward(self, x):
        out = self.encoder(x)
        out = torch.sin(out)
        #out = torch.sigmoid(out)
        #out = F.leaky_relu(out)
        return out

class Decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim):
        super().__init__()
        #self.decoder = nn.Linear(hidden_dim, output_dim)
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, output_dim, bias=True),
            nn.ReLU()
        )
    def forward(self, x):
        out = self.decoder(x)
        #out = torch.sigmoid(out)
        return out

In [None]:
### Define the loss function
loss_fn = torch.nn.MSELoss()

### Define an optimizer (both for the encoder and the decoder!)
lr= 0.0001

### Set the random seed for reproducible results
torch.manual_seed(0)

### Initialize the two networks
d = 2

In [None]:
#model = Autoencoder(hidden_dim=hidden_dim)
encoder = Encoder(input_dim=784, hidden_dim=d)
decoder = Decoder(hidden_dim=d, output_dim=784)
params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optim = torch.optim.RMSprop(params_to_optimize, lr=lr, weight_decay=1e-03)

# Check if the GPU is available
#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
print(f'Selected device: {device}')

# Move both the encoder and the decoder to the selected device
encoder.to(device)
decoder.to(device)

In [None]:
def point_plot(encoder, mydata):
    import plotly.express as px
    s = encoder(mydata[:][0].view(-1,1,28*28)).detach().numpy()
    s = s.reshape(-1, 2)
    l = mydata[:][1].numpy().reshape(-1,1)
    s = np.concatenate((s,l),axis=1)
    myplot = px.scatter(s, x = s[:,0], y = s[:,1], color=s[:,2].astype(str), opacity=0.5)
    return myplot    

In [None]:
def train_batch(encoder, decoder, device, dataloader, loss_fn, optimizer, num_batches, batches_per_plot):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    mse_loss = []
    batch_idx = 0
    while (batch_idx < num_batches):
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
        for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
            #shaping the images properly
            image_batch = image_batch.view(-1,28*28)
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Evaluate loss

            loss = loss_fn(decoded_data, image_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Print batch loss
            #print('\t batch number: %f \t partial train loss (single batch): %f' %(batch_idx) % (loss.data))
            print('\t batch number: {:} \t partial train loss (single batch): {:.6}' .format(batch_idx, loss.data))
            #print(batch_idx)

            if (batch_idx % batches_per_plot == 0) & (batch_idx > 0):
                plot = point_plot(encoder, test_data)
                plot.show()

        
            mse_loss.append(float(loss.detach().cpu().numpy()))
            batch_idx += 1
            if batch_idx > num_batches:
                break

    return mse_loss

In [None]:
num_batches = 300
batches_per_plot = 50
mse_loss = train_batch(encoder, decoder, device, train_loader, loss_fn, optim, num_batches,batches_per_plot)

In [None]:
# Plot losses

plt.figure(figsize=(10,8))
plt.semilogy(mse_loss, label='Train_loss')
#plt.semilogy(diz_loss['train_loss'], label='Train_loss')
#plt.semilogy(diz_loss['train_loss'] - diz_loss['mse_loss'], label='Curv_loss')
plt.title('Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
#plt.grid()
plt.legend()
#plt.title('loss')
plt.show()

In [None]:
from tqdm import tqdm

In [None]:
encoded_samples = []
for sample in tqdm(train_dataset):
#for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        img = img.view(-1,28*28) # reshape the img
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)
encoded_samples = pd.DataFrame(encoded_samples)
encoded_samples

In [None]:
import plotly.express as px

px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', 
           color=encoded_samples.label.astype(str), opacity=0.5)