# mnist classifier

+ MNIST
    + 28x28
    + 2 sets of labels
        + `targets` digit class
        + `color_targets` dye color of MNIST to different colors, a continuous tuning parameter in [0,1]

        
+ cnn classifier
    + predict `targets` or binarized `color_targets`

In [3]:
import os

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision 
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
import torchvision.utils as tv_utils

from torch.utils.tensorboard import SummaryWriter

from models import MnistCNN
from datasets import ColorMNIST
from plot_tools import plot_im
from utils import makedirs_exists_ok, seed_rng, set_cuda_visible_devices, load_weights_from_file

In [None]:
data_root = './data'
model_root = './models/mnist_classifier'
figure_root = './figures/mnist_classifier'
log_root = './logs/mnist_classifier'
image_size = 32
batch_size = 64
n_workers = 1
seed = 9
gpu_id = '0'
n_workers = 4
load_weights = ''
lr = 0.0002
n_epochs = 20
log_interval = 100
target_type = 'color'

In [None]:
makedirs_exists_ok(data_root)
makedirs_exists_ok(model_root)
makedirs_exists_ok(figure_root)
makedirs_exists_ok(log_root)

writer = SummaryWriter(log_root)
writer.flush()

seed_rng(seed)
device = set_cuda_visible_devices(gpu_id)

transforms = tv_transforms.Compose([
    tv_transforms.Resize(image_size),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize((0.1307,), (0.3081,))
])

train_loader = torch.utils.data.DataLoader(
    ColorMNIST(
        root=data_root, download=True, train=True, transform=transforms),
    batch_size=batch_size, shuffle=True, num_workers=n_workers)
test_loader = torch.utils.data.DataLoader(
    ColorMNIST(
        root=data_root, download=True, train=False, transform=transforms),
    batch_size=batch_size, shuffle=True, num_workers=n_workers)


model = MnistCNN(3, 1, 32).to(device)
load_weights_from_file(model, load_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(model)

In [None]:

def evaluate(model, test_loader, device, target_type):
    
    if target_type == 'digit':
        select_target = lambda y_digit, y_color: y_digit
        criterion = F.cross_entropy
        compute_decision = lambda output: output.argmax(dim=1, keepdim=True)
    else:
        select_target = lambda y_digit, y_color: (y_color < 0.5).float()
        criterion = F.binary_cross_entropy_with_logits
        compute_decision = lambda output: (torch.sigmoid(output) > 0).float()

    model.eval()

    with torch.no_grad():
        test_loss = 0
        correct   = 0
        for x, y_digit, y_color in test_loader:
            y = select_target(y_digit, y_color)
            x, y  = x.to(device), y.to(device)
            output = model(x)
            test_loss += criterion(output.view_as(y), y, reduction='sum').item()
            pred = compute_decision(output)
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        correct = 100. * correct / len(test_loader.dataset)

    model.train()

    return test_loss, correct


In [None]:
criterion = nn.BCEWithLogitsLoss()

for epoch in range(n_epochs):
    for it, (x, y, c) in enumerate(train_loader):
        x, y, c = x.to(device), y.to(device), c.to(device)
        
        optimizer.zero_grad()
        output = model(x)
        
        c = (c<0.5).float()
        
        loss = criterion(output.view_as(c), c)
        loss.backward()
        optimizer.step()
        

        ##############################################################
        # print
        ##############################################################
        
        global_step = epoch*len(train_loader)+it
        
        if it % log_interval == log_interval-1:
            print(f'[{epoch+1}/{n_epochs}]\t'
                  f'[{(it+1)*batch_size}/{len(train_loader.dataset)} ({100.*(it+1)/len(train_loader):.0f}%)]\t'
                  f'loss={loss.item():.4}')

    ##############################################################
    # evaluate
    ##############################################################
    
    test_loss, correct = evaluate(model, test_loader, device, target_type)

    print(f'[{epoch+1}/{n_epochs}]\t'
            f'Average Loss: {test_loss:.4}\t'
            f'Accuracy: {correct}/{len(test_loader.dataset)} ({correct:.0f}%)')
#     model.train()
#     torch.save(model.state_dict(), os.path.join(model_root, f'mnist_cnn_{epoch}.pt'))
    