
This notebook is used to present a toy trainer for an LeNet+DyN model.
The complete training procedure involves multiple buffer mechanisms for permutating Dyn settings on the dev set.
This procedure will be presented as an online open-sourced architecture.


In [1]:

import time
import random
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

import matplotlib.pyplot as plt

import copy

# check device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)


cpu


In [2]:

# parameters
LEARNING_RATE = 0.001
BATCH_SIZE = 32

IMG_SIZE = 32
N_CLASSES = 10



In [3]:

# define transforms
data_transforms = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])

# download and create datasets
train_dataset = datasets.MNIST(root='datasets/MNIST', 
                               train=True, 
                               transform=data_transforms,
                               download=True)

valid_dataset = datasets.MNIST(root='datasets/MNIST', 
                               train=False, 
                               transform=data_transforms)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)

valid_loader = DataLoader(dataset=valid_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False)


In [4]:

def train(train_loader, model, criterion, optimizer, device):
    '''
    Function for the training step of the training loop
    '''

    model.train()
    running_loss = 0
    
    for X, y_true in train_loader:

        optimizer.zero_grad()
        
        X = X.to(device)
        y_true = y_true.to(device)
    
        # Forward pass
        y_hat, _ = model(X) 
        loss = criterion(y_hat, y_true) 
        running_loss += loss.item() * X.size(0)

        # Backward pass
        loss.backward()
        optimizer.step()
        
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss



def validate(valid_loader, model, criterion, device):
    '''
    Function for the validation step of the training loop
    '''
   
    model.eval()
    running_loss = 0
    
    for X, y_true in valid_loader:
    
        X = X.to(device)
        y_true = y_true.to(device)

        # Forward pass and record loss
        y_hat, _ = model(X) 
        loss = criterion(y_hat, y_true) 
        running_loss += loss.item() * X.size(0)

    epoch_loss = running_loss / len(valid_loader.dataset)
        
    return model, epoch_loss


def get_accuracy(model, data_loader, device):
    '''
    Function for computing the accuracy of the predictions over the entire data_loader
    '''
    
    correct_pred = 0 
    n = 0
    
    with torch.no_grad():
        model.eval()
        for X, y_true in data_loader:

            X = X.to(device)
            y_true = y_true.to(device)

            _, y_prob = model(X)
            _, predicted_labels = torch.max(y_prob, 1)

            n += y_true.size(0)
            correct_pred += (predicted_labels == y_true).sum()

    return correct_pred.float() / n


def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
    '''
    Function defining the entire training loop
    '''
    #global c_dyn_model
    # set objects for storing metrics
    best_val = 0
    best_loss = 1e10
    train_losses = []
    valid_losses = []
 
    # Train model
    for epoch in range(0, epochs):

        # training
        model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)
        train_losses.append(train_loss)

        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)

        if epoch % print_every == (print_every - 1):
            
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
                
#             print(f'{datetime.now().time().replace(microsecond=0)} --- '
#                   f'Epoch: {epoch}\t'
#                   f'Train loss: {train_loss:.4f}\t'
#                   f'Valid loss: {valid_loss:.4f}\t'
#                   f'Train accuracy: {100 * train_acc:.2f}\t'
#                   f'Valid accuracy: {100 * valid_acc:.2f}')
    
    return model, optimizer, (train_losses, valid_losses)


