In [2]:
!pip install torchvision
!pip install tqdm



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from osciquant.regularization import OsciQuantLoss
from osciquant.quantizers import UniformQuantizer
from osciquant.handler import attach_weight_quantizers, toggle_quantization

In [3]:
class TinySkyNet(nn.Module):
    def __init__(self):
        super(TinySkyNet, self).__init__()
        self.width = 256
        self.fc1 = nn.Linear(32 * 32 * 3, self.width)
        self.fc2 = nn.Linear(self.width, 128)
        self.output = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.output(x)
        return x


def train(model, epoch, train_loader, optimizer, criterion, device):
    model.train()
    train_loader_tqdm = tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=False)

    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (X, y) in enumerate(train_loader_tqdm):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()

        output = model(X)
        batch_loss = criterion(output, y)

        batch_loss.backward()
        optimizer.step()

        running_loss += batch_loss.item()
        _, predicted = output.max(1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()

        avg_loss = running_loss / (batch_idx + 1)
        accuracy = 100.0 * correct / total

        # Update TQDM postfix
        train_loader_tqdm.set_postfix({
            'loss': f'{avg_loss:.4f}',
            'acc': f'{accuracy:.2f}%'
        })

    final_loss = running_loss / len(train_loader)
    final_accuracy = 100.0 * correct / total

    return final_loss, final_accuracy


def test(model, test_loader, criterion, device, desc="Test"):
    model.eval()
    test_loader_tqdm = tqdm(test_loader, desc=desc, leave=True)

    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (X, y) in enumerate(test_loader_tqdm):
            X, y = X.to(device), y.to(device)
            output = model(X)
            loss = criterion(output, y).item()

            running_loss += loss
            _, predicted = output.max(1)
            total += y.size(0)
            correct += predicted.eq(y).sum().item()

            running_accuracy = 100. * correct / total
            avg_loss = running_loss / (batch_idx + 1)

            # Update the TQDM postfix
            test_loader_tqdm.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'acc': f'{running_accuracy:.2f}%',
            })

    final_loss = running_loss / len(test_loader)
    final_accuracy = 100. * correct / total

    return final_loss, final_accuracy


def build_dataset(train_ratio):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])

    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_size = int(train_ratio * len(train_dataset))
    val_size = len(train_dataset) - train_size

    train_dataset, val_dataset = random_split(train_dataset,[train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

    return train_loader, val_loader, test_loader

In [4]:
EPOCHS = 20
LR = 0.00025
BIT = 2  # Ternary because of the symmetric quantizer
LAMBDA_VAL = 14.0
EXCLUDE = []  # quantize all layers
TRAIN_SIZE = 1.0
QUANTIZER = UniformQuantizer(bit_width=BIT)
REGULARIZATION = True
DEVICE = "mps"

In [None]:
# With OsciQuant and QAT
train_loader, val_loader, test_loader = build_dataset(TRAIN_SIZE)  # CIFAR-10
device = DEVICE

model = TinySkyNet()
model.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

# Osciquant
attach_weight_quantizers(model=model, exclude_layers=EXCLUDE, quantizer=QUANTIZER, enabled=False)
criterion_reg = OsciQuantLoss(base_loss=criterion, model=model, regularization_lambda=LAMBDA_VAL, regularization=REGULARIZATION)

for epoch in range(EPOCHS):
    # train
    toggle_quantization(model, enabled=False)
    train_loss, train_acc = train(model, epoch, train_loader, optimizer, criterion_reg, device)
    
    # test
    # Todo: split train into val and select best model for test
    toggle_quantization(model, enabled=False)
    test_loss, test_acc = test(model, test_loader, criterion, device, desc=f"Test FP32")
    toggle_quantization(model, enabled=True)
    test_loss, test_acc = test(model, test_loader, criterion, device, desc=f"Test {BIT}-bit")
     
    # cross bit test
    # todo: reset to BIT after cross bit test. Make a function for this in util
    # for bit_width in [4,3,2]:
    #     for name, submodule in model.named_modules():
    #         if hasattr(submodule, 'parametrizations'):
    #             # submodule.parametrizations is a dictionary like {"weight": [param_module, ...]}
    #             for param_name, param_list in submodule.parametrizations.items():
    #                 for p in param_list:
    #                     if isinstance(p, FakeQuantParametrization):
    #                         p.quantizer.set_bits(bit_width)
    #     toggle_quantization(model, enabled=True)
    #     test_loss, test_acc = test(model, test_loader, criterion, device, desc=f"Test {bit_width}")

Attached weight quantizer to layer: fc1
Attached weight quantizer to layer: fc2
Attached weight quantizer to layer: output
Osciquant


Test 4: 100%|██████████| 79/79 [00:00<00:00, 81.28it/s, loss=1.6561, acc=41.05%]            
Test 3: 100%|██████████| 79/79 [00:00<00:00, 84.44it/s, loss=1.6662, acc=40.95%]
Test 2: 100%|██████████| 79/79 [00:00<00:00, 84.88it/s, loss=1.5577, acc=44.83%]
Test 4: 100%|██████████| 79/79 [00:00<00:00, 83.82it/s, loss=1.5460, acc=45.60%]            
Test 3: 100%|██████████| 79/79 [00:00<00:00, 84.90it/s, loss=1.5378, acc=45.75%]
Test 2: 100%|██████████| 79/79 [00:00<00:00, 84.79it/s, loss=1.4537, acc=48.75%]
Training Epoch 2:  86%|████████▌ | 335/391 [00:07<00:01, 45.41it/s, loss=1.3921, acc=51.05%]

In [None]:
# for bit_width in [2,3,4]:
#     for name, submodule in model.named_modules():
#         if hasattr(submodule, 'parametrizations'):
#             # submodule.parametrizations is a dictionary like {"weight": [param_module, ...]}
#             for param_name, param_list in submodule.parametrizations.items():
#                 for p in param_list:
#                     if isinstance(p, FakeQuantParametrization):
#                         FakeQuantParametrization.quantizer.set_bit_width(bit_width)
#     toggle_quantization(model, enabled=True)
#     test_loss, test_acc = test(model, test_loader, criterion, device, desc=f"Test {bit_width}")
# 
# FakeQuantParametrization.quantizer.set_bit_width(BIT)