<a href="https://colab.research.google.com/github/sebastienwood/MemristorQuant/blob/main/jonathan_AICAS2022.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import code

In [1]:
%cd /content
!git clone --recursive https://github.com/bearpaw/pytorch-classification.git

%cd /content/pytorch-classification/

/content
fatal: destination path 'pytorch-classification' already exists and is not an empty directory.
/content/pytorch-classification


In [2]:
'''Adapted from VGG for CIFAR10. FC layers are removed.
(c) YANG, Wei 
'''
import torch.nn as nn
import math


class VGG(nn.Module):

    def __init__(self, features, num_classes=1000):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Linear(512, num_classes, bias=False)
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    #forward pass but with only fully connected layers
    def forward_fc(self, x):
        channels = 64 #TODO: this is architecture specific
        width = 32 #TODO: this is architecture specific
        
        for layer in self.features:
            if isinstance(layer, nn.Linear):
                x = torch.nn.functional.pad(x, (1, 1, 1, 1)) #padding
                x = x.view(x.size(0), -1)
                x = layer(x)
                x = x.reshape(x.shape[0], min(512, int(channels)), int(width), int(width))
                channels*=2 #TODO: this is architecture specific
                width/=2 #TODO: this is architecture specific
            else:
                x = layer(x)

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                if m.bias is not None:
                    m.bias.data.zero_()


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'A':
            layers += [nn.AvgPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.Softplus()]
            else:
                layers += [conv2d, nn.Softplus()]
            in_channels = v
    return nn.Sequential(*layers)


cfg = {
    'small_vgg': [64, 'A', 128, 'A', 256, 'A', 512, 'A', 512, 'A'],
    'really_small_vgg': [16, 'A', 32, 'A', 64, 'A', 128, 'A', 128, 'A'],
}


