In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

device = torch.device('xpu')

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to PyTorch Tensor
    transforms.Normalize((0.1307,), (0.3081,)) # Normalize pixel values
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

### Simple Auto encoder

In [100]:


256  / 2

128.0

In [106]:
class MnistAutoEncoder(nn.Module):
  def __init__(self):
    super(MnistAutoEncoder, self).__init__()
    self.encoder = nn.Sequential(
      nn.Linear(in_features=784, out_features=512, bias=True),
      nn.LayerNorm(512),
      nn.ReLU(),
      nn.Linear(in_features=512, out_features=256, bias=True),
      nn.LayerNorm(256),
      nn.ReLU(),
      nn.Linear(in_features=256, out_features=128, bias=True),
      nn.LayerNorm(128),
      nn.ReLU(),
      nn.Linear(in_features=128, out_features=32, bias=True),
      nn.LayerNorm(32),
      nn.ReLU(),
    )
    self.decoder = nn.Sequential(
      nn.Linear(in_features=32, out_features=128, bias=True),
      nn.LayerNorm(128),
      nn.ReLU(),
      nn.Linear(in_features=128, out_features=256, bias=True),
      nn.LayerNorm(256),
      nn.ReLU(),
      nn.Linear(in_features=256, out_features=512, bias=True),
      nn.LayerNorm(512),
      nn.ReLU(),
      nn.Linear(in_features=512, out_features=784, bias=True),
      nn.LayerNorm(784),
      nn.ReLU(),
    )
  
  def forward(self, X):
    X_embed = self.encoder(X)
    X_restored = self.decoder(X_embed)
    return X_embed, X_restored
  

mae = MnistAutoEncoder()
mae.to(device)
print('number of parameters: ', sum([p.numel() for p in mae.parameters()]))

number of parameters:  1146512


In [107]:
lr = 1e-2
optimizer = optim.SGD(mae.parameters(), lr=lr, weight_decay=1e-4)
criterion = nn.MSELoss()

#training loop
for epoch in range(5):
  for batch_idx, batch in enumerate(train_loader):
    X, y = batch
    X = X.flatten(start_dim=1)
    X = X.to(device)
    batch_emb, batch_restored = mae(X)
    
    optimizer.zero_grad()
    loss = criterion(batch_restored, X)
    loss.backward()
    optimizer.step()

    if batch_idx % 400 == 0:
      print(f'epoch {epoch}, batch {batch_idx}: loss = {loss.item()}')


epoch 0, batch 0: loss = 1.492504358291626
epoch 0, batch 400: loss = 0.7581777572631836
epoch 0, batch 800: loss = 0.7430914640426636
epoch 1, batch 0: loss = 0.7466987371444702
epoch 1, batch 400: loss = 0.6881937980651855
epoch 1, batch 800: loss = 0.6563748717308044
epoch 2, batch 0: loss = 0.656299352645874
epoch 2, batch 400: loss = 0.6139695644378662
epoch 2, batch 800: loss = 0.6040909886360168
epoch 3, batch 0: loss = 0.5591377019882202
epoch 3, batch 400: loss = 0.5295407176017761
epoch 3, batch 800: loss = 0.5612890720367432
epoch 4, batch 0: loss = 0.5158297419548035
epoch 4, batch 400: loss = 0.4893154799938202
epoch 4, batch 800: loss = 0.49156373739242554


In [None]:
def compare_cos_similarity(model, sampling_dataloader):
  x, y = next(iter(sampling_dataloader))
  x = x.flatten(start_dim=1)
  x, y = x.to(device), y.to(device)
  x_emb, _ = model(x)
  x_emb_norm = F.normalize(x_emb, p=2.0, dim=1)
  cos_sim = x_emb_norm @ x_emb_norm.T
  cos_sim = cos_sim.cpu()
  y_cpu = y.cpu()

  print('class    |  same class  |  diff class  |  ratio (same / diff)')
  print('-------------------------------------------------------------')
  for label in set(y_cpu.numpy()):
    idxs = (y_cpu == label).nonzero().squeeze(1)
    not_idxs = (y_cpu != label).nonzero().squeeze(1)
    mean_dot_prod_sharedclass = cos_sim[idxs][:, idxs].mean()
    mean_dot_prod_diffclass = cos_sim[idxs][:, not_idxs].mean()
    ratio = mean_dot_prod_sharedclass / mean_dot_prod_diffclass
    print(f'class {label}        {mean_dot_prod_sharedclass:.2f}           {mean_dot_prod_diffclass:.2f}          {ratio:.2f}')

In [109]:
compare_cos_similarity(mae, test_loader)

class    |  same class  |  diff class  |  ratio (same / diff)
-------------------------------------------------------------
class 0        0.93           0.87          1.07
class 1        0.98           0.90          1.09
class 2        0.94           0.91          1.03
class 3        0.95           0.91          1.04
class 4        0.95           0.91          1.04
class 5        0.94           0.92          1.03
class 6        0.95           0.92          1.04
class 7        0.95           0.91          1.04
class 8        0.96           0.93          1.03
class 9        0.96           0.92          1.04
