In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from time import time

from models.cifar import bit_resnet20_b158
from calculate_statistics import calculate_mean_std

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

dataset = datasets.CIFAR10(
    root="data", train=True, download=True, transform=transforms.ToTensor()
)
dataloader = DataLoader(dataset, batch_size=64)
mean, std = calculate_mean_std(dataloader, device)

In [None]:
train_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)
test_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

train_dataset = datasets.CIFAR10(
    root="data", train=True, download=True, transform=train_transform
)
test_dataset = datasets.CIFAR10(
    root="data", train=False, download=True, transform=test_transform
)

In [None]:
num_classes = 10
max_epoch = 200
optim_setting = {
    "name": "SGD",
    "args": {
        "lr": 0.1,
        "momentum": 0.9,
        "weight_decay": 5e-4,
        "nesterov": True,
    },
}
scheduler_setting = {
    "name": "CosineAnnealingLR",
    "args": {
        "T_max": max_epoch, 
        "eta_min": 0.0
    },
}

model = bit_resnet20_b158(num_classes).to(device)
optimizer = getattr(torch.optim, optim_setting["name"])(
    model.parameters(), **optim_setting["args"]
)
scheduler = getattr(torch.optim.lr_scheduler, scheduler_setting["name"])(
    optimizer, **scheduler_setting["args"]
)

In [None]:
# ミニバッチサイズ・エポック数の設定
batch_size = 64
epoch_num = 10
n_iter = len(train_dataset) / batch_size

# データローダーの設定
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 誤差関数の設定
criterion = nn.CrossEntropyLoss().to(device)

# ネットワークを学習モードへ変更
model.train()

start = time()
for epoch in range(1, epoch_num+1):
    sum_loss = 0.0
    count = 0
    
    for image, label in train_loader:
        
        image = image.to(device)
        label = label.to(device)
            
        y = model(image)
        
        loss = criterion(y, label)
        
        model.zero_grad()
        loss.backward()
        optimizer.step()
        
        sum_loss += loss.item()
        
        pred = torch.argmax(y, dim=1)
        count += torch.sum(pred == label)
    scheduler.step()
        
    print("epoch: {}, mean loss: {}, mean accuracy: {}, elapsed_time :{}".format(epoch,
                                                                                 sum_loss / n_iter,
                                                                                 count.item() / len(train_dataset),
                                                                                 time() - start))

In [None]:
# データローダーの準備
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False)

# ネットワークを評価モードへ変更
model.eval()

# 評価の実行
count = 0
with torch.no_grad():
    for image, label in test_loader:

        image = image.to(device)
        label = label.to(device)
            
        y = model(image)

        pred = torch.argmax(y, dim=1)
        count += torch.sum(pred == label)

print("test accuracy: {}".format(count.item() / 10000.))