# MNIST Embeddings

## Imports and Hyper-parameters 

In [93]:
import torch 
import os 
import numpy as np 
from torchvision import transforms 
from torchvision.datasets import MNIST 
from torchvision import models 

print(os.getcwd())
root = os.getcwd()

from mnist_net import Net

from sklearn.ensemble import IsolationForest
import plotly.express as px
import pandas as pd

/home/praveens/Desktop/synthetic_biometrics/visualize_embeddings


In [94]:
batch_size = 128
data_folder = root + '/MNIST'
device = 'cpu'
model_path = './mnist_net.pth'

merge_class1 = 2
merge_class2 = 4

## Create Dataloader

In [95]:
transformations = transforms.Compose([transforms.Resize((221, 221)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485], std=[0.229])
                                    ])

mnist_data = MNIST(root=r'./MNIST',
                   download=False, # change to True to download MNIST data
                   train=False, 
                   transform=transformations)

data_loader = torch.utils.data.DataLoader(mnist_data,
                                          batch_size=batch_size,
                                          shuffle=True)

In [96]:
mnist_data

Dataset MNIST
    Number of datapoints: 10000
    Root location: ./MNIST
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=(221, 221), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485], std=[0.229])
           )

In [97]:
# list(mnist_net.named_parameters())

## Initialize Tensorboard Logging Class 

In [98]:
mnist_net = Net().to(device)
mnist_net.load_state_dict(torch.load(model_path, map_location=torch.device(device)))  # load saved model 

mnist_net.eval()

Net(
  (conv1): Conv2d(1, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
  (conv4): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=16, bias=True)
  (fc4): Linear(in_features=16, out_features=3, bias=True)
  (fc5): Linear(in_features=3, out_features=10, bias=True)
)

In [99]:
def filter_by_classes(data_loader, cl1, cl2):
    
    class1 = torch.tensor(cl1)
    class2 = torch.tensor(cl2)
    
    class1_images = []
    class2_images = []
    
    for images, labels in data_loader: 
        class1_images.extend(images[labels == class1])
        class2_images.extend(images[labels == class2])
        
    class1_images = torch.stack(class1_images)
    class2_images = torch.stack(class2_images)
        
    return class1_images, class2_images

In [100]:
def find_class_embeddings(model, data_loader, cl1, cl2, remove_outliers=True):
    images1, images2 = filter_by_classes(data_loader, cl1, cl2)
    
    e1 = model.embedding(images1).clone().detach()
    e2 = model.embedding(images2).clone().detach()
    
    if remove_outliers:
        model_if = IsolationForest(random_state=42)
        model_if.fit(e1.numpy())
        e1_anomaly = model_if.predict(e1)

        model_if.fit(e2.numpy())
        e2_anomaly = model_if.predict(e2)

        # keeping the inliers 
        e1 = e1[e1_anomaly==1]
        print('inliers for ' + str(cl1) + ' : ', e1.shape[0])
        e2 = e2[e2_anomaly==1]
        print('inliers for ' + str(cl2) + ' : ', e2.shape[0])
    
    ################## STEP 1: finding centroid of the 2 class embeddings  
    class1_mean = torch.mean(e1, dim=0)    # squish multiple rows to get mean 
    class2_mean = torch.mean(e2, dim=0)

    class1_normed = class1_mean/torch.norm(class1_mean)
    class2_normed = class2_mean/torch.norm(class2_mean)
    
    print('val of centroid of class ' + str(cl1) + ' : ', class1_normed)
    print('val of centroid of class ' + str(cl2)  + ' : ', class2_normed)
    print('cosine similarity : ', (torch.dot(class1_normed, class2_normed)/torch.linalg.norm(class1_normed)*torch.linalg.norm(class2_normed)))
    
    return class1_normed, class2_normed

In [101]:
x1, x2 = find_class_embeddings(mnist_net, 
                               data_loader, 
                               merge_class1, 
                               merge_class2)

inliers for 2 :  896
inliers for 4 :  841
val of centroid of class 2 :  tensor([ 0.2796,  0.1175, -0.9529])
val of centroid of class 4 :  tensor([-0.3523, -0.3211,  0.8791])
cosine similarity :  tensor(-0.9739)


In [102]:
################## Step 3: calculate change of basis matrix
def change_of_basis(wm, d):
    ################## The first basis is d, how do we get the other 2 ?    

    new_wm = torch.eye(wm.shape[0])
    
    # replace first column of identity matrix with d 
    new_wm[:, 0] = d
    
    print('new weight matrix : ', new_wm)

    return new_wm

