In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
import torchvision.transforms as T
from torchvision.utils import make_grid
from torchvision.utils import save_image
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler

import time
import os
import shutil
import copy
import sys

import PIL
from IPython.display import Image
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans, MiniBatchKMeans
from statistics import mean
from collections  import OrderedDict
import numpy as np
from skimage import io, transform
import random
import scipy
import cv2
from math import floor, ceil

# !pip install torchinfo
from torchinfo import summary

%matplotlib inline

In [2]:
IMG_SIZE = 32
N_CLASSES = 10
PATH = r"C:\Users\ameyv\CNN_Accelerators\LeNet_CIFAR.pth"

In [None]:
mult_counts = torch.zeros(3)

In [None]:
class LeNet5(nn.Module):

    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.feature_extractor = nn.Sequential(            
            nn.SortConv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.SortConv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.SortConv2D(in_channels=16, out_channels=84, kernel_size=5, stride=1),
            nn.ReLU()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=84, out_features=64),
            nn.Tanh(),
            nn.Linear(in_features=64, out_features=n_classes),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probs = F.softmax(logits, dim=1)
        return logits, probs
    
    def compute_conv(x, in_channels, kernel_size, centroids, labels, weights, r, c):
        x_out_cell = 0
        mult_count = 0
        act_clusters = torch.zeros(1,num_clusters)
        for i in range(width*depth*height):
            act_clusters[0][labels[i]] += img[floor(i/(height*width))][r+floor((i%(width*height))/width)][c+floor(i%width)]
        for it in range(width):
            img_out_cell += act_clusters[0][it]*centroids[it]
            conv_mult_count += 1
        return x_out_cell, mult_count

    def compute_filter_conv(x, h, w, in_channels, kernel_size, centroids, labels, weights, bias, kernel_id, w_out, h_out, stride, padding):
        x_out_channel = torch.zeros(w_out,h_out)
        for r in range(0,h+2*padding-kernel_size,stride):
            for c in range(0,w+2*padding-kernel_size,stride):
                r_out = floor(r/stride)
                c_out = floor(c/stride)
                #print(r_out, c_out)
                x_out_channel[r_out][c_out], mult = self.compute_conv(x, in_channels, kernel_size, centroids, labels, weights, r, c)
                x_out_channel[r_out][c_out] += bias
        return x_out_channel, mult

    def SortConv2D(x, wt_str, b_str, in_channels, out_channels, kernel_size, stride, padding) -> torch.Tensor:
        #   x = x[0]
        #   print(x.shape)
        layer_mult_count = 0
        h = x.shape[1]
        w = x.shape[2]
        w_out = floor((w+2*padding-kernel_size)/stride+1)
        h_out = floor((h+2*padding-kernel_size)/stride+1)
        x_out = torch.zeros(out_channels, h_out, w_out)
        pad_transform = transforms.Pad(padding)
        x = pad_transform(x)
        for kernel_id in range(out_channels):
            weights = parameters[wt_str][kernel_id]
            bias = parameters[b_str][kernel_id]
            weight_list = compute_kmeans(weights, in_channels, out_channels, kernel_size, kernel_size)
            centroids, labels = compute_kmeans(parameters[wt_str], kernel_size)
            x_out[kernel_id], mult = self.compute_filter_conv(x, h, w, in_channels, kernel_size, centroids, labels, weights, bias, kernel_id, w_out, h_out, stride, padding)
            layer_mult_count += mult
        if wt_str == 'feature_extractor.0.weight':
            mult_counts[0] += layer_mult_count
        elif wt_str == 'feature_extractor.3.weight':
            mult_counts[1] += layer_mult_count
        else:    
            mult_counts[2] += layer_mult_count
        return x_out    

In [4]:
lenet = LeNet5(N_CLASSES)
lenet.load_state_dict(torch.load(PATH))
lenet.eval()

