# Classification on CIFAR

In [None]:
import os
import time
import json
import copy
from pathlib import Path
import datetime
import csv

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision

import ops.trains as trains
import ops.tests as tests
import ops.datasets as datasets
import ops.schedulers as schedulers
import models

In [None]:
# config_path = "configs/cifar10_general.json"
config_path = "configs/cifar100_general.json"

with open(config_path) as f:
    args = json.load(f)
    print("args: \n", args)

In [None]:
dataset_args = copy.deepcopy(args).get("dataset")
train_args = copy.deepcopy(args).get("train")
val_args = copy.deepcopy(args).get("val")

dataset_train, dataset_test = datasets.get_dataset(**dataset_args, download=True)
dataset_name = dataset_args["name"]
num_classes = len(dataset_train.classes)

dataset_train = DataLoader(dataset_train, 
                           shuffle=True, 
                           num_workers=train_args.get("num_workers", 4), 
                           batch_size=train_args.get("batch_size", 128))
dataset_test = DataLoader(dataset_test, 
                          num_workers=val_args.get("num_workers", 4), 
                          batch_size=val_args.get("batch_size", 128))

print("Train: %s, Test: %s, Classes: %s" % (
    len(dataset_train.dataset), 
    len(dataset_test.dataset), 
    num_classes
))

## Model

In [None]:
# VGG
# name = "vgg_dnn_19"
# name = "vgg_mcdo_19"
# name = "vgg_dnn_smoothing_19"
# name = "vgg_mcdo_smoothing_19"

# ResNet
# name = "resnet_dnn_18"
# name = "resnet_mcdo_18"
# name = "resnet_dnn_smoothing_18"
name = "resnet_mcdo_smoothing_18"

# name = "resnet_dnn_50"
# name = "resnet_mcdo_50"
# name = "resnet_dnn_smoothing_50"
# name = "resnet_mcdo_smoothing_50"

# Preact ResNet
# name = "preresnet_dnn_18"
# name = "preresnet_mcdo_18"
# name = "preresnet_dnn_smoothing_18"
# name = "preresnet_mcdo_smoothing_18"

# name = "preresnet_dnn_50"
# name = "preresnet_mcdo_50"
# name = "preresnet_dnn_smoothing_50"
# name = "preresnet_mcdo_smoothing_50"

# ResNeXt
# name = "resnext_dnn_50"
# name = "resnext_mcdo_50"
# name = "resnext_dnn_smoothing_50"
# name = "resnext_mcdo_smoothing_50"

# WideResNet
# name = "wideresnet_dnn_50"
# name = "wideresnet_mcdo_50"
# name = "wideresnet_dnn_smoothing_50"
# name = "wideresnet_mcdo_smoothing_50"


model = models.get_model(name, num_classes=num_classes, tiny=True, temp=5e1)
# models.load(model, dataset_name, saved_time)

## Train

Define a TensorBoard writer:

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = os.path.join("runs", dataset_name, model.name, current_time)
writer = SummaryWriter(log_dir)

with open("%s/config.json" % log_dir, "w") as f:
    json.dump(args, f)
with open("%s/model.log" % log_dir, "w") as f:
    f.write(repr(model))

print("Create TensorBoard log dir: ", log_dir)

Train the model:

In [None]:
train_args = copy.deepcopy(args).get("train")
epochs = train_args.pop("epochs")
warmup_epochs = train_args.get("warmup_epochs", 0)
val_args = copy.deepcopy(args).get("val")
n_ff = val_args.pop("n_ff", 1)
gpu = torch.cuda.is_available()

optim_args = copy.deepcopy(args).get("optim")
optimizer, train_scheduler = trains.get_optimizer(model, **optim_args)
warmup_scheduler = schedulers.WarmupScheduler(optimizer, len(dataset_train) * warmup_epochs)

model = model.cuda() if gpu else model.cpu()
warmup_time = time.time()
for epoch in range(warmup_epochs):
    *train_metrics, = trains.train_epoch(optimizer, model, dataset_train, 
                                         warmup_scheduler, gpu=gpu)
if warmup_epochs > 0:
    print("The model is warmed up: %.2f sec" % (time.time() - warmup_time))

for epoch in range(epochs):
    batch_time = time.time()
    *train_metrics, = trains.train_epoch(optimizer, model, dataset_train, 
                                         gpu=gpu)
    train_scheduler.step()
    batch_time = time.time() - batch_time
    
    if (epoch + 1) % 1 == 0:
        trains.add_train_metrics(writer, train_metrics, epoch)
        template = "(%.2f sec/epoch) Epoch: %d, Loss: %.4f, lr: %.3e"
        print(template % (batch_time,
                          epoch,
                          train_metrics[0], 
                          [param_group["lr"] for param_group in optimizer.param_groups][0]))
    
    if (epoch + 1) % 1 == 0:
        *test_metrics, cal_diag = tests.test(model, n_ff, dataset_test, num_classes, verbose=False, gpu=gpu)
        trains.add_test_metrics(writer, test_metrics, epoch)

        cal_diag = torchvision.utils.make_grid(cal_diag)
        writer.add_image("test/calibration diagrams", cal_diag, global_step=epoch)

#         for name, param in model.named_parameters():
#             name = name.split(".")
#             writer.add_histogram("%s/%s" % (name[0], ".".join(name[1:])), param, global_step=epoch)


Save the model:

In [None]:
models.save(model, dataset_name, current_time, optimizer=optimizer)

## Test

In [None]:
gpu = torch.cuda.is_available()

model = model.cuda() if gpu else model.cpu()
*metrics, cal_diag = tests.test(model, 1, dataset_test, verbose=True, gpu=gpu)
metrics_list = [[1, *metrics]]

leaderboard_path = os.path.join("leaderboard", "logs", dataset_name, model.name)
Path(leaderboard_path).mkdir(parents=True, exist_ok=True)
metrics_dir = os.path.join(leaderboard_path, "%s_%s_%s.csv" % (dataset_name, model.name, current_time))
tests.save_metrics(metrics_dir, metrics_list)

In [None]:
gpu = torch.cuda.is_available()

model = model.cuda() if gpu else model.cpu()
metrics_list = []
for n_ff in [1, 2, 3, 4, 5, 10, 20, 50]:
    print("N: %s, " % n_ff, end="")
    *metrics, cal_diag = tests.test(model, n_ff, dataset_test, verbose=False, gpu=gpu)
    metrics_list.append([n_ff, *metrics])

leaderboard_path = os.path.join("leaderboard", "logs", dataset_name, model.name)
Path(leaderboard_path).mkdir(parents=True, exist_ok=True)
metrics_dir = os.path.join(leaderboard_path, "%s_%s_%s.csv" % (dataset_name, model.name, current_time))
tests.save_metrics(metrics_dir, metrics_list)