In [1]:
import json

class TrainLogging:

    def __init__(self):
        self.log = []

    def stack(self, **kwargs):
        self.log.append(kwargs)

    def save(self, path: str):
        with open(path, "w") as f:
            json.dump(self.log, f, indent=4)

In [2]:
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
from typing import List, Tuple, TypeVar

def process(trainloader, testloader, model, epochs: int, lr: float, lr_scheduling=None, log_savepath=None):

    log_dict = defaultdict(list)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    if lr_scheduling is not None:
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_scheduling)

    def train(trainloader) -> Tuple[float, float]:
        sum_loss, sum_correct, sum_dataN = 0.0, 0, 0
        for (inputs, labels) in trainloader:
            optimizer.zero_grad()
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            sum_loss += loss.item()
            _, predicted = outputs.max(1)
            sum_dataN += labels.size(0)
            sum_correct += (predicted == labels).sum().item()
            loss.backward()
            optimizer.step()
        train_loss = sum_loss*trainloader.batch_size/len(trainloader.dataset)
        train_acc = float(sum_correct/sum_dataN)
        return train_loss, train_acc

    def test(testloader) -> Tuple[float, float]:
        sum_loss, sum_correct, sum_dataN = 0.0, 0, 0
        for (inputs, labels) in testloader:
            outputs, _ = model(inputs)
            loss = criterion(outputs, labels)
            sum_loss += loss.item()
            _, predicted = outputs.max(1)
            sum_dataN += labels.size(0)
            sum_correct += (predicted == labels).sum().item()
        test_loss = sum_loss*testloader.batch_size/len(testloader.dataset)
        test_acc = float(sum_correct/sum_dataN)
        return test_loss, test_acc

    print("\n{0:<13}{1:<13}{2:<13}{3:<13}{4:<13}{5:<6}".format("epoch","train/loss","train/acc","test/loss","test/acc","lr"))
    logging = TrainLogging()
    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train(trainloader)
        test_loss, test_acc = test(testloader)
        lr = optimizer.param_groups[-1]["lr"]
        print("{0:<13}{1:<13.5f}{2:<13.5f}{3:<13.5f}{4:<13.5f}{5:<6.6f}".format(epoch, train_loss, train_acc, test_loss, test_acc, lr))
        logging.stack(epoch=epoch, train_loss=train_loss, train_acc=train_acc, test_loss=test_loss, test_acc=test_acc, lr=lr)
        if lr_scheduling is not None: scheduler.step()
    if log_savepath is not None:
        logging.save(log_savepath)

    return model

In [3]:
# net のstub
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self, out=3):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1) # (1) 32*32*3 -> 32*32*16
        self.conv2 = nn.Conv2d(16, 32, 3, 1, padding=1) # (3) 16*16*16 -> 16*16*32
        self.conv3 = nn.Conv2d(32, 64, 3, 1, padding=1) # (5) 8*8*32 -> 8*8*64
        self.fc1 = nn.Linear(4*4*64, 500)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(500, out)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2) # (2) 32*32*16 -> 16*16*16
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2) # (4) 16*16*32 -> 8*8*32
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2, 2) # (6) 8*8*64 -> 4*4*64
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        feature = x
        x = self.dropout1(x)
        x = self.fc2(x)
        return x, feature

In [4]:
# trainloader, testloader を生成するstub
import torchvision
import torchvision.transforms as transforms

def update_labels(train, test):
    updated = [[], []]
    mapping_dict = defaultdict(lambda: -1)
    for i, t in enumerate([train, test]):
        t = sorted(t, key=lambda x:x[1])
        new_label = 0
        for data in t:
            if mapping_dict[data[1]] == -1:
                mapping_dict[data[1]] = new_label
                new_label += 1
            updated[i].append((data[0], mapping_dict[data[1]]))
    train, test = updated
    return train, test

def dataset_stub():
    path = "../../../prototype/proposal/data/"
    classes = [1,2,8]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train = torchvision.datasets.CIFAR10(root=path, train=True, download=True, transform=transform)
    test = torchvision.datasets.CIFAR10(root=path, train=False, download=True, transform=transform)
    train = [d for d in train if d[1] in classes]
    test = [d for d in test if d[1] in classes]
    train, test = update_labels(train, test)
    return train, test

def loader_stub():
    batch_size = 128
    train, test = dataset_stub()
    trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader

In [5]:
# lr_scheduling のstub
def lr_v1(epoch):
    if epoch < 10:
        return 1
    elif epoch < 20:
        return 0.1**1
    elif epoch < 30:
        return 0.1**2
    else:
        return 0.1**3

In [6]:
model = LeNet(3)
trainloader, testloader = loader_stub()
model = process(trainloader, testloader, model, epochs=15, lr=0.001, lr_scheduling=lr_v1, log_savepath="./assets/log.json")

Files already downloaded and verified
Files already downloaded and verified

epoch        train/loss   train/acc    test/loss    test/acc     lr    
1            1.10482      0.34820      1.12136      0.35033      0.001000
2            1.09926      0.38293      1.11336      0.41000      0.001000
3            1.09053      0.44947      1.10288      0.46267      0.001000
4            1.07346      0.49547      1.07907      0.50533      0.001000
5            1.04301      0.52713      1.03819      0.54467      0.001000
6            0.99340      0.55247      0.97853      0.57000      0.001000
7            0.93422      0.58020      0.92469      0.59300      0.001000
8            0.88373      0.60800      0.87588      0.61900      0.001000
9            0.84009      0.63067      0.83266      0.65200      0.001000
10           0.79425      0.66227      0.78007      0.67233      0.001000
11           0.76533      0.67673      0.77773      0.67900      0.000100
12           0.75810      0.68240    