def small_vgg(**kwargs):
    """Small VGG model
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = VGG(make_layers(cfg['small_vgg'], batch_norm=False), **kwargs)
    return model

In [3]:
#from https://github.com/pytorch/pytorch/issues/26781#issuecomment-821054668
def convmatrix2d(kernel, image_shape, padding=None):
    # kernel: (out_channels, in_channels, kernel_height, kernel_width, ...)
    # image: (in_channels, image_height, image_width, ...)

    if padding:
        assert padding[0] == padding[1]
        padding = padding[0]
        old_shape = image_shape
        pads = (padding, padding, padding, padding)
        image_shape = (image_shape[0], image_shape[1] + padding*2, image_shape[2]
                       + padding*2)
    else:
        image_shape = tuple(image_shape)

    assert image_shape[0] == kernel.shape[1]
    assert len(image_shape[1:]) == len(kernel.shape[2:])
    result_dims = torch.tensor(image_shape[1:]) - torch.tensor(kernel.shape[2:]) + 1
    m = torch.zeros((
        kernel.shape[0], 
        *result_dims, 
        *image_shape
    ))
    for i in range(m.shape[1]):
        for j in range(m.shape[2]):
            m[:,i,j,:,i:i+kernel.shape[2],j:j+kernel.shape[3]] = kernel
    return m.flatten(0, len(kernel.shape[2:])).flatten(1)

    # Handle zero padding. Effectively, the zeros from padding do not
    # contribute to convolution output as the product at those elements is zero.
    # Hence the columns of the conv mat that are at the indices where the
    # padded flattened image would have zeros can be ignored. The number of
    # rows on the other hand must not be altered (with padding the output must
    # be larger than without). So..

    # We'll handle this the easy way and create a mask that accomplishes the
    # indexing
    if padding:
        mask = torch.nn.functional.pad(torch.ones(old_shape), pads).flatten()
        mask = mask.bool()
        m = m[:, mask]

    return m

In [4]:
%cd /content/pytorch-classification/ 
'''
Training script for CIFAR-10/100
Copyright (c) Wei YANG, 2017
'''
from __future__ import print_function

import argparse
import os
import shutil
import time
import random
import sys

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import importlib
import models.cifar as models
from torchsummary import summary

from utils import Bar, Logger, AverageMeter, mkdir_p, savefig

arch = "small_vgg"
checkpoint = "/content/pytorch-classification/checkpoints/cifar10/" + arch

resume = "/content/pytorch-classification/checkpoints/cifar10/" + arch + "/model_best.pth.tar"
#resume = False
evaluate = True
conv_to_fc = True

dataset = "cifar10"
start_epoch = 0
# Training settings (from https://github.com/bearpaw/pytorch-classification/blob/24f1c456f48c78133088c4eefd182ca9e6199b03/TRAINING.md)
lr = 0.1
epochs = 164
decrease_lr_at_epochs = [81, 122]

prototyping = True
if prototyping:
  epochs = 1
  resume = False
  evaluate = False
  conv_to_fc = True

# Use CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
use_cuda = torch.cuda.is_available()
print(f'use_cuda={use_cuda}')
test_batch = 100
train_batch = 128
workers = 2

# Random seed
manualSeed = 2
random.seed(manualSeed)
torch.manual_seed(manualSeed)
if use_cuda:
    torch.cuda.manual_seed_all(manualSeed)
    scaler = torch.cuda.amp.GradScaler()
drop = 0

best_acc = 0  # best test accuracy

if not os.path.isdir(checkpoint):
    mkdir_p(checkpoint)
  

# Data
print('==> Preparing dataset %s' % dataset)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if dataset == 'cifar10':
    dataloader = datasets.CIFAR10
    num_classes = 10
    input_shape = (3, 32, 32)
else:
    dataloader = datasets.CIFAR100
    num_classes = 100
    input_shape = (3, 32, 32)


trainset = dataloader(root='./data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=train_batch, shuffle=True, num_workers=workers)

testset = dataloader(root='./data', train=False, download=True, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=test_batch, shuffle=False, num_workers=workers)

# Model
print("==> creating model '{}'".format(arch))
if arch == "small_vgg":
    model = small_vgg(num_classes=num_classes)

if use_cuda:
    model = torch.nn.DataParallel(model).cuda()
print(model)

cudnn.benchmark = True
print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)

# Resume
title = 'cifar-10-' + arch
if resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isfile(resume), 'Error: no checkpoint directory found!'
    checkpoint = os.path.dirname(resume)
    if use_cuda:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['state_dict'])
    else:
        checkpoint = torch.load(resume, map_location=torch.device('cpu'))
        rename_state_dict = checkpoint['state_dict'].copy()
        for key in checkpoint['state_dict'].keys():
            if "module." in key:
                rename_state_dict[key.replace('module.', '')] = checkpoint['state_dict'][key]
                rename_state_dict.pop(key)
        checkpoint['state_dict'] = rename_state_dict
        model.load_state_dict(checkpoint['state_dict'])

    best_acc = checkpoint['best_acc']
    start_epoch = checkpoint['epoch']
else:
    logger = Logger(os.path.join(checkpoint, 'log.txt'), title=title)
    logger.set_names(['Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])

if conv_to_fc:

    print("==> converting Conv2d to Linear")
    print(summary(model, input_shape))

    VGG.forward = VGG.forward_fc

    if use_cuda:
        layers = model.module.features.children()
        x = torch.zeros(input_shape).cuda()
    else:
        layers = model.features.children()
        x = torch.zeros(input_shape)
    x = x[None, :, :, :]

    current_input_shape = input_shape
    for ind_layer, layer in enumerate(layers):
        if isinstance(layer, nn.Conv2d):
            kernel = layer.weight
            conv_matrix = convmatrix2d(kernel, current_input_shape, layer.padding)

            if use_cuda:
                model.module.features[ind_layer] = nn.Linear(conv_matrix.shape[1], conv_matrix.shape[0], bias=False, device="cuda:0")
                model.module.features[ind_layer].weight.data = conv_matrix.cuda()
            else:
                model.features[ind_layer] = nn.Linear(conv_matrix.shape[1], conv_matrix.shape[0], bias=False)
                model.features[ind_layer].weight.data = conv_matrix
        x = layer(x)
        current_input_shape = tuple(x.shape[1:])
    
    print("==> converted Conv2d to Linear")            
    print(summary(model, input_shape))

    #TODO: iterate though the layers and update the weights here

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

@torch.no_grad()
def test(testloader, model, criterion, epoch, use_cuda):
    global best_acc

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar('Processing', max=len(testloader))

    for batch_idx, (inputs, targets) in enumerate(testloader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        #inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        # deprecated 

        # compute output
        if use_cuda:
          with torch.cuda.amp.autocast():
            outputs = model(inputs)
        else:
          outputs = model(inputs)
        loss = criterion(outputs, targets)
        if use_cuda:
          scaler.scale(loss)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
 
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(testloader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                    )
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)

if evaluate:
    print('\nEvaluation only')
    test_loss, test_acc = test(testloader, model, criterion, start_epoch, use_cuda)
    print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
    sys.exit()

def train(trainloader, model, criterion, optimizer, epoch, use_cuda):
    # switch to train mode
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    bar = Bar('Processing', max=len(trainloader))
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # measure data loading time
        data_time.update(time.time() - end)

        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        #inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        # deprecated 

        # compute output
        if use_cuda:
          with torch.cuda.amp.autocast():
            outputs = model(inputs)
        else:
          outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        if use_cuda:
          scaler.scale(loss).backward()
          scaler.step(optimizer)
          scaler.update()
        else:
          loss.backward()
          optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(trainloader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    total=bar.elapsed_td,
                    eta=bar.eta_td,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                    )
        bar.next()
    bar.finish()
    return (losses.avg, top1.avg)

def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        print("\nOld learning rate: ", param_group['lr'])
        param_group['lr'] *= 0.1
        print("\nNew learning rate: ", param_group['lr'])

# Train and val
if not conv_to_fc:
  for epoch in range(start_epoch, epochs):
      if epoch in decrease_lr_at_epochs:
          adjust_learning_rate(optimizer)

      print('\nEpoch: [%d | %d]' % (epoch + 1, epochs))

      train_loss, train_acc = train(trainloader, model, criterion, optimizer, epoch, use_cuda)
      test_loss, test_acc = test(testloader, model, criterion, epoch, use_cuda)

      print(' train acc: %d - test acc: %d' % (train_acc, test_acc))

      # append logger file
      logger.append([train_loss, test_loss, train_acc, test_acc])
      # save model
      is_best = test_acc > best_acc
      best_acc = max(test_acc, best_acc)
      save_checkpoint({
              'epoch': epoch + 1,
              'state_dict': model.state_dict(),
              'acc': test_acc,
              'best_acc': best_acc,
              'optimizer' : optimizer.state_dict(),
          }, is_best, checkpoint=checkpoint)

  logger.close()
  logger.plot()
  savefig(os.path.join(checkpoint, 'log.eps'))

  print('Best acc:')
  print(best_acc)
else:
  test_loss, test_acc = test(testloader, model, criterion, 0, use_cuda)
  print('Acc:')
  print(test_acc)



/content/pytorch-classification
use_cuda=True
==> Preparing dataset cifar10
Files already downloaded and verified
Files already downloaded and verified
==> creating model 'small_vgg'
DataParallel(
  (module): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): Softplus(beta=1, threshold=20)
      (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): Softplus(beta=1, threshold=20)
      (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (7): Softplus(beta=1, threshold=20)
      (8): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (9): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (10): Softplus(beta=1, threshold=20)
      (11): AvgPool2d(kernel_size=2, stride=2, padding=0)


In [None]:
from typing import Iterator
def net_param_iterator(model: nn.Module) -> Iterator:
  ignored = []
  for name, module in model.named_modules():
    if type(module) in [nn.Conv2d, nn.BatchNorm2d, nn.ReLU, nn.Linear]:
      #TODO preparse params ?
      yield module
    else:
      ignored.append(type(module))
  print(set(ignored))


class MemristorQuant(object):
  def __init__(self, model: nn.Module, types_handled = [nn.Linear], N: int = 100, wmax_mode='all', Gmax=0.1, std_noise:float=1.) -> None:
      super().__init__()
      self.model = model
      self.saved_params = []
      self.actual_params = [] 
      self.intermediate_params = {}
      for m in model.modules():
        if type(m) in types_handled:
          self.saved_params.append(m.weight.data.clone().cpu())
          self.actual_params.append(m.weight)
      self.quanted = False
      self.N = N
      self.wmax_mode = wmax_mode
      self.Gmax = Gmax
      self.std_noise = std_noise
      print(f"Initialized memquant with {len(self.saved_params)} parameters quantified")

  def __call__(self, input):
    return self.forward(input)

  def forward(self, input):
    # WARNING: do not use this forward for learning, it does 2 forward with 1 batch
    res_reliable = self.model(input).detach()
    if not self.quanted:
      self.quant()
    self.renoise()
    res = self.model(input)
    self.unquant()
    self.MSE = torch.mean(torch.square(res - res_reliable))
    return res, res_reliable

  @staticmethod
  def memory_usage(tensor):
    '''Return memory usage in MB'''
    return tensor.element_size() * tensor.nelement() / 1e6

  def memory_info(self):
    for i in range(len(self.saved_params)):
      print(f'Tensor {i} of shape {self.saved_params[i].shape}')
      print(f'Saved version dtype {self.saved_params[i].dtype} on {self.saved_params[i].device} (taking {self.memory_usage(self.saved_params[i])}) MB')
      print(f'Current version dtype {self.actual_params[i].dtype} on {self.actual_params[i].device} (taking {self.memory_usage(self.actual_params[i])}) MB')

  def quant(self):
    if self.quanted:
      self.unquant()
    for i in range(len(self.saved_params)):
      true_value = self.actual_params[i].data
      self.saved_params[i].copy_(true_value.clone().cpu())
      self._quantize(true_value)
      self.intermediate_params[i] = true_value.clone().cpu()
      self.actual_params[i].data.copy_(true_value)
    self.quanted = True

  def renoise(self):
    for i, inter in self.intermediate_params.items():
      self.actual_params[i].data.copy_(self.intermediate_params[i])
      self.actual_params[i].data += torch.normal(mean=0., std=self.std_noise, size=self.actual_params[i].shape, device=self.actual_params[i].device)

  def unquant(self):
    if self.quanted:
      for i in range(len(self.saved_params)):
        self.actual_params[i].data.copy_(self.saved_params[i].to(self.actual_params[i].data))
    self.quanted = False

  @torch.no_grad()
  def _quantize(self, tensor) -> None:
    c = self.Gmax / self.Wmax(tensor)
    delta = self.Gmax / self.N
    tensor *= c
    tensor /= delta
    torch.floor_(tensor)
    tensor += 0.5
    tensor *= delta

  @torch.no_grad()
  def Wmax(self, tensor):
    assert len(tensor.shape) == 2, 'Only works for 2d tensors !'
    if self.wmax_mode == 'all':
      return max([torch.max(torch.abs(t)) for t in self.saved_params])
    elif self.wmax_mode == 'layerwise':
      return torch.max(torch.abs(tensor))
    elif self.wmax_mode == 'columnwise':
      return torch.max(torch.abs(tensor), dim=0)

# Ex usage
quanter = MemristorQuant(model)
# To quant
quanter.quant()
# To noise
quanter.renoise()
# To unquant
quanter.unquant()

quanter.memory_info()

for i in net_param_iterator(model):
  print(i)

Initialized memquant with 6 parameters quantified


In [None]:
batch, _ = next(iter(testloader))
print(batch.shape)

In [None]:
%%timeit
model(batch)

In [None]:
%%timeit
quanter.quant()
model(batch)
quanter.unquant()

In [None]:
%%timeit
quanter.quant()

In [None]:
%%timeit
quanter.renoise()