In [3]:
import argparse
import time
import os
import importlib.util
import random
import torch
import numpy as np
import torch.nn.functional as F
from google_drive_downloader import GoogleDriveDownloader as gdd

from torch import nn, Tensor
from torch.utils.data import IterableDataset
import imageio

In [4]:
def get_images(paths, labels, nb_samples=None, shuffle=True):
    """
    Takes a set of character folders and labels and returns paths to image files
    paired with labels.
    """
    if nb_samples is not None:
        sampler = lambda x: random.sample(x, nb_samples)
    else:
        sampler = lambda x: x
    image_labels = [
        (i, os.path.join(path, image))
        for i, path in zip(labels, paths)
        for image in sampler(os.listdir(path))
    ]
    if shuffle:
        random.shuffle(image_labels)
    return image_labels


class DataGenerator(IterableDataset):
    """
    Data Generator capable of generating batches of Omniglot data.
    A "class" is considered a class of omniglot digits.
    """

    def __init__(
        self,
        num_classes,
        num_samples_per_class,
        batch_type,
        config={},
        cache=True,
    ):
        """
        Args:
            num_classes: Number of classes for classification (N-way)
            num_samples_per_class: num samples to generate per class in one batch (K+1)
            batch_size: size of meta batch size (e.g. number of functions)
            batch_type: train/val/test
        """
        self.num_samples_per_class = num_samples_per_class
        self.num_classes = num_classes

        data_folder = config.get("data_folder", "./omniglot_resized")
        self.img_size = config.get("img_size", (28, 28))

        self.dim_input = np.prod(self.img_size)
        self.dim_output = self.num_classes

        character_folders = [
            os.path.join(data_folder, family, character)
            for family in os.listdir(data_folder)
            if os.path.isdir(os.path.join(data_folder, family))
            for character in os.listdir(os.path.join(data_folder, family))
            if os.path.isdir(os.path.join(data_folder, family, character))
        ]

        random.seed(1)
        random.shuffle(character_folders)
        num_val = 100
        num_train = 1100
        self.metatrain_character_folders = character_folders[:num_train]
        self.metaval_character_folders = character_folders[num_train : num_train + num_val]
        self.metatest_character_folders = character_folders[num_train + num_val :]
        self.image_caching = cache
        self.stored_images = {}

        if batch_type == "train":
            self.folders = self.metatrain_character_folders
        elif batch_type == "val":
            self.folders = self.metaval_character_folders
        else:
            self.folders = self.metatest_character_folders

    def image_file_to_array(self, filename, dim_input):
        """Takes an image path and returns numpy array"""
        if self.image_caching and (filename in self.stored_images):
            return self.stored_images[filename]
        image = imageio.imread(filename)  # misc.imread(filename)
        image = image.reshape([dim_input])
        image = image.astype(np.float32) / 255.0
        image = 1.0 - image
        if self.image_caching:
            self.stored_images[filename] = image
        return image

    def _sample(self):
        """Samples a batch for training, validation, or testing"""
        K, N = self.num_samples_per_class, self.num_classes

        sampled_classes = random.sample(self.folders, N)
        one_hot_labels = np.eye(N)

        images_and_labels = get_images(sampled_classes, one_hot_labels, nb_samples=K, shuffle=False)

        support_labels, support_images = [], []
        query_labels, query_images = [], []
        for i, (label, image) in enumerate(images_and_labels):
            if i % K == 0:
                support_labels.append(label)
                support_images.append(image)
            else:
                query_labels.append(label)
                query_images.append(image)

        zip_test_label_images = list(zip(query_labels, query_images))
        np.random.shuffle(zip_test_label_images)
        query_labels, query_images = zip(*zip_test_label_images)

        images_support = [self.image_file_to_array(x, self.dim_input) for x in support_images]
        images_query = [self.image_file_to_array(x, self.dim_input) for x in query_images]

        images = np.array(images_support+images_query).reshape(K, N, self.dim_input)
        labels = np.array(support_labels + [q for q in query_labels]).reshape(K, N, N)

        images = images.astype("float32")
        labels = labels.astype("float32")

        return images, labels

    def __iter__(self):
        while True:
            yield self._sample()


In [5]:
def initialize_weights(model):
    if type(model) in [nn.Linear]:
        nn.init.xavier_uniform_(model.weight)
        nn.init.zeros_(model.bias)
    elif type(model) in [nn.LSTM, nn.RNN, nn.GRU]:
        nn.init.orthogonal_(model.weight_hh_l0)
        nn.init.xavier_uniform_(model.weight_ih_l0)
        nn.init.zeros_(model.bias_hh_l0)
        nn.init.zeros_(model.bias_ih_l0)


class MANN(nn.Module):
    def __init__(self, num_classes, samples_per_class, hidden_dim):
        super(MANN, self).__init__()
        self.num_classes = num_classes
        self.samples_per_class = samples_per_class

        self.layer1 = torch.nn.LSTM(num_classes + 784, hidden_dim, batch_first=True)
        self.layer2 = torch.nn.LSTM(hidden_dim, num_classes, batch_first=True)
        initialize_weights(self.layer1)
        initialize_weights(self.layer2)

    def forward(self, input_images, input_labels):
        B, Kplusone, N, D = input_images.shape
        zero_labels = torch.zeros_like(input_labels[:, -1:, :, :])
        labels = torch.concat([input_labels[:, :-1, :, :], zero_labels], dim=1)

        x = torch.concat([input_images, labels], dim=3)
        x = x.reshape(
            x.shape[0],
            x.shape[1] * x.shape[2],
            x.shape[3],
        )

        x, _ = self.layer1(x)
        out, _ = self.layer2(x)

        out = out.reshape(B, Kplusone, N, N)
        return out
        
    def loss_function(self, preds, labels):
        """Computes MANN loss"""
        loss = None

        query_preds, query_labels = preds[:, -1, :, :], labels[:, -1, :, :]

        query_preds = query_preds.reshape(query_preds.shape[0]*self.num_classes, -1)
        query_labels = query_labels.reshape(query_labels.shape[0]*self.num_classes, -1)

        loss = F.cross_entropy(query_preds, query_labels)

        return loss


