In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torch import optim
from torchvision.datasets import SVHN, MNIST
from tqdm import tqdm

from src.data_utils import (
    TransformDataset,
    SrcDataset,
    DstDataset,
    get_train_weak_transform,
    get_train_strong_transform,
    get_test_transform,
)
from src.models import get_model
from src.reproducibility import set_seed
from src.loss import consistency_loss

In [2]:
# 入力画像サイズ
INPUT_SIZE = 32

# learning rate
LR = 1e-3

# エポック数
EPOCHS = 10

In [3]:
def get_source_dl(batch_size=32):
    weak_trans = get_train_weak_transform()
    test_trans = get_test_transform()
    train_mnist = MNIST(
        root='./mnist/',
        train=True,
        download=True,
    )
    test_mnist = MNIST(
        root='./mnist/',
        train=False,
        download=True,
        transform=test_trans
    )
    split_generator = torch.Generator().manual_seed(2024)
    train_mnist, valid_mnist = random_split(train_mnist, [0.8, 0.2], split_generator)
    train_mnist = SrcDataset(
        train_mnist,
        weak_trans,
    )
    valid_mnist = TransformDataset(
        valid_mnist,
        test_trans
    )
    test_mnist = TransformDataset(
        test_mnist,
        test_trans
    )
    dl_generator = torch.Generator().manual_seed(2024)
    src_train_dl = DataLoader(
        train_mnist,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        generator=dl_generator,
    )
    src_valid_dl = DataLoader(
        valid_mnist,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
    )
    src_test_dl = DataLoader(
        test_mnist,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
    )
    return src_train_dl, src_valid_dl, src_test_dl

In [4]:
def train(
    G,
    F,
    optimizer_G,
    optimizer_F,
    loss_fn,
    epochs,
    train_dl,
    valid_dl,
    device,
):
    logs = {
        'train_accuracy': [],
        'train_loss': [],
        'valid_accuracy': [],
        'valid_loss': [],
    }
    for epoch in range(epochs):
        train_accuracy, train_avg_loss = train_one_epoch(
            G,
            F,
            optimizer_G,
            optimizer_F,
            loss_fn,
            train_dl,
            device,
        )
        valid_accuracy, valid_avg_loss = evaluate(
            G,
            F,
            loss_fn,
            valid_dl,
            device
        )
        logs['train_accuracy'].append(train_accuracy)
        logs['train_loss'].append(train_avg_loss)
        logs['valid_accuracy'].append(valid_accuracy)
        logs['valid_loss'].append(valid_avg_loss)

        print(f'epoch: {epoch + 1} train loss={train_avg_loss:.3f}, train acc={train_accuracy:.3f}, valid loss={valid_avg_loss:.3f}, valid acc={valid_accuracy:.3f}')
    return logs


def train_one_epoch(
    G,
    F,
    optimizer_G,
    optimizer_F,
    loss_fn,
    train_dl,
    device,
):
    G.train()
    F.train()
    total_correct = 0
    total_data = 0
    total_loss = 0
    def zero_grad():
        optimizer_G.zero_grad()
        optimizer_F.zero_grad()

    for X, y in tqdm(train_dl, total=len(train_dl)):
        zero_grad()
        X = X.to(device)
        y = y.to(device)
        pred = F(G(X))
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer_G.step()
        optimizer_F.step()

        num_data = len(X)
        total_data += num_data
        total_correct += (torch.argmax(pred, dim=1) == y).sum().item()
        total_loss += loss.item() * total_data
    accuracy = total_correct / total_data
    avg_loss = total_loss / total_data
    return accuracy, avg_loss


@torch.no_grad()
def evaluate(G, F, loss_fn, valid_dl, device):
    G.eval()
    F.eval()
    total_correct = 0
    total_data = 0
    total_loss = 0
    for X, y in tqdm(valid_dl, total=len(valid_dl)):
        X = X.to(device)
        y = y.to(device)
        pred = F(G(X))
        loss = loss_fn(pred, y)

        num_data = len(X)
        total_data += num_data
        total_correct += (torch.argmax(pred, dim=1) == y).sum().item()
        total_loss += loss.item() * total_data
    accuracy = total_correct / total_data
    avg_loss = total_loss / total_data
    return accuracy, avg_loss

In [None]:
set_seed(0)
G, F = get_model(model_type='cnn', num_feature_extractor_out=512)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
G.to(device)
F.to(device)
mnist_train_dl, mnist_valid_dl, mnist_test_dl = get_source_dl()
loss_fn = nn.CrossEntropyLoss()
optimizer_G = optim.AdamW(G.parameters(), lr=LR, weight_decay=0.0005)
optimizer_F = optim.AdamW(F.parameters(), lr=LR, weight_decay=0.0005)
logs = train(
    G,
    F,
    optimizer_G,
    optimizer_F,
    loss_fn,
    epochs=EPOCHS,
    train_dl=mnist_train_dl,
    valid_dl=mnist_valid_dl,
    device=device,
)


100%|██████████| 1500/1500 [00:07<00:00, 188.15it/s]
100%|██████████| 375/375 [00:00<00:00, 452.95it/s]


epoch: 1 train loss=108.199, train acc=0.938, valid loss=20.061, valid acc=0.968


100%|██████████| 1500/1500 [00:07<00:00, 198.19it/s]
100%|██████████| 375/375 [00:00<00:00, 439.85it/s]


epoch: 2 train loss=67.753, train acc=0.972, valid loss=11.636, valid acc=0.982


 79%|███████▉  | 1191/1500 [00:06<00:01, 199.29it/s]