# MNIST Embeddings

## Imports and Hyper-parameters 

In [1]:
import torch 
import os 
import numpy as np 
from torchvision import transforms 
from torchvision.datasets import MNIST 
from torchvision import models 

print(os.getcwd())
root = os.getcwd()

from DeepFeatures import DeepFeatures
from mnist_net import Net

/home/praveens/Desktop/synthetic_biometrics/visualize_embeddings


In [2]:
batch_size = 128
data_folder = root + '/MNIST'
device = 'cpu'
model_path = './mnist_net.pth'

merge_class1 = 1
merge_class2 = 0

## Create Dataloader

In [3]:
transformations = transforms.Compose([transforms.Resize((221, 221)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485], std=[0.229])
                                    ])

mnist_data = MNIST(root=r'./MNIST',
                   download=False, # change to True to download MNIST data
                   train=False, 
                   transform=transformations)

data_loader = torch.utils.data.DataLoader(mnist_data,
                                          batch_size=batch_size,
                                          shuffle=True)

In [4]:
# list(mnist_net.named_parameters())

## Initialize Tensorboard Logging Class 

In [5]:
mnist_net = Net().to(device)
mnist_net.load_state_dict(torch.load(model_path, map_location=torch.device(device)))  # load saved model 

mnist_net.eval()