In [103]:
def perform_weight_surgery(mnist_net, x1, x2, save_model=True):
    
    ################## Step 2: calculating difference vector of the 2 classes 
    # calculate the difference between the 2 class vectors 
    d = x1 - x2  
    print('d : ', d)
    d = d/torch.norm(d)
    print('normed d : ', d)
    weight_matrix = mnist_net.fc4.weight.clone().detach()  # weight matrix of the penultimate layer 
    print('weight matrix : \n', weight_matrix)
    
    ################## Step 3: calculate change of basis matrix 
    # perform a change of basis, to put d as the first basis 
    new_wm = change_of_basis(weight_matrix, d)
    
    # gram schmidt to calculate orthogonal basis vectors 
    def gram_schmidt(weight_matrix):
        ################## d is first basis vector
        m, n = weight_matrix.shape
        
        Q = torch.zeros((m, n))

        for i in range(n):

            v = weight_matrix[:, i].clone()

            for j in range(i):
                proj = (v @ torch.outer(Q[:, j], Q[:, j]))
                v = v - proj
                        
            Q[:, i] = v/torch.linalg.norm(v)
            
        return Q
    
    ################## Step 4: using Gram-Schmidt to get orthogonal basis vectors 
    U = gram_schmidt(new_wm)    
    
    print('U matrix: ', U)
    print('result of U * U_transpose : ', torch.mm(U, U.T))
#     print('result of U * U_inv : ', U @ torch.inverse(U))
    
    S = torch.eye(U.shape[0])
    S[0, 0] = 1e-5
#     print('matrix S :\n', S)
    
    # projection matrix 
#     P_d = torch.mm(torch.mm(U, S), U.T)
    P_d = U @ S @ U.T

    print('projection matrix : ', P_d)
    
# #     modified_weight_matrix = torch.mm(P_d, weight_matrix)
    modified_weight_matrix = P_d @ weight_matrix
    
    print('modified weight matrix :\n', modified_weight_matrix)
    
    print('rank of the modified weights matrix : ', torch.linalg.matrix_rank(modified_weight_matrix))
    
    mnist_net.fc4.weight = torch.nn.Parameter(modified_weight_matrix)
    
    if save_model: 
        path = './modified_mnist_net.pth'
        torch.save(mnist_net.state_dict(), path)

In [104]:
perform_weight_surgery(mnist_net, x1, x2, save_model=True)

d :  tensor([ 0.6320,  0.4386, -1.8319])
normed d :  tensor([ 0.3181,  0.2207, -0.9220])
weight matrix : 
 tensor([[-0.4793, -0.3378,  0.2388,  0.2235,  0.0479,  0.6634, -0.3005,  0.0974,
          0.1657,  0.2789,  0.3297,  0.6865, -0.0954, -0.6407,  0.2100, -0.5907],
        [ 0.7700,  0.2041,  0.5233, -0.4937,  0.1929, -0.2195, -0.1123, -0.1531,
         -0.3558, -0.4096,  0.4457, -0.1805,  0.1870, -0.3620,  0.4298, -0.7580],
        [-0.0495, -0.4934, -0.1103, -0.4222, -0.3907,  0.5165,  0.0474, -0.0977,
          0.4707, -0.1545,  0.1174, -0.3732,  0.8841,  0.2384, -0.3835, -0.5321]])
new weight matrix :  tensor([[ 0.3181,  0.0000,  0.0000],
        [ 0.2207,  1.0000,  0.0000],
        [-0.9220,  0.0000,  1.0000]])
U matrix:  tensor([[ 0.3181, -0.0720,  0.9453],
        [ 0.2207,  0.9753,  0.0000],
        [-0.9220,  0.2087,  0.3261]])
