In [1]:
import torch 
from torch.utils.data import DataLoader
import os
import time
import numpy as np
import math
from torchvision import datasets
from torchvision.transforms import transforms, InterpolationMode

from model import ProtoModel
from sampler import FewShotBatchSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_ways_tr, n_ways_ts, num_support, num_query = 60, 5, 1, 1
num_epochs, eposides = 100, 100
lr, lr_gamma, lr_policy, lr_decay_epochs = 1e-3, 0.5, 'step', 20
num_hiddens, out_channels = 64, 64
gpu_ids, distance = (0), 'euclidean'
print_per_epoch = 1

# for early stop
best_evaluation_accuracy = 0
best_evaluation_epoch = 0
endure_epochs = 20
save_dir = "../checkpoints/"

In [3]:
def init_seed(seed):
    '''
    Disable cudnn to maximize reproducibility
    '''
    # torch.cuda.cudnn_enabled = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
init_seed(520)

In [4]:
train_transform = transforms.Compose([
    transforms.Resize((28, 28), interpolation=InterpolationMode.BICUBIC),
    transforms.RandomApply(
        [transforms.RandomAffine(degrees=15, shear=0.3 * 180 / math.pi, scale=(0.8, 2.0))],
        p=0.5
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.87, std=0.33)
])
test_transform = transforms.Compose([
    transforms.Resize((28, 28), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.87, std=0.33)
])
train_dataset = datasets.Omniglot(root='../data', background=True, transform=train_transform, download=False)
test_dataset = datasets.Omniglot(root='../data', background=False, transform=test_transform, download=False)

In [5]:
model = ProtoModel(
    1, out_channels, num_hiddens, lr, gpu_ids, distance, is_train=True, 
    lr_policy=lr_policy, gamma=lr_gamma, lr_decay_iters=lr_decay_iters
)
model.init_net()
model.to_device()

initialize network with normal


In [6]:
train_labels = np.array([train_dataset[i][1] for i in range(len(train_dataset))])
test_labels = np.array([test_dataset[i][1] for i in range(len(test_dataset))])
train_sampler = FewShotBatchSampler(train_labels, n_ways_tr, num_support + num_query, eposides)
test_sampler = FewShotBatchSampler(test_labels, n_ways_ts, num_support + num_query, eposides)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_sampler=test_sampler)

In [7]:
def split_support_query(X, n_ways, n_support, n_query):
    assert X.size(0) == n_ways * (n_support + n_query)
    s_idxs = torch.LongTensor(np.arange(0, X.size(0), num_support + num_query))
    q_idxs = torch.LongTensor(np.arange(num_support, X.size(0), num_support + num_query))
    Xs = X[s_idxs]
    Xq = X[q_idxs]
    
    return Xs, Xq

In [8]:
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []

for epoch in range(num_epochs):
    st = time.time()
    train_loss_mean, train_acc_mean, train_size = 0.0, 0.0, 0
    for X, _ in train_loader:
        Xs, Xq = split_support_query(X, n_ways_tr, num_support, num_query)
        train_loss, train_acc = model.train_on_batch(Xs, Xq, n_ways_tr)
        b_size = Xq.size(0)
        train_size += b_size
        train_loss_mean += train_loss * b_size
        train_acc_mean += train_acc * b_size
    train_losses.append(train_loss_mean / train_size)
    train_accuracies.append(train_acc_mean / train_size)
    model.update_lr()
    
    test_loss_mean, test_acc_mean, test_size = 0.0, 0.0, 0
    for X, _ in test_loader:
        Xs, Xq = split_support_query(X, n_ways_ts, num_support, num_query)
        test_loss, test_acc = model.test_on_batch(Xs, Xq, n_ways_ts)  
        b_size = Xq.size(0)     
        test_size += b_size
        test_loss_mean += test_loss * b_size
        test_acc_mean += test_acc * b_size
    test_losses.append(test_loss_mean / test_size)
    test_accuracies.append(test_acc_mean / test_size)
    
    if epoch % print_per_epoch == 0 or epoch == num_epochs - 1:
        print(f"Epoch {epoch+1}/{num_epochs} {time.time() - st :.2f}sec :", end=' ')
        print(f"Train Loss: {train_losses[epoch] :.4f} Train Accuracy: {train_accuracies[epoch] :.4f}", end="  ")
        print(f"Test Loss: {test_losses[epoch] :.4f} Test Accuracy: {test_accuracies[epoch] :.4f}")
    if test_accuracies[epoch] > best_evaluation_accuracy:
        best_evaluation_accuracy = test_accuracies[epoch]
        best_evaluation_epoch = epoch
    if epoch - best_evaluation_epoch > endure_epochs:
        print("Early stop")
        print(f"Best evaluation accuracy: {best_evaluation_accuracy :.4f} Current evaluation accuracy: {test_accuracies[epoch] :.4f}")
        model.save_networks(save_dir=save_dir, epoch=epoch+1)
        break

Epoch 1/100 17.18sec : Train Loss: 6.7643 Train Accuracy: 0.1532  Test Loss: 0.9766 Test Accuracy: 0.6560
Epoch 2/100 17.92sec : Train Loss: 3.4611 Train Accuracy: 0.2047  Test Loss: 0.7770 Test Accuracy: 0.6940
Epoch 3/100 18.17sec : Train Loss: 2.9555 Train Accuracy: 0.2650  Test Loss: 0.6426 Test Accuracy: 0.7720
Epoch 4/100 17.62sec : Train Loss: 2.6585 Train Accuracy: 0.3093  Test Loss: 0.5547 Test Accuracy: 0.7940
Epoch 5/100 17.29sec : Train Loss: 2.4651 Train Accuracy: 0.3547  Test Loss: 0.4086 Test Accuracy: 0.8540
Epoch 6/100 18.02sec : Train Loss: 2.3107 Train Accuracy: 0.3948  Test Loss: 0.3767 Test Accuracy: 0.8600
Epoch 7/100 18.30sec : Train Loss: 2.1985 Train Accuracy: 0.4190  Test Loss: 0.4001 Test Accuracy: 0.8500
Epoch 8/100 18.33sec : Train Loss: 2.0191 Train Accuracy: 0.4513  Test Loss: 0.3154 Test Accuracy: 0.8920
Epoch 9/100 17.89sec : Train Loss: 1.9795 Train Accuracy: 0.4650  Test Loss: 0.3110 Test Accuracy: 0.8920
Epoch 10/100 18.39sec : Train Loss: 1.8925 Tra