## Demonstration of layer sharing on LeNet-5

In [1]:
import torch
import torch.nn as nn
import pandas as pd

from data.mnist import MnistDataset
from models.lenet.lenet import LeNet5
from utils.train import *
from utils.quantize import *
from utils.weight_sharing import *
from utils.plot import *

Parameters

In [2]:
LEARNING_RATE = 0.0001
BATCH_SIZE = 32
N_CLASSES = 10
DEVICE = None
EPOCHS = 100

NET_PATH = './models/lenet/saves/lenet.save'
SAVE_DATA = './results/test_share.csv'

Setting up the components

In [3]:
dataset = MnistDataset(BATCH_SIZE, './data', val_split=0.5)
model = LeNet5(N_CLASSES)
criterion = nn.CrossEntropyLoss()

lam_opt = lambda mod : torch.optim.Adam(mod.parameters(), lr=LEARNING_RATE)
lam_train = lambda opt, epochs : train_net(model, criterion, opt, dataset, epochs, device=DEVICE)
lam_test = lambda : get_accuracy(model, dataset.test_dl, DEVICE)

ws_controller = WeightShare(model, lam_opt, lam_train, lam_test)
ws_controller.print_layers_info()

layer_name #weights #bias w_locked CR
feature_extractor.0 150 6 False 1.00
feature_extractor.3 2400 16 False 1.00
feature_extractor.6 48000 120 False 1.00
classifier.0 10080 84 False 1.00
classifier.2 840 10 False 1.00
Sum num weights, bias:  61470 236
Compression rate 1.00


Getting the network

In [4]:
optimizer = lam_opt(model)
train_settings = [criterion, optimizer, dataset, EPOCHS, DEVICE, 1, True]

get_trained(model, NET_PATH, train_settings)

TEST - Sharing the whole network to given parameters

In [5]:
ws_controller.share([20, 20, 20, 20, 20], [0], [0, 0, 0, 0, 0])

{'accuracy': 0.9846,
 'compression': 1.4442953020134228,
 'times': {'train': 0,
  'share': 0.08937478065490723,
  'test': 0.8639969825744629}}

In [6]:
get_trained(model, NET_PATH, train_settings)
ws_controller.reset()
ws_controller.share([20, 20, 20, 20, 20], [0], [0, 0, 0, 0, 0], ['f2', 'f2', 'f2', 'f2', 'f2'])

{'accuracy': 0.9846,
 'compression': 1.6205128205128205,
 'times': {'train': 0,
  'share': 0.09102702140808105,
  'test': 0.8078582286834717}}

In [7]:
get_trained(model, NET_PATH, train_settings)
ws_controller.reset()
ws_controller.share([20, 20, 20, 20, 20], [0], [0, 0, 0, 0, 0], ['f1', 'f1', 'f1', 'f1', 'f1'])

{'accuracy': 0.985,
 'compression': 1.7504950495049507,
 'times': {'train': 0,
  'share': 0.11257600784301758,
  'test': 0.9388778209686279}}