In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# define training settings
NUM_EPOCHS = 150
BATCH_SIZE = 128
MULTIPLIER = 0.1

In [None]:
import torch

CUDA_DEVICE = 3

torch.cuda.set_device(CUDA_DEVICE)
device = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")

In [None]:
from doren_bnn.mobilenet import MobileNet, NetType
from torchinfo import summary

NETTYPE = NetType.XNORPP_SCA
model = MobileNet(3, num_classes=10, nettype=NETTYPE).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

# from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=5e-6)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 25, eta_min=1e-4)
scheduler = CosineAnnealingLR(optimizer, NUM_EPOCHS)

In [None]:
# hyperparameters
ALPHA = 0.01
LAMBDA = 1e-4

In [None]:
from doren_bnn.utils import Dataset, Experiment

EXPERIMENT_ID = f"archived/mobilenet-xnorpp-sca-cifar10-full-{ALPHA}-{LAMBDA}"
print(EXPERIMENT_ID)
experiment = Experiment(
    EXPERIMENT_ID, Dataset.CIFAR10, BATCH_SIZE, multiplier=MULTIPLIER
)

In [None]:
# uncomment if you wish to load a previous checkpoint
experiment.load_checkpoint(model, optimizer, scheduler)

# Training

In [None]:
from doren_bnn.xnorpp_sca import Conv2d_XnorPP_SCA

LAMB_PERIOD = 50
hyperparams_dict = {
    "alpha": lambda _: ALPHA,
    "lamb": lambda epoch: 0
    if epoch < LAMB_PERIOD
    else LAMBDA * (10 ** -((NUM_EPOCHS - epoch) // LAMB_PERIOD)),
}


def regulariser(model=None, alpha: float = ALPHA, lamb: float = LAMBDA):
    wdrs = [
        layer.wdr(alpha)
        for layer in model.modules()
        if isinstance(layer, Conv2d_XnorPP_SCA)
    ]
    # print(["{:.3f}".format(float(wdr)) for wdr in wdrs])
    return lamb * sum(wdrs)

In [None]:
experiment.train(
    model,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    device=device,
    hyperparams_dict=hyperparams_dict,
    regulariser=regulariser,
)

# Test-time Inference

In [None]:
from doren_bnn.mobilenet import MobileNet, NetType
from torchinfo import summary

NETTYPE = NetType.XNORPP_SCA
model = MobileNet(3, num_classes=10, nettype=NETTYPE, test=True).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

In [None]:
experiment.load_checkpoint(model, optimizer, scheduler)

In [None]:
experiment.test(model, device=device)

In [None]:
from doren_bnn.xnorpp_sca import Conv2d_XnorPP_SCA

sparsity = []
for module in model.modules():
    if isinstance(module, Conv2d_XnorPP_SCA):
        print(module.in_channels, module.out_channels, module.kernel_size)
        print(module.weight.size())

        tanh_weight_sq = torch.tanh(module.weight).square()
        quant_err = (tanh_weight_sq * (1 - tanh_weight_sq)).sum().item()

        total_num_sparse = 0
        max_num_nonsparse = -1
        for row in module.weight:
            num_sparse = (torch.round(torch.tanh(row)) == 0).sum().item()
            num_nonsparse = row.numel() - num_sparse

            total_num_sparse += num_sparse
            if num_nonsparse > max_num_nonsparse:
                max_num_nonsparse = num_nonsparse

        print(max_num_nonsparse)
        print(total_num_sparse / module.weight.numel(), quant_err)
        print("---")

        sparsity.append(total_num_sparse / module.weight.numel())

In [None]:
import matplotlib.pyplot as plt

ids_1 = [i for (i, _, k) in sparsity if k == 1]
vals_1 = [val for (_, val, k) in sparsity if k == 1]
plt.scatter(ids_1, vals_1, c="red")

ids_3 = [i for (i, _, k) in sparsity if k == 3]
vals_3 = [val for (_, val, k) in sparsity if k == 3]
plt.scatter(ids_3, vals_3, c="blue")

plt.xlabel("layer no.")
plt.ylabel("sparsity")