In [1]:
import torch
from torch import nn, optim

from data import DatasetGZSL
from models import CSVaDE

from train import *
from utils import *

In [2]:
# Build dataset & model
dataset = DatasetGZSL('AWA2', True, 'cpu')

model = CSVaDE('cadavae',
                cnn_dim=dataset.cnn_dim,
                att_dim=dataset.att_dim,
                num_classes=len(dataset.classes),
                device='cpu',
                load_pretrained=True,
                reset_classifier=True)

Loading pretrained model from: saved/cadavae.pt



In [3]:
# Lock VAE weights
for param in model.parameters():
    param.requires_grad = False
for param in model.classifier.parameters():
    param.requires_grad = True


optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=[0.5, 0.999])
loss_function = nn.NLLLoss()
num_epochs=100
batch_size=100
num_seen=200
num_unseen=400
early_stop=4
top_k_acc=1
verbose=True
tensorboard_dir='tensorboards/models/'

In [4]:
import torch
import numpy as np

from torch.utils.data import Subset, Dataset

class EmbeddingsDataset(Dataset):
    def __init__(self, dataset, model, num_seen, num_unseen):
        cuda = model.device.startswith('cuda')

        # Get seen embeddings from cnn features
        features, _, labels = dataset[dataset.trainval_idx]
        if cuda:
            labels = labels.to('cpu')

        z_cnns = []
        l_cnns = []
        
        for c in dataset.seen_classes:
            idx = np.where(labels == c)[0]
            idx = np.random.choice(idx, num_seen, replace=True)

            x_cnn = features[idx]
            l_cnn = labels[idx]

            _, z_cnn, _ = model.cnn_encoder(x_cnn)
            
            z_cnns.append(z_cnn)
            l_cnns.append(l_cnn)
        
        z_cnns = torch.cat(z_cnns)
        l_cnns = torch.cat(l_cnns)

        # Get unseen embeddings from attributes
        _, attributes, labels = dataset[dataset.test_unseen_idx]
        if cuda:
            labels = labels.to('cpu')

        z_atts = []
        l_atts = []

        for c in dataset.unseen_classes:
            idx = np.where(labels == c)[0]
            idx = np.random.choice(idx, num_unseen, replace=True)

            x_att = attributes[idx]
            l_att = labels[idx]

            _, z_att, _ = model.att_encoder(x_att)
            
            z_atts.append(z_att)
            l_atts.append(l_att)
        
        z_atts = torch.cat(z_atts)
        l_atts = torch.cat(l_atts)

        self.embeddings = torch.cat([z_cnns, z_atts]).to(model.device)
        self.labels     = torch.cat([l_cnns, l_atts]).to(model.device)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if type(idx) is np.ndarray:
            idx = idx.tolist()

        return [self.embeddings[idx, :], self.labels[idx]]

embeddingset = EmbeddingsDataset(dataset, model, num_seen, num_unseen)
trainloader = DataLoader(embeddingset, batch_size=batch_size, shuffle=True)

Class: antelope: torch.Size([200, 64])
Class: grizzly bear: torch.Size([200, 64])
Class: killer whale: torch.Size([200, 64])
Class: beaver: torch.Size([200, 64])
Class: dalmatian: torch.Size([200, 64])
Class: persian cat: torch.Size([200, 64])
Class: german shepherd: torch.Size([200, 64])
Class: siamese cat: torch.Size([200, 64])
Class: skunk: torch.Size([200, 64])
Class: mole: torch.Size([200, 64])
Class: tiger: torch.Size([200, 64])
Class: hippopotamus: torch.Size([200, 64])
Class: leopard: torch.Size([200, 64])
Class: moose: torch.Size([200, 64])
Class: spider monkey: torch.Size([200, 64])
Class: humpback whale: torch.Size([200, 64])
Class: elephant: torch.Size([200, 64])
Class: gorilla: torch.Size([200, 64])
Class: ox: torch.Size([200, 64])
Class: fox: torch.Size([200, 64])
Class: chimpanzee: torch.Size([200, 64])
Class: hamster: torch.Size([200, 64])
Class: squirrel: torch.Size([200, 64])
Class: rhinoceros: torch.Size([200, 64])
Class: rabbit: torch.Size([200, 64])
Class: wolf: to

In [5]:
start_epoch = len(model.classifier_history[0]) + 1

for epoch in range(start_epoch, num_epochs+1):
    epoch_loss = 0.0

    if verbose:
        print('Epoch {}/{}'.format(epoch, num_epochs))

    # Train
    model.train()
    for (i_batch, (z, labels)) in enumerate(trainloader, start=1):
        start = time.time_ns()

        # Forward pass
        pred = model.classifier(z)
        loss = loss_function(pred, labels.long())

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        end = time.time_ns()

        # Update
        epoch_loss += loss.item()
        if verbose:
            print_progress(i_batch, len(trainloader), end-start, loss.item())

    model.classifier_history[0] = np.append(model.classifier_history[0], epoch_loss / len(trainloader))

    # Test
    model.eval()
    with torch.no_grad():
        # Forward pass for seen classes
        features, _, labels = dataset[dataset.test_seen_idx]
        _, pred = torch.topk(model(features), top_k_acc, dim=1)
        seen_acc = top_k_accuracy(pred, labels)

        # Forward pass for unseen classes
        features, _, labels = dataset[dataset.test_unseen_idx]
        _, pred = torch.topk(model(features), top_k_acc, dim=1)
        unseen_acc = top_k_accuracy(pred, labels)

    # Calculate total accuracy
    acc = 2*seen_acc*unseen_acc / (seen_acc + unseen_acc)
    print('Test: S = {:.1f}| U = {:.1f}| \033[1mH = {:.1f}\033[0m \n'.format(seen_acc, unseen_acc, acc))

    # Update history & save checkpoint iff current epoch is the best
    model.classifier_history[1] = np.append(model.classifier_history[1], seen_acc)
    model.classifier_history[2] = np.append(model.classifier_history[2], unseen_acc)
    model.classifier_history[3] = np.append(model.classifier_history[3], acc)

    # Early Stop
    if early_stop != None:
        if len(model.classifier_history[3]) > early_stop+1 and acc <= min(model.classifier_history[3][-early_stop-1:-1]):
            if verbose:
                print('Stopped at epoch {} because H-accuracy stopped improving\n'.format(epoch))
                break

