In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import Omniglot
from torchvision.models import resnet18
from tqdm import tqdm
from auxillary_functions import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
image_dim = 28

train_set = Omniglot(
    root = "./data",
    background = True,
    transform = transforms.Compose(
        [
            transforms.Grayscale(num_output_channels = 3),
            transforms.RandomResizedCrop(image_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ],
    ),
    download = True
)

test_set = Omniglot(
    root = "./data",
    background = False,
    transform = transforms.Compose(
        [
            transforms.Grayscale(num_output_channels = 3),
            transforms.Resize([int(image_dim * 1.15), int(image_dim * 1.15)]),
            transforms.CenterCrop(image_dim),
            transforms.ToTensor()
        ]
    ),
    download = True
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
class ProtoNN(nn.Module):
    def __init__(self, backbone:nn.Module):
        super(ProtoNN, self).__init__()
        self.backbone = backbone #will be using pretrained resnet backbone for feature extraction from images
    
    def forward(self, support_images: torch.Tensor, support_labels: torch.Tensor, query_images: torch.Tensor) -> torch.Tensor:
        # Predicting query labels from labeled support images
        
        #Feature extraction
        support = self.backbone.forward(support_images)
        query = self.backbone.forward(query_images)
        
        n_way = len(torch.unique(support_labels)) #num classes from num labels
        proto = torch.cat([support[torch.nonzero(support_labels == label)].mean(0) for label in range(n_way)])
        
        distances = torch.cdist(query, proto)
        
        return -distances #classification scores given by negative distances

cnn = resnet18(pretrained = True)
cnn.fc = nn.Flatten()
model = ProtoNN(cnn)



In [17]:
n_way = 5
n_shot = 5
n_query = 10
n_evaluation_tasks = 100
def label_getter(test_set):
    return [x[1] for x in test_set._flat_character_images]
#test_set.get_labels = lambda: [x[1] for x in test_set._flat_character_images]
test_set.get_labels = label_getter(test_set)
test_sampler = TaskSampler(test_set, n_way = n_way, n_shot = n_shot,n_query = n_query, n_tasks = n_evaluation_tasks)
test_load = DataLoader(test_set, batch_sampler=test_sampler,  collate_fn=test_sampler.episodic_collate)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate
)
if __name__ == "__main__":
    (
        example_support_images,
        example_support_labels,
        example_query_images,
        example_query_labels,
        example_class_ids
    ) = next(iter(test_loader))

# plot_images(example_support_images, "support images", images_per_row=n_shot)
# plot_images(example_query_images, "query images", images_per_row=n_query)

TypeError: 'list' object is not callable