class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.fc = nn.Linear(400, 120)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(84, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        probs = F.softmax(out, dim=1)
        return out, probs





In [5]:

def compute_stress(dyn_model, c_model):

    FT_convs, rec_CF0, rec_CF1, rec_CF2 = dyn_model.forward(c_model)
    
    loss_W = torch.sum(((rec_CF0-c_model.state_dict()['fc.weight'])**2)/rec_CF0.numel())
    loss_W += torch.sum(((rec_CF1-c_model.state_dict()['fc1.weight'])**2)/rec_CF1.numel())
    loss_W += torch.sum(((rec_CF2-c_model.state_dict()['fc2.weight'])**2)/rec_CF2.numel())
    
    for row_id in range(5):
        for col_id in range(5):
            loss_W += torch.sum(((FT_convs[row_id*5+col_id]-c_model.state_dict()['layer2.0.weight'][:,:,row_id,col_id])**2)/(16*6))

    return loss_W


def update_DyNs(dyn_model, c_model, _optim, dyn_epochs, loss_thres=0.1):

    global c_dyn_model, cur_best_val
    raw_valAcc = get_accuracy(c_model, valid_loader, device=DEVICE)
    
    for _ep in range(dyn_epochs):
        loss_W = compute_stress(dyn_model, c_model)
        _optim.zero_grad()
        loss_W.backward()
        _optim.step()
        if loss_W.item() < loss_thres: return dyn_model, _ep, loss_W.item()
            
        if _ep %(dyn_epochs//10)==0:
            rec_model = recover_LeNet(dyn_model, c_model, _prec=prec_id*1e-4)
            valid_acc = get_accuracy(rec_model, valid_loader, device=DEVICE)
            if valid_acc.item() > cur_best_val:
                cur_best_val = valid_acc.item()
                c_dyn_model = copy.deepcopy(dyn_model)
            print(datetime.now().time().replace(microsecond=0), _ep, loss_W.item(), valid_acc.item())
            if valid_acc > raw_valAcc: return dyn_model, _ep, loss_W.item()

    return dyn_model, _ep, loss_W.item()


def path_integral(dyn_model, _prec=0):
        
        FT_convs = []
        _scale = dyn_model._scale
        
        for row_id in range(5):
            for col_id in range(5):
                inFT_Qs = dyn_model.inFT_Qs[row_id]
                ouFT_Qs = dyn_model.ouFT_Qs[col_id]
                
                if _prec != 0:
                    inFT_Qs = _prec*(torch.div(inFT_Qs, _prec, rounding_mode='floor'))
                    ouFT_Qs = _prec*(torch.div(ouFT_Qs, _prec, rounding_mode='floor'))
                
                lambdas_FT = dyn_model.lambdas_FT[row_id*5+col_id]
                
                FT_convs.append(torch.sum((_scale/(torch.cdist(inFT_Qs, ouFT_Qs)))*lambdas_FT, 0))
            
        lambdas_ih1CF = dyn_model.lambdas_ih1CF
        lambdas_h12CF = dyn_model.lambdas_h12CF
        lambdas_h2oCF = dyn_model.lambdas_h2oCF
        
        inCF_Qs = dyn_model.inCF_Qs
        h1CF_Qs = dyn_model.h1CF_Qs
        h2CF_Qs = dyn_model.h2CF_Qs
        ouCF_Qs = dyn_model.ouCF_Qs
        
        if _prec != 0:            
            inCF_Qs = _prec*(torch.div(inCF_Qs, _prec, rounding_mode='floor'))
            h1CF_Qs = _prec*(torch.div(h1CF_Qs, _prec, rounding_mode='floor'))
            h2CF_Qs = _prec*(torch.div(h2CF_Qs, _prec, rounding_mode='floor'))
            ouCF_Qs = _prec*(torch.div(ouCF_Qs, _prec, rounding_mode='floor'))
            
        rec_CF0 = torch.sum((_scale/(torch.cdist(h1CF_Qs, inCF_Qs)))*lambdas_ih1CF, 0)
        rec_CF1 = torch.sum((_scale/(torch.cdist(h2CF_Qs, h1CF_Qs)))*lambdas_h12CF, 0)
        rec_CF2 = torch.sum((_scale/(torch.cdist(ouCF_Qs, h2CF_Qs)))*lambdas_h2oCF, 0)

        return FT_convs, rec_CF0, rec_CF1, rec_CF2



def recover_LeNet(dyn_model, _model, _prec=0):
    
    c_model = copy.deepcopy(_model)
    FT_convs, rec_CF0, rec_CF1, rec_CF2 = path_integral(dyn_model, _prec=_prec)
    
    c_model.state_dict()['fc.weight'].copy_(rec_CF0)
    c_model.state_dict()['fc1.weight'].copy_(rec_CF1)
    c_model.state_dict()['fc2.weight'].copy_(rec_CF2)
    
    for row_id in range(5):
        for col_id in range(5):
            conv_param = FT_convs[row_id*5+col_id]
            c_model.state_dict()['layer2.0.weight'][:,:,row_id,col_id].copy_(conv_param)
    
    return c_model




In [6]:


class leNetDyNMat(nn.Module):
    def __init__(self, num_ftQs, num_cfQs, q_dim, _scale):
        super(leNetDyNMat, self).__init__() 

        self.q_dim = q_dim
        
        self.inFT_Qs = nn.Parameter(torch.rand(5, num_ftQs, 16, q_dim, device=DEVICE))
        self.ouFT_Qs = nn.Parameter(torch.rand(5, num_ftQs, 6, q_dim, device=DEVICE))
        
        self.lambdas_FT = nn.Parameter(torch.randn(25, num_ftQs, 1, 1, device=DEVICE))
        
        self.inCF_Qs = nn.Parameter(torch.rand(num_cfQs, 400, q_dim, device=DEVICE))
        self.h1CF_Qs = nn.Parameter(torch.rand(num_cfQs, 120, q_dim, device=DEVICE))
        self.h2CF_Qs = nn.Parameter(torch.rand(num_cfQs, 84, q_dim, device=DEVICE))
        self.ouCF_Qs = nn.Parameter(torch.rand(num_cfQs, 10, q_dim, device=DEVICE))
        
        self.lambdas_ih1CF = nn.Parameter(torch.randn(num_cfQs, 1, 1, device=DEVICE))
        self.lambdas_h12CF = nn.Parameter(torch.randn(num_cfQs, 1, 1, device=DEVICE))
        self.lambdas_h2oCF = nn.Parameter(torch.randn(num_cfQs, 1, 1, device=DEVICE))
        
        self._scale = _scale
        self.relu = nn.ReLU()
        
        self.raw_numParams = 5*(16+6)*num_ftQs*q_dim + q_dim*num_cfQs*(400+120+84+10)+num_cfQs*3+150+5*6+5*16+120+84+10
        self.com_numParams = 5*(16+6)*num_ftQs + num_cfQs*(400+120+84+10)+num_cfQs*3+150+5*6+5*16+120+84+10
        
    def forward(self, c_model):
        
        FT_convs = []
        
        for row_id in range(5):
            for col_id in range(5):
                FT_convs.append(torch.sum((self._scale/(torch.cdist(self.inFT_Qs[row_id], self.ouFT_Qs[col_id])))\
                          *self.lambdas_FT[row_id*5+col_id], 0))
                
        rec_CF0 = torch.sum((self._scale/(torch.cdist(self.h1CF_Qs, self.inCF_Qs)))*self.lambdas_ih1CF, 0)
        rec_CF1 = torch.sum((self._scale/(torch.cdist(self.h2CF_Qs, self.h1CF_Qs)))*self.lambdas_h12CF, 0)
        rec_CF2 = torch.sum((self._scale/(torch.cdist(self.ouCF_Qs, self.h2CF_Qs)))*self.lambdas_h2oCF, 0)

        return FT_convs, rec_CF0, rec_CF1, rec_CF2
    



In [7]:

criterion = nn.CrossEntropyLoss()


In [8]:

buffer_model = LeNet5(N_CLASSES).to(DEVICE)
buffer_optim = torch.optim.Adam(buffer_model.parameters(), lr=LEARNING_RATE)


In [9]:

dyn_model = leNetDyNMat(4,3,9,2)
dyn_optim = torch.optim.Adam(dyn_model.parameters(), lr=1e-3)


In [10]:

c_dyn_model = leNetDyNMat(4,3,9,2)
cur_best_val = 0
prec_id = 10
NN_update_epochs = 1
DyN_update_epochs = 2000

for dyn_batch in range(20):
    buffer_model, optimizer, _ = training_loop(buffer_model, criterion, buffer_optim, 
                                               train_loader, valid_loader, NN_update_epochs, DEVICE)
    
    dyn_model, dyn_ep, dyn_loss = update_DyNs(dyn_model, buffer_model, dyn_optim, DyN_update_epochs, loss_thres=1e-5)
    buffer_model = recover_LeNet(dyn_model, buffer_model, _prec=prec_id*1e-4)
    buffer_optim = torch.optim.Adam(buffer_model.parameters(), lr=LEARNING_RATE)
    
    valid_acc = get_accuracy(buffer_model, valid_loader, device=DEVICE)
    
    print(dyn_batch, valid_acc.item(), '--- Best ValAcc:', cur_best_val, '--- Raw NumParams:', dyn_model.raw_numParams, '--- Compressed NumParams:', dyn_model.com_numParams)



17:37:53 0 501.8808898925781 0.09749999642372131
17:37:57 200 127.29206085205078 0.07670000195503235
17:38:00 400 56.02415466308594 0.07840000092983246
17:38:04 600 30.76397705078125 0.11010000109672546
17:38:07 800 18.857959747314453 0.12189999967813492
17:38:11 1000 11.801366806030273 0.13179999589920044
17:38:15 1200 6.245660781860352 0.10740000009536743
17:38:18 1400 2.774937152862549 0.08470000326633453
17:38:22 1600 1.1539548635482788 0.08839999884366989
17:38:26 1800 0.5317487716674805 0.10409999638795853
0 0.10010000318288803 --- Best ValAcc: 0.13179999589920044 --- Raw NumParams: 21021 --- Compressed NumParams: 2765
17:39:00 0 0.07143969088792801 0.10100000351667404
17:39:04 200 0.05131681263446808 0.40560001134872437
17:39:07 400 0.046170923858881 0.6087999939918518
17:39:11 600 0.04262854903936386 0.6560999751091003
17:39:15 800 0.03984040394425392 0.698199987411499
17:39:19 1000 0.0375080369412899 0.7487000226974487
17:39:22 1200 0.035486627370119095 0.7567999958992004
17:3

In [11]:

buffer_model = recover_LeNet(c_dyn_model, buffer_model, _prec=prec_id*1e-4)
valid_acc = get_accuracy(buffer_model, valid_loader, device=DEVICE)
print(valid_acc.item())


0.9922999739646912