Net(
  (conv1): Conv2d(1, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
  (conv4): Conv2d(32, 64, kernel_size=(2, 2), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
  (fc4): Linear(in_features=3, out_features=3, bias=True)
  (fc5): Linear(in_features=3, out_features=10, bias=True)
)

In [6]:
deep_features = DeepFeatures(model=mnist_net)

In [7]:
def find_class_embeddings(deep_features, data_loader, class1, class2):
    images, labels = next(iter(data_loader))
    e = deep_features.generate_embeddings(images).clone().detach()
    
    ################## STEP 1: finding centroid of the 2 class embeddings  
    x1 = None
    class1 = torch.tensor(class1)
    class1_count = 0

    x2 = None
    class2 = torch.tensor(class2)
    class2_count = 0

    ################## Should centroid be found by summing and averaging ?
    for index, (image, label) in enumerate(zip(images, labels)):
        if class1_count == 1 and class2_count == 1:
            break 
        
        if (not class1_count) and label == class1:
            x1 = e[index]/torch.norm(e[index])
            class1_count += 1
        if (not class2_count) and label == class2:
            x2 = e[index]/torch.norm(e[index])
            class2_count += 1
    
    print('val of class ' + str(class1.numpy()) + ' : ', x1)
    print('val of class '  + str(class2.numpy()) + ' : ', x2)
    print('cosine similarity : ', (np.dot(x1, x2)/np.linalg.norm(x1)*np.linalg.norm(x2)))
    
    return x1, x2

In [8]:
x1, x2 = find_class_embeddings(deep_features, 
                               data_loader, 
                               merge_class1, 
                               merge_class2)
x1, x2

val of class 1 :  tensor([ 0.2814,  0.0522, -0.9582])
val of class 0 :  tensor([0.7092, 0.5632, 0.4241])
cosine similarity :  -0.1773846


(tensor([ 0.2814,  0.0522, -0.9582]), tensor([0.7092, 0.5632, 0.4241]))

In [9]:
################## Step 3: calculate change of basis matrix
def change_of_basis(wm, d):
    ################## The first basis is d, how do we get the other 2 ?
    combs = [list(range(wm.shape[1]-1)), 
             list(range(1, wm.shape[1]))]
    
    print('combinations : ', combs)
    
    det_threshold = 1e-1
    new_wm = None 

    for c in combs:
        curr_det = np.linalg.det(np.column_stack((d, wm[:, c])))
        
        # skip in case determinant is close to 0
        if abs(curr_det) < det_threshold:
            continue 
        print('det : ', curr_det)
        new_wm = np.column_stack((d, wm[:, c]))

    return torch.tensor(new_wm)

In [10]:
def perform_weight_surgery(mnist_net, x1, x2, save_model=True):
    
    ################## Step 2: calculating difference vector of the 2 classes 
    # calculate the difference between the 2 class vectors 
    d = (x2/torch.norm(x2)) - (x1/torch.norm(x1))  
    print('d : ', d)
    d = d/torch.norm(d)
    print('normed d : ', d)
    weight_matrix = mnist_net.fc4.weight.clone().detach()  # weight matrix of the penultimate layer 
    print('weight matrix : \n', weight_matrix)
    
    ################## Step 3: calculate change of basis matrix 
    # perform a change of basis, to put d as the first basis 
    new_wm = change_of_basis(weight_matrix, d)
    print('new weight matrix : \n', new_wm)
    
    # gram schmidt to calculate orthogonal basis vectors 
    def gram_schmidt(weight_matrix):
        ################## d is first basis vector
        new_basis_vectors = [weight_matrix[:, 0]]
        m, n = weight_matrix.shape

        for i in range(1, n):

            a_i = weight_matrix[:, i].clone()
            q_i = 0

            for j, q_j in enumerate(new_basis_vectors):
                q_i -= (torch.dot(q_j, a_i) * q_j)
            
            q_i += a_i
            q_i = q_i / torch.norm(q_i)
            
            new_basis_vectors.append(q_i)
            
        return new_basis_vectors
    
    ################## Step 4: using Gram-Schmidt to get orthogonal basis vectors 
    new_basis_vectors = gram_schmidt(new_wm)    # 
    
    U = [nbv.reshape(-1, 1) for nbv in new_basis_vectors]
    U = torch.cat(U, axis=1)
    print('result of U * U_transpose : ', U @ U.T)
    print('result of U * U_inv : ', U @ torch.inverse(U))
    
#     U, _ = torch.qr(new_wm)
    print('unitary matrix :\n', U)
#     print('QR decomp :\n', Q)
    
    S = torch.eye(U.shape[0])
    S[0, 0] = 0
    print('projection matrix :\n', S)
    
    print('modified weight matrix :\n', weight_matrix @ U @ S @ torch.inverse(U))
    
    mnist_net.fc4.weight = torch.nn.Parameter(weight_matrix @ U @ S @ torch.inverse(U))
    
    if save_model: 
        path = './modified_mnist_net.pth'
        torch.save(mnist_net.state_dict(), path)

In [11]:
perform_weight_surgery(mnist_net, x1, x2, save_model=True)

d :  tensor([0.4278, 0.5110, 1.3823])
normed d :  tensor([0.2788, 0.3330, 0.9008])
weight matrix : 
 tensor([[ 0.3257, -1.0893,  0.8606],
        [ 0.0483,  0.1397,  0.5508],
        [-1.0913,  0.2588,  0.4324]])
combinations :  [[0, 1], [1, 2]]
det :  0.77067333
det :  -0.6762598
new weight matrix : 
 tensor([[ 0.4278, -1.0893,  0.8606],
        [ 0.5110,  0.1397,  0.5508],
        [ 1.3823,  0.2588,  0.4324]])
result of U * U_transpose :  tensor([[1.1409, 0.0847, 0.5542],
        [0.0847, 0.2809, 0.7441],
        [0.5542, 0.7441, 2.9330]])
result of U * U_inv :  tensor([[ 1.0000e+00, -7.3592e-09,  6.1799e-09],
        [ 3.6311e-11,  1.0000e+00,  2.7344e-09],
        [-3.3433e-08, -1.3429e-07,  1.0000e+00]])
unitary matrix :
 tensor([[ 4.2779e-01, -9.5126e-01, -2.3024e-01],
        [ 5.1098e-01,  1.4054e-01,  8.3702e-04],
        [ 1.3823e+00,  2.7452e-01, -9.7313e-01]])
projection matrix :
 tensor([[0., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
modified weight matrix :
 t

In [12]:
original_mnist_net = Net().to(device)
original_mnist_net.load_state_dict(torch.load('./mnist_net.pth', map_location=torch.device(device)))  # load saved model 

original_mnist_net.eval()

original_deep_features = DeepFeatures(model=original_mnist_net)

In [13]:
_, _ = find_class_embeddings(original_deep_features, data_loader, merge_class1, merge_class2)

val of class 1 :  tensor([ 0.2812,  0.0526, -0.9582])
val of class 0 :  tensor([0.6185, 0.6097, 0.4957])
cosine similarity :  -0.268965


In [14]:
_, _ = find_class_embeddings(deep_features, data_loader, merge_class1, merge_class2)

val of class 1 :  tensor([ 0.1037, -0.1388, -0.9849])
val of class 0 :  tensor([0.6243, 0.4941, 0.6051])
cosine similarity :  -0.5997719


## Write Modified Embeddings to Tensorboard

In [15]:
deep_features.create_tensorboard_dirs()

In [16]:
batch_images, batch_labels = next(iter(data_loader))
deep_features.write_embeddings(x=batch_images.to(device))

True

In [17]:
deep_features.create_tensorboard_log()

torch.Size([128, 3])
torch.Size([128, 1, 28, 28])


  all_embeds = torch.Tensor(all_embeds)


## Write Unmodified Embeddings to Tensorboard

In [18]:
original_deep_features.create_tensorboard_dirs(model_type='original')

In [19]:
batch_images, batch_labels = next(iter(data_loader))
original_deep_features.write_embeddings(x=batch_images.to(device))

True

In [20]:
original_deep_features.create_tensorboard_log()

torch.Size([128, 3])
torch.Size([128, 1, 28, 28])


## Accuracy over test images

In [21]:
def accuracy_over_test_images(mnist_net, tag='original'):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device); labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = mnist_net(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the ' +  tag +  ' network on the '+ str(total) +' test images: '+ str(100 * correct // total)+ '%')

In [22]:
accuracy_over_test_images(mnist_net, tag='modified')

Accuracy of the modified network on the 10000 test images: 40%


In [23]:
accuracy_over_test_images(original_mnist_net)

Accuracy of the original network on the 10000 test images: 96%


In [24]:
# 1, 0 split (accuracy 45%)
# 1, 2 split (accuracy 47%)
# 1, 3 split (accuracy 46%)
# 1, 4 split (accuracy 36%)
# 1, 5 split (accuracy 39%)
# 1, 6 split (accuracy 52%) had to run twice 
# 1, 7 split (accuracy Not Possible %)
# 1, 8 split (accuracy  Not possible %)
# 1, 9 split (accuracy  37%)