In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, MNIST
from torchvision.models import alexnet
import torch.nn.functional as F
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average

from AlexNetLastTwoLayers import AlexNetLastTwoLayers
from PrototypicalNetworks import PrototypicalNetworks
from PrototypicalFlagNetworks import PrototypicalFlagNetworks


import copy

from statistics import mean

from matplotlib import pyplot as plt

In [None]:
def count_accuracy(logits, label):
    pred = torch.argmax(logits, dim=1).view(-1)
    label = label.view(-1)
    accuracy = 100 * pred.eq(label).float().mean()
    return accuracy

def one_hot(indices, depth):
    """
    Returns a one-hot tensor.
    This is a PyTorch equivalent of Tensorflow's tf.one_hot.
        
    Parameters:
      indices:  a (n_batch, m) Tensor or (m) Tensor.
      depth: a scalar. Represents the depth of the one hot dimension.
    Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
    """

    encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
    index = indices.view(indices.size()+torch.Size([1]))
    encoded_indicies = encoded_indicies.scatter_(1,index,1)
    
    return encoded_indicies

def training_epoch(model_, data_loader, optimizer):
    all_loss = []
    model_.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()

            logit_query = model_(support_images.cuda(), support_labels.cuda(), query_images.cuda())

            train_way = len(torch.unique(support_labels))
            smoothed_one_hot = one_hot(query_labels.reshape(-1).cuda(), train_way)
            log_prb = F.log_softmax(logit_query.reshape(-1, train_way), dim=1)
            loss = -(smoothed_one_hot * log_prb).sum(dim=1)
            loss = loss.mean()

            loss.backward()
            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

def val_evaluate(model_, val_loader):
    model_.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # No gradients needed during validation
        for val_support_images, val_support_labels, val_query_images, val_query_labels, _ in val_loader:
            # Obtain validation predictions
            val_preds = model_(val_support_images.cuda(), val_support_labels.cuda(), val_query_images.cuda())
            
            # Count correct predictions
            correct += (val_preds.argmax(dim=2).reshape(-1) == val_query_labels.cuda()).sum().item()
            total += val_query_labels.size(0)

    # Calculate validation accuracy
    val_accuracy = correct / total
    return val_accuracy

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
random_seed = 0
# np.random.seed(random_seed)
torch.manual_seed(random_seed)
# random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:

transform = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
        ]
)


train_data = CIFAR10(
    root="../data",
    transform= transform,
    download=True,
    train = True
)
test_data = CIFAR10(
    root="../data",
    transform=transform,
    download=True,
    train = False   
)


# split of training data into train and validation sets
train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_subset, val_subset = torch.utils.data.random_split(train_data, [train_size, val_size])


Files already downloaded and verified
Files already downloaded and verified


In [5]:
N_WAY = 5  # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 10  # Number of images per class in the query set
N_EVALUATION_TASKS = 100

# The sampler needs a dataset with a "get_labels" method. Check the code if you have any doubt!
val_subset.get_labels = lambda: [
    instance[1] for instance in val_subset
]
test_data.get_labels = lambda: [
    instance[1] for instance in test_data
]

test_sampler = TaskSampler(
    test_data, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

val_sampler = TaskSampler(
    val_subset, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

test_loader = DataLoader(
    test_data,
    batch_sampler=test_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
    shuffle = False
)

val_loader = DataLoader(
    val_subset,
    batch_sampler=val_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
    shuffle = False
)



N_TASKS_PER_EPOCH = 500
N_VALIDATION_TASKS = 100

train_data.get_labels = lambda: [instance[1] for instance in train_data]

train_sampler = TaskSampler(
    train_data, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TASKS_PER_EPOCH
)

train_loader = DataLoader(
    train_data,
    batch_sampler=train_sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

In [None]:
# Eval Protonets

backbone = alexnet(pretrained = True)
backbone.classifier[6] = nn.Flatten()
model = PrototypicalNetworks(backbone, head = 'ProtoNet').to(device)


train_optimizer = optim.Adam(model.parameters(), lr=1e-5)
n_epochs = 40

train_losses = []
val_accs = []
best_state = model.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)
    validation_accuracy = val_evaluate(model, val_loader)

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(model.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print(f"Ding ding ding! We found a new best model! {best_validation_accuracy}")

    # tb_writer.add_scalar("Train/loss", average_loss, epoch)
    # tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    # Warn the scheduler that we did an epoch
    # so it knows when to decrease the learning rate
    # train_scheduler.step()
    train_losses.append(average_loss)
    val_accs.append(validation_accuracy)


torch.save(best_state, 'cirfar10_protonets.pth')



plt.figure()
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')


plt.figure()
plt.plot(val_accs)
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')

In [7]:
# Eval Subspace Nets

backbone = alexnet(pretrained = True)
backbone.classifier[6] = nn.Flatten()
model = PrototypicalNetworks(backbone, head = 'SubspaceNet')


train_optimizer = optim.Adam(model.parameters(), lr=1e-5)
n_epochs = 40

train_losses = []
val_accs = []
best_state = model.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)
    validation_accuracy = val_evaluate(model, val_loader)

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(model.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print(f"Ding ding ding! We found a new best model! {best_validation_accuracy}")

    # tb_writer.add_scalar("Train/loss", average_loss, epoch)
    # tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)


    train_losses.append(average_loss)
    val_accs.append(validation_accuracy)


torch.save(best_state, '../models/cirfar10_subspacenets.pth')


plt.figure()
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')


plt.figure()
plt.plot(val_accs)
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')





In [11]:

# Eval Subspace Nets

my_alexnet = alexnet(pretrained = True)
backbone = AlexNetLastTwoLayers(my_alexnet)
model = PrototypicalFlagNetworks(backbone)


train_optimizer = optim.Adam(model.parameters(), lr=1e-5)
n_epochs = 40

train_losses = []
val_accs = []
best_state = model.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(model, train_loader, train_optimizer)
    validation_accuracy = val_evaluate(model, val_loader)

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = copy.deepcopy(model.state_dict())
        # state_dict() returns a reference to the still evolving model's state so we deepcopy
        # https://pytorch.org/tutorials/beginner/saving_loading_models
        print(f"Ding ding ding! We found a new best model! {best_validation_accuracy}")

    # tb_writer.add_scalar("Train/loss", average_loss, epoch)
    # tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    train_losses.append(average_loss)
    val_accs.append(validation_accuracy)
    


torch.save(best_state, '../models/cifar10_flagnets.pth')


plt.figure()
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')


plt.figure()
plt.plot(val_accs)
plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')