In [6]:
def train_step(images, labels, model, optim, eval=False):
    predictions = model(images, labels)
    loss = model.loss_function(predictions, labels)
    if not eval:
        optim.zero_grad()
        loss.backward()
        optim.step()
    return predictions.detach(), loss.detach()

In [7]:
def main(config):
    print(config)
    random.seed(config.random_seed)
    np.random.seed(config.random_seed)
    
    if config.device == "gpu" and torch.backends.mps.is_available() and torch.backends.mps.is_built():
        # Waiting for PyTorch 2.0 to enable MPS
        # device = torch.device("mps")
        
        # Default to cpu as for now mps it is not stable. 
        device = torch.device("cpu")
    elif config.device == "gpu" and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    print("Using device: ", device)

    torch.manual_seed(config.random_seed)
    
    assert os.path.isdir("./omniglot_resized")

    # Create Data Generator
    train_iterable = DataGenerator(
        config.num_classes,
        config.num_shot + 1,
        batch_type="train",
        cache=config.image_caching,
    )
    train_loader = iter(
        torch.utils.data.DataLoader(
            train_iterable,
            batch_size=config.meta_batch_size,
            num_workers=0,
            pin_memory=True,
        )
    )
    test_iterable = DataGenerator(
        config.num_classes,
        config.num_shot + 1,
        batch_type="test",
        cache=config.image_caching,
    )
    test_loader = iter(
        torch.utils.data.DataLoader(
            test_iterable,
            batch_size=config.meta_batch_size,
            num_workers=0,
            pin_memory=True,
        )
    )

    # Create model
    model = MANN(config.num_classes, config.num_shot + 1, config.hidden_dim)
    model.to(device)

    # Create optimizer
    optim = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    times = []
    for step in range(config.train_steps):
        ## Sample Batch
        t0 = time.time()
        i, l = next(train_loader)
        i, l = i.to(device), l.to(device)
        t1 = time.time()

        ## Train
        _, ls = train_step(i, l, model, optim)
        t2 = time.time()
        times.append([t1 - t0, t2 - t1])

        ## Evaluate
        if (step + 1) % config.eval_freq == 0:
            if config.debug == True:
                print("*" * 5 + "Iter " + str(step + 1) + "*" * 5)
            i, l = next(test_loader)
            i, l = i.to(device), l.to(device)
            pred, tls = train_step(i, l, model, optim, eval=True)
            print("Train Loss:", ls.cpu().numpy(), "Test Loss:", tls.cpu().numpy())
            pred = torch.reshape(
                pred, [-1, config.num_shot + 1, config.num_classes, config.num_classes]
            )

            pred = torch.argmax(pred[:, -1, :, :], axis=2)
            l = torch.argmax(l[:, -1, :, :], axis=2)
            acc = pred.eq(l).sum().item() / (config.meta_batch_size * config.num_classes)
            print("Test Accuracy", acc)

            times = np.array(times)
            print(f"Sample time {times[:, 0].mean()} Train time {times[:, 1].mean()}")
            times = []

In [13]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_classes", type=int, default=2)
    parser.add_argument("--num_shot", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--eval_freq", type=int, default=100)
    parser.add_argument("--meta_batch_size", type=int, default=128)
    parser.add_argument("--hidden_dim", type=int, default=256)
    parser.add_argument("--random_seed", type=int, default=123)
    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--train_steps", type=int, default=5000)
    parser.add_argument("--image_caching", type=bool, default=True)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--debug", type=str, default=True)
    parser.add_argument('--cache', action='store_true')

    args, unknown = parser.parse_known_args()
    
    main(args)

Namespace(num_classes=2, num_shot=1, num_workers=4, eval_freq=100, meta_batch_size=128, hidden_dim=256, random_seed=123, learning_rate=0.001, train_steps=5000, image_caching=True, device='cpu', debug=True, cache=False)
Using device:  cpu


  image = imageio.imread(filename)  # misc.imread(filename)


*****Iter 100*****
Train Loss: 0.6764758 Test Loss: 0.68418396
Test Accuracy 0.5234375
Sample time 0.3911420273780823 Train time 0.10928846120834351
*****Iter 200*****
Train Loss: 0.59750336 Test Loss: 0.53428966
Test Accuracy 0.72265625
Sample time 0.15379758834838866 Train time 0.16037741422653198
*****Iter 300*****
Train Loss: 0.5734464 Test Loss: 0.5729797
Test Accuracy 0.64453125
Sample time 0.07374746322631837 Train time 0.10701999664306641
*****Iter 400*****
Train Loss: 0.53648233 Test Loss: 0.53488576
Test Accuracy 0.703125
Sample time 0.07112741947174073 Train time 0.11463311910629273
*****Iter 500*****
Train Loss: 0.522375 Test Loss: 0.52411914
Test Accuracy 0.69921875
Sample time 0.06677749156951904 Train time 0.1030724835395813
*****Iter 600*****
Train Loss: 0.50266016 Test Loss: 0.5433821
Test Accuracy 0.703125
Sample time 0.07270204067230225 Train time 0.10892043828964233
*****Iter 700*****
Train Loss: 0.5039217 Test Loss: 0.5193359
Test Accuracy 0.75390625
Sample time 0.