In [2]:
import torch
import random
import torchvision
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import norm
import matplotlib.pyplot as plt
import numpy as np
import huffman
import math
import faiss
import sys
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
import warnings
warnings.filterwarnings('ignore')
sys.path.append('./mnist')
sys.path.append('./cifar10')
sys.path.append('./imagenet')

# MNIST Experiments: on ImageNet Codebook

In [3]:
# Basic definitions for MNIST inference

#  From my training code
random_seed = 1 


torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
apply_transform = transforms.Compose([transforms.Resize(32),transforms.ToTensor(),
                                      transforms.Normalize((0.1309,), (0.2893,))])
# Change the dataset folder to the proper location in a new system
testset = datasets.MNIST(root='../../dataset', train=False, download=True, transform=apply_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)
testloader2 = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)

class CNN_LeNet(nn.Module):
    def __init__(self):
        super(CNN_LeNet, self).__init__()
        # Define the net structure
        # This is the input layer first Convolution
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.fc1 = nn.Linear(400,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84, 10) 
    
    def forward(self, x): 
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 400)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = F.softmax(x,dim=1)
        return x

pretrained_model = "./mnist_v0.pt"
mnist_model = CNN_LeNet() 
mnist_model.load_state_dict(torch.load(pretrained_model))
num_class = 10
mnist_model.eval()

CNN_LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
# Lets check the kind of prediction the model is doing - the standard non symbolic infrence
def mnist_test_base_acc(model):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for testdata in testloader:
            X, y = testdata
            output = model.forward(X)
            init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            # If the initial prediction is wrong, dont bother attacking, just move on
            if init_pred.item() == y.item():
                #print("correct")
                correct += 1
            total += 1    
    return 100*round(correct/total, 4)

In [5]:
# Test accuracy of symbolic inference
def mnist_test_sym_acc(model,sym_mnist,n_clusters, index,  patch_size, stride, channel_count, instr=False):
    correct = 0 
    total = 0 
    centroid_lut = index.reconstruct_n(0, n_clusters)
    if instr:
        pdf = np.zeros((n_clusters,), dtype=int)
    model.eval()
    with torch.no_grad():
        for data in testloader:
            X, y = data
            if instr:
                Xsym_, pdf = sym_mnist(X.squeeze(), n_clusters, index, centroid_lut, pdf, patch_size, stride, channel_count)
            else:
                Xsym_ = sym_mnist(X.squeeze(), n_clusters, index, centroid_lut, patch_size, stride, channel_count)
            Xsym = torch.from_numpy(Xsym_)
            Xsym = Xsym.unsqueeze(0)
            output = model.forward(Xsym.float())
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total += 1
    if instr:
        return round(correct/total, 4), pdf
    else:
        return round(correct/total, 4)

In [7]:
# Standard non symbolic inference
acc = mnist_test_base_acc(mnist_model)
print("Non Symbolic test accuracy:{}% ".format(acc))

Non Symbolic test accuracy:99.03% 


In [11]:
# Load the different codebooks & perform symbolic inference
# Start with 512
from  patchutils_mnist import  fm_to_symbolic_fm as sym_mnist

index = faiss.read_index("./kmeans_img_mnist_k2_s0_c512_v0.index")
n_clusters=512
patch_size = (2, 2)
channel_count = 1
repeat = 2
location=False
stride = 0
acc = mnist_test_sym_acc(mnist_model, sym_mnist, n_clusters, index,  patch_size, stride, channel_count)
print("Symbolic test accuracy (codebook 512):{}% ".format(100*acc))

Symbolic test accuracy (codebook 512):99.02% 


In [13]:
# Load the imagenet codebooks & try symbolic inference
# Start with 2048
from  patchutils_mnist import  fm_to_symbolic_fm as sym_mnist
index = faiss.read_index("../imagenet/kmeans_img_imgnet_k2_s0_c2048_v0.index")
n_clusters= 2048
patch_size = (2, 2)
channel_count = 1
repeat = 2
location=False
stride = 0
acc = mnist_test_sym_acc(mnist_model, sym_mnist, n_clusters, index,  patch_size, stride, channel_count)
print("Symbolic test accuracy (codebook 2048):{}% ".format(100*acc))

Symbolic test accuracy (codebook 2048):99.07000000000001% 