idx = np.argmax(model.classifier_history[3])

Class: antelope: torch.Size([200, 64])
Class: grizzly bear: torch.Size([200, 64])
Class: killer whale: torch.Size([200, 64])
Class: beaver: torch.Size([200, 64])
Class: dalmatian: torch.Size([200, 64])
Class: persian cat: torch.Size([200, 64])
Class: german shepherd: torch.Size([200, 64])
Class: siamese cat: torch.Size([200, 64])
Class: skunk: torch.Size([200, 64])
Class: mole: torch.Size([200, 64])
Class: tiger: torch.Size([200, 64])
Class: hippopotamus: torch.Size([200, 64])
Class: leopard: torch.Size([200, 64])
Class: moose: torch.Size([200, 64])
Class: spider monkey: torch.Size([200, 64])
Class: humpback whale: torch.Size([200, 64])
Class: elephant: torch.Size([200, 64])
Class: gorilla: torch.Size([200, 64])
Class: ox: torch.Size([200, 64])
Class: fox: torch.Size([200, 64])
Class: chimpanzee: torch.Size([200, 64])
Class: hamster: torch.Size([200, 64])
Class: squirrel: torch.Size([200, 64])
Class: rhinoceros: torch.Size([200, 64])
Class: rabbit: torch.Size([200, 64])
Class: wolf: to

KeyboardInterrupt: 

In [5]:
start_epoch = len(model.classifier_history[0]) + 1

trainset   = Subset(dataset, dataset.trainval_idx)
dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

for epoch in range(start_epoch, num_epochs+1):
    epoch_loss = 0.0

    if verbose:
        print('Epoch {}/{}'.format(epoch, num_epochs))

    # Train
    model.train()
    for (i_batch, (features, _, labels)) in enumerate(dataloader, start=1):
        start = time.time_ns()

        # Forward pass
        z, _, _ = model.cnn_encoder(features)
        pred = model.classifier(z)
        loss = loss_function(pred, labels.long())

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        end = time.time_ns()

        # Update
        epoch_loss += loss.item()
        if verbose:
            print_progress(i_batch, len(dataloader), end-start, loss.item())

    model.classifier_history[0] = np.append(model.classifier_history[0], epoch_loss / len(dataloader))

    # Test
    model.eval()
    with torch.no_grad():
        # Forward pass for seen classes
        features, _, labels = dataset[dataset.test_seen_idx]
        _, pred = torch.topk(model(features), top_k_acc, dim=1)
        seen_acc = top_k_accuracy(pred, labels)

        # Forward pass for unseen classes
        features, _, labels = dataset[dataset.test_unseen_idx]
        _, pred = torch.topk(model(features), top_k_acc, dim=1)
        unseen_acc = top_k_accuracy(pred, labels)

    # Calculate total accuracy
    acc = 2*seen_acc*unseen_acc / (seen_acc + unseen_acc)
    print('Test: S = {:.1f}| U = {:.1f}| \033[1mH = {:.1f}\033[0m \n'.format(seen_acc, unseen_acc, acc))

    # Update history & save checkpoint iff current epoch is the best
    model.classifier_history[1] = np.append(model.classifier_history[1], seen_acc)
    model.classifier_history[2] = np.append(model.classifier_history[2], unseen_acc)
    model.classifier_history[3] = np.append(model.classifier_history[3], acc)

    # Early Stop
    """if early_stop != None:
        if len(model.classifier_history[3]) > early_stop+1 and acc <= min(model.classifier_history[3][-early_stop-1:-1]):
            if verbose:
                print('Stopped at epoch {} because H-accuracy stopped improving\n'.format(epoch))
                break"""

idx = np.argmax(model.classifier_history[3])

Epoch 1/100
236/236 |████████████████████████████████████████████████████████████| 100.0% - 4 ms/batch - loss: 3.88
Test: S = 2.6| U = 0.4| [1mH = 0.7[0m 

Epoch 2/100
236/236 |████████████████████████████████████████████████████████████| 100.0% - 5 ms/batch - loss: 3.98
Test: S = 4.6| U = 0.0| [1mH = 0.0[0m 

Epoch 3/100
236/236 |████████████████████████████████████████████████████████████| 100.0% - 4 ms/batch - loss: 3.96
Test: S = 6.7| U = 0.0| [1mH = 0.0[0m 

Epoch 4/100
236/236 |████████████████████████████████████████████████████████████| 100.0% - 4 ms/batch - loss: 3.86
Test: S = 9.3| U = 0.0| [1mH = 0.0[0m 

Epoch 5/100
236/236 |████████████████████████████████████████████████████████████| 100.0% - 4 ms/batch - loss: 4.02
Test: S = 11.9| U = 0.0| [1mH = 0.0[0m 

Epoch 6/100


KeyboardInterrupt: 