result of U * U_transpose :  tensor([[ 1.0000e+00,  0.0000e+00, -2.7936e-07],
        [ 0.0000e+00,  1.0000e+00,  1.4901e-08],
        [-2.7936e-07

In [105]:
original_mnist_net = Net().to(device)
original_mnist_net.load_state_dict(torch.load('./mnist_net.pth', map_location=torch.device(device)))  # load saved model 

original_mnist_net.eval()

Net(
  (conv1): Conv2d(1, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
  (conv4): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=16, bias=True)
  (fc4): Linear(in_features=16, out_features=3, bias=True)
  (fc5): Linear(in_features=3, out_features=10, bias=True)
)

In [106]:
_, _ = find_class_embeddings(original_mnist_net, data_loader, merge_class1, merge_class2, remove_outliers=False)

val of centroid of class 2 :  tensor([ 0.2778,  0.1143, -0.9538])
val of centroid of class 4 :  tensor([-0.3507, -0.3221,  0.8793])
cosine similarity :  tensor(-0.9730)


In [107]:
_, _ = find_class_embeddings(mnist_net, data_loader, merge_class1, merge_class2, remove_outliers=False)

val of centroid of class 2 :  tensor([-0.3223, -0.8889, -0.3254])
val of centroid of class 4 :  tensor([-0.3029, -0.8974, -0.3208])
cosine similarity :  tensor(0.9998)


## View Unmodified Embeddings through Plotly

In [108]:
def remove_outliers(embeds, labels):
    unique_labels = set(labels.numpy())
    
    model_if = IsolationForest(random_state=42)
    
    new_embeds, new_labels = [], []
    
    for label_x in unique_labels:
        embed_x = embeds[labels==label_x]
        
        model_if.fit(embed_x)
        embedx_anomaly = model_if.predict(embed_x)
        
        embed_x = embed_x[embedx_anomaly == 1]   # keeping inliers 
        labels_x = [label_x] * embed_x.shape[0]
        
        new_embeds.extend(embed_x)
        new_labels.extend(labels_x)
        
    return np.stack(new_embeds), np.stack(new_labels)

In [109]:
batch_images, batch_labels = next(iter(data_loader))

embeds = original_mnist_net.embedding(batch_images.to(device), normalize=True)
embeds = embeds.detach().numpy()

# new_embeds, new_labels = remove_outliers(embeds, batch_labels)

In [110]:
df = pd.DataFrame({'x': embeds[:, 0], 
                   'y': embeds[:, 1], 
                   'z': embeds[:, 2], 
                   'label': batch_labels.numpy()})

fig = px.scatter_3d(df, x='x', y='y', z='z',
                      color='label', text='label')
fig.update_coloraxes(showscale=False)
fig.update_traces(marker=dict(size=5), 
                  textposition='middle center', 
                  textfont=dict(color='white', size=10))

# fig.write_image("./figs/mnist_clusters.pdf", scale=10)

fig.show()

## View Modified Embeddings through Plotly

In [111]:
batch_images, batch_labels = next(iter(data_loader))

embeds = mnist_net.embedding(batch_images.to(device), normalize=True)
embeds = embeds.detach().numpy()

# new_embeds, new_labels = remove_outliers(embeds, batch_labels)

In [112]:
def project_to_2D_unit_circle(embeds):
    
    # Find the normal vector of the plane
    u, s, vh = np.linalg.svd(embeds)
    normal = vh[-1]    # the last eigenvector is the one that doesn't vary, i.e. 
                       # normal to the plane where the points lie 

    # Project the points onto the plane
    points_proj = embeds - (embeds @ np.outer(normal, normal))  # embeds minus projection of embeds onto normal
                                                                # would give us embeds that lie on the 2D plane,
                                                                # which is perpendicular to the normal 

    # normalize to unit circle 
    points_norm = np.linalg.norm(points_proj, axis=1)
    points_proj /= points_norm[:, np.newaxis]         # using broadcasting to normalize 
    
    return points_proj

In [113]:
projected_points = project_to_2D_unit_circle(embeds)
df = pd.DataFrame({'x': projected_points[:, 0], 
                   'y': projected_points[:, 1], 
                   'z': projected_points[:, 2], 
                   'label': batch_labels.numpy()})

fig = px.scatter_3d(df, x='x', y='y', z='z',
                      color='label', text='label')
fig.update_coloraxes(showscale=False)
fig.update_traces(marker=dict(size=5), 
                  textposition='middle center', 
                  textfont=dict(color='white', size=12))

camera = dict(
    up=dict(x=0, y=0, z=1),
    center=dict(x=0., y=0., z=0.),
    eye=dict(x=0.8, y=0.8, z=-1.2)
)

fig.update_layout(scene_camera=camera)
# fig.write_image("./figs/mc_plot.pdf", scale=5)
fig.show()

## Accuracy over test images

In [114]:
def accuracy_over_test_images(mnist_net, tag='original'):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device); labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = mnist_net(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the ' +  tag +  ' network on the '+ str(total) +' test images: '+ str(100 * correct // total)+ '%')

In [115]:
accuracy_over_test_images(mnist_net, tag='modified')

Accuracy of the modified network on the 10000 test images: 77%


In [980]:
# accuracy_over_test_images(original_mnist_net)
# Accuracy of the original network on the 10000 test images: 97%