LeNet5(
  (feature_extractor): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Conv2d(16, 84, kernel_size=(5, 5), stride=(1, 1))
    (7): ReLU()
  )
  (classifier): Sequential(
    (0): Linear(in_features=84, out_features=64, bias=True)
    (1): Tanh()
    (2): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [5]:
parameters = lenet.state_dict()
for param_tensor in parameters:
    #print(type(alexnet.state_dict()[param_tensor]))
    print(param_tensor, "\t", parameters[param_tensor].size())

feature_extractor.0.weight 	 torch.Size([6, 1, 5, 5])
feature_extractor.0.bias 	 torch.Size([6])
feature_extractor.3.weight 	 torch.Size([16, 6, 5, 5])
feature_extractor.3.bias 	 torch.Size([16])
feature_extractor.6.weight 	 torch.Size([84, 16, 5, 5])
feature_extractor.6.bias 	 torch.Size([84])
classifier.0.weight 	 torch.Size([64, 84])
classifier.0.bias 	 torch.Size([64])
classifier.2.weight 	 torch.Size([10, 64])
classifier.2.bias 	 torch.Size([10])


In [6]:
print(parameters['feature_extractor.0.weight'][0])

tensor([[[ 0.0907,  0.1396, -0.0428,  0.0474, -0.1452],
         [ 0.1162,  0.0521,  0.1174,  0.0070, -0.2547],
         [ 0.3676,  0.2714,  0.3123,  0.1142,  0.0710],
         [ 0.1612,  0.3490,  0.2863,  0.1326,  0.0660],
         [-0.2164, -0.1052, -0.1675, -0.0691, -0.4119]]])


In [11]:
def compute_kmeans(kernel, num_cluster):    
    centroids = []
    labels = []
    num_cluster = kernel_size
    flattened_weights = kernel.reshape(-1)
    flattened_weights = flattened_weights.reshape(height*width*depth,1)
    kmeans = MiniBatchKMeans(n_clusters=num_cluster, batch_size=height*width, compute_labels=True)
    kmeans.fit(flattened_weights)
    centroids = kmeans.cluster_centers_
    labels = kmeans.labels_
#     print(labels)
#     print(len(labels))
#     print(type(labels))
#     print(labels.shape)
    return centroids, labels

In [None]:
centroids1, labels1 = compute_kmeans(parameters['feature_extractor.0.weight'], 5)
centroids2, labels2 = compute_kmeans(parameters['feature_extractor.3.weight'], 5)
centroids3, labels3 = compute_kmeans(parameters['feature_extractor.6.weight'], 5)

In [15]:
"""
def compute_conv(x, in_channels, kernel_size, weight_list, weights, r, c):
    x_out_cell = 0
    for k in range(in_channels):
        for i in range(kernel_size):
            for j in range(kernel_size):          
                if weights[k][i][j] > 0:
                    x_out_cell += (x[k][r+i][c+j]*weights[k][i][j])
    for tup in weight_list:
        x_out_cell += tup[0]*x[tup[1]][r+tup[2]][c+tup[3]]
        if x_out_cell < 0:
            break
    return x_out_cell

def compute_filter_conv(x, h, w, in_channels, kernel_size, weight_list, weights, bias, kernel_id, w_out, h_out, stride, padding):
    x_out_channel = torch.zeros(w_out,h_out)
    for r in range(0,h+2*padding-kernel_size,stride):
        for c in range(0,w+2*padding-kernel_size,stride):
            r_out = floor(r/stride)
            c_out = floor(c/stride)
            #print(r_out, c_out)
            x_out_channel[r_out][c_out] = self.compute_conv(x, in_channels, kernel_size, weight_list, weights, r, c)
            x_out_channel[r_out][c_out] += bias
    return x_out_channel

def SortConv2D(x, wt_str, b_str, in_channels, out_channels, kernel_size, stride, padding) -> torch.Tensor:
    #   x = x[0]
    #   print(x.shape)
    h = x.shape[1]
    w = x.shape[2]
    w_out = floor((w+2*padding-kernel_size)/stride+1)
    h_out = floor((h+2*padding-kernel_size)/stride+1)
    x_out = torch.zeros(out_channels, h_out, w_out)
    pad_transform = transforms.Pad(padding)
    x = pad_transform(x)
    for kernel_id in range(out_channels):
        weights = parameters[wt_str][kernel_id]
        bias = parameters[b_str][kernel_id]
        weight_list = self.compute_weights_list(weights, in_channels, out_channels, kernel_size, kernel_size)
        x_out[kernel_id] = self.compute_filter_conv(x, h, w, in_channels, kernel_size, weight_list, weights, bias, kernel_id, w_out, h_out, stride, padding)
    return x_out    
"""

'\ndef compute_conv(x, in_channels, kernel_size, weight_list, weights, r, c):\n    x_out_cell = 0\n    for k in range(in_channels):\n        for i in range(kernel_size):\n            for j in range(kernel_size):          \n                if weights[k][i][j] > 0:\n                    x_out_cell += (x[k][r+i][c+j]*weights[k][i][j])\n    for tup in weight_list:\n        x_out_cell += tup[0]*x[tup[1]][r+tup[2]][c+tup[3]]\n        if x_out_cell < 0:\n            break\n    return x_out_cell\n\ndef compute_filter_conv(x, h, w, in_channels, kernel_size, weight_list, weights, bias, kernel_id, w_out, h_out, stride, padding):\n    x_out_channel = torch.zeros(w_out,h_out)\n    for r in range(0,h+2*padding-kernel_size,stride):\n        for c in range(0,w+2*padding-kernel_size,stride):\n            r_out = floor(r/stride)\n            c_out = floor(c/stride)\n            #print(r_out, c_out)\n            x_out_channel[r_out][c_out] = self.compute_conv(x, in_channels, kernel_size, weight_list, weig

In [None]:
transform = transforms.Compose([transforms.Resize((32, 32)),
                                transforms.ToTensor(), 
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                transforms.Grayscale(num_output_channels=1), 
                               ])

train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)

test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(test_data, batch_size=4, shuffle=False, num_workers=2)

classes = ('Airplane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

In [None]:
correct_pred = 0 
n = 0
num_images = 10

with torch.no_grad():
    model.eval()
    for i in range(num_images):
        X = testloader[i][0]
        y = testloader[i][1]

        X = X.to(device)
        y_true = y_true.to(device)

        _, y_prob = model(X)
        _, predicted_labels = torch.max(y_prob, 1)

        n += y_true.size(0)
        correct_pred += (predicted_labels == y_true).sum()

test_accuracy = correct_pred.float() / n

In [None]:
mult_counts = mult_counts/num_images

In [7]:
print("Test accuracy for ", num_images, "images is: ", test_accuracy)
print("Averaged Multipication Counts: ", mult_counts)

Test accuracy for  10 images is:  40
Averaged Multipication Counts:  [20480, 8840, 440]
