In [42]:
import torch

from factorisation.densenet import get_student_densenet, get_standard_densnet121
from utils import (
    calculate_score,
    count_nonzero_parameters,
    get_device,
    get_macs,
    get_test_cifar10_dataloader,
    load_trained_model,
    test,
)

In [38]:
device = get_device()
test_loader = get_test_cifar10_dataloader()

## Teacher

In [33]:
teacher, _ = load_trained_model()
PARAM_REF = count_nonzero_parameters(teacher)
OPS_REF = get_macs(teacher)

## Student Model 24

In [48]:
res = torch.load("train_checkpoint/model_distill_train_student_size_24.pth")
student_24 = get_student_densenet()
student_24.to(device)
student_24.load_state_dict(res["net"])
print(res['acc'])

91.32


In [49]:
w_24 = count_nonzero_parameters(student_24)
f_24 = get_macs(student_24)

### Half Quantisation

In [41]:
student_24.half()

DenseNet(
  (conv1): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (dense1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(72, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(72, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): Bottleneck(
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [None]:
half_quant_distill_24_acc = test(test_loader, student_24, half=True)
half_quant_distill_24_acc

In [59]:
half_quant_distill_24_score = calculate_score(
    0, 0, 16, 16, w_24, f_24, PARAM_REF, OPS_REF
)
half_quant_distill_24_score

0.2903556607322233

## Student Model 32

In [51]:
res = torch.load("train_checkpoint/model_distill_train_student_size_32.pth")
student_32 = get_student_densenet(growth_rate=32)
student_32.to(device)
student_32.load_state_dict(res["net"])
print(res['acc'])

91.71


In [52]:
w_32 = count_nonzero_parameters(student_32)
f_32 = get_macs(student_32)

### Half Quantisation

In [55]:
student_32.half()

DenseNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (dense1): Sequential(
    (0): Bottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): Bottleneck(
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (2): Bottleneck(
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_sta

In [None]:
half_quant_distill_32_acc = test(test_loader, student_32, half=True)
half_quant_distill_32_acc

In [63]:
half_quant_distill_32_score = calculate_score(
    0, 0, 16, 16, w_32, f_32, PARAM_REF, OPS_REF
)
half_quant_distill_32_score

0.5147654779002551