In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# define training settings
NUM_EPOCHS = 150
BATCH_SIZE = 128
MULTIPLIER = 1.0

In [None]:
import torch

CUDA_DEVICE = 0

torch.cuda.set_device(CUDA_DEVICE)
device = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")

In [None]:
from doren_bnn.mobilenet import MobileNet, NetType
from torchinfo import summary

NETTYPE = NetType.XNORPP_STTN
model = MobileNet(3, num_classes=10, nettype=NETTYPE).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

# from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=5e-6)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 25, eta_min=1e-4)
scheduler = CosineAnnealingLR(optimizer, NUM_EPOCHS)

In [None]:
from doren_bnn.utils import Dataset, Experiment

EXPERIMENT_ID = f"mobilenet-xnorpp-sttn-cifar10-{MULTIPLIER}x"
print(EXPERIMENT_ID)
experiment = Experiment(
    EXPERIMENT_ID, Dataset.CIFAR10, BATCH_SIZE, multiplier=MULTIPLIER
)

In [None]:
# uncomment if you wish to load a previous checkpoint
experiment.load_checkpoint(model, optimizer, scheduler)

# Training

In [None]:
experiment.train(
    model,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    device=device,
)

# Test-time Inference

In [None]:
from doren_bnn.mobilenet import MobileNet, NetType
from torchinfo import summary

NETTYPE = NetType.XNORPP_STTN
model = MobileNet(3, num_classes=10, nettype=NETTYPE).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

In [None]:
experiment.load_checkpoint(model, optimizer, scheduler)

In [None]:
experiment.test(model, device=device)

In [None]:
from doren_bnn.xnorpp_sttn import Conv2d_XnorPP_STTN

for module in model.modules():
    if isinstance(module, Conv2d_XnorPP_STTN):
        print(module.in_channels, module.out_channels, module.kernel_size)
        print(module.weight1.size())

        total_num_sparse = 0
        max_num_nonsparse = -1
        for (row1, row2) in zip(module.weight1, module.weight2):
            num_sparse = torch.bitwise_xor(row1.gt(0), row2.gt(0)).sum().item()
            num_nonsparse = row1.numel() - num_sparse

            total_num_sparse += num_sparse
            if num_nonsparse > max_num_nonsparse:
                max_num_nonsparse = num_nonsparse

        print(max_num_nonsparse)
        print(total_num_sparse / module.weight1.numel())
        print("---")