In [1]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
import torch.optim as optim
import torchvision.transforms as transforms
import scipy.stats as stats
from copy import deepcopy



In [17]:
class OmniglotLoader():
    def __init__(self, batch_size, n_way, k_spt, k_qry, downsampled_size):
        """
        Args:
        batchsize: imgs per batch
        n_way: number of classes in support set
        k_spt: number of images per class in support set
        k_qry: number of images per class in query set
        """
        
        transform = transforms.Compose([
            lambda x: x.resize((downsampled_size, downsampled_size)),
            lambda x: np.reshape(x, (downsampled_size, downsampled_size, 1)),
            lambda x: (x - 128)/128
        ])
        
        self.x_train = torchvision.datasets.Omniglot(root='./data',
                                               download=True, transform=transform)
        self.x_val = torchvision.datasets.Omniglot(root='./data', background = False,
                                            download=True, transform=transform)
        temp_train = {}
        for (img, label) in self.x_train:
            if label in temp_train.keys():
                temp_train[label].append(img)
            else:
                temp_train[label] = [img]

        temp_val = {}
        for (img, label) in self.x_val:
            if label in temp_val.keys():
                temp_val[label].append(img)
            else:
                temp_val[label] = [img]

        self.x = []
        for label, imgs in temp_train.items():
            self.x.append(np.array(imgs))
        for label, imgs in temp_val.items():
            self.x.append(np.array(imgs))

        self.x = np.array(self.x).astype(np.float32)

        temp = []
        
        self.x_train = self.x[:964]
        self.x_test = self.x[964:]
        self.num_classes = self.x.shape[0]
        self.batch_size = batch_size
        self.downsampled_size = downsampled_size
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),
                               "test": self.load_data_cache(self.datasets["test"])}

    def load_data_cache(self, data_pack):

        num_episodes = 10
        
        set_size = self.k_spt * self.n_way
        query_size = self.k_qry * self.n_way
        data_cache = []

        for sample in range(num_episodes):
            x_supports, y_supports, x_queries, y_queries = [], [], [], []
            for _ in range(self.batch_size):
                x_support, y_support, x_query, y_query = [], [], [], []
                selected_class = np.random.choice(data_pack.shape[0], self.n_way, False)
                for j, cur_class in enumerate(selected_class):
                    selected_image = np.random.choice(20, self.k_spt + self.k_qry, False)
                    # meta-training and meta-test
                    x_support.append(data_pack[cur_class][selected_image[:self.k_spt]])
                    x_query.append(data_pack[cur_class][selected_image[self.k_spt:]])
                    y_support.append([j for _ in range(self.k_spt)])
                    y_query.append([j for _ in range(self.k_qry)])
                    
                support_perm = np.random.permutation(self.n_way * self.k_spt)
                x_support = np.array(x_support).reshape(self.n_way * self.k_spt, 1, self.downsampled_size, self.downsampled_size)[support_perm]
                y_support = np.array(y_support).reshape(self.n_way * self.k_spt)[support_perm]
                query_perm = np.random.permutation(self.n_way * self.k_qry)
                x_query = np.array(x_query).reshape(self.n_way * self.k_qry, 1, self.downsampled_size, self.downsampled_size)[query_perm]
                y_query = np.array(y_query).reshape(self.n_way * self.k_qry)[query_perm]

                x_supports.append(x_support)
                y_supports.append(y_support)
                x_queries.append(x_query)
                y_queries.append(y_query)
                
            x_supports = np.array(x_supports).astype(np.float32).reshape(self.batch_size, set_size, 1, self.downsampled_size, self.downsampled_size)
            y_supports = np.array(y_supports).astype(np.int32).reshape(self.batch_size, set_size)

            x_queries = np.array(x_queries).astype(np.float32).reshape(self.batch_size, query_size, 1, self.downsampled_size, self.downsampled_size)
            y_queries = np.array(y_queries).astype(np.int32).reshape(self.batch_size, query_size)

            data_cache.append([x_supports, y_supports, x_queries, y_queries])
            
        return data_cache

    def next(self, mode = "train"):
        # update cache if needed
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1
        return next_batch

In [56]:
class ONNModel(ONNBaseModel):
    def __init__(self, device=torch.device("cpu")):
        super().__init__()
        self.conv = onn.layers.MZIBlockConv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            bias=False,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.pool = nn.AdaptiveAvgPool2d(5)
        self.linear = onn.layers.MZIBlockLinear(
            in_features=6*5*5,
            out_features=5, # because we're doing 5-way
            bias=False,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, x):
        print(f"we haven't we exploded yet? x is: {x}")
        x = self.pool(torch.relu(self.conv(x)))
        x = torch.flatten(x, 1)
        print(f"have we exploded yet? x is: {x}")
        x = self.linear(x)
        return x

In [60]:
class Meta(nn.Module):
    # note that the maml-pytorch library uses an argparser to handle most of this instead of passing in individual parameters
    def __init__(self, update_lr, meta_lr, n_way, k_spt, k_qry, task_num, update_step, update_step_test):
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        
        super(Meta, self).__init__()
        self.net = ONNModel()
        self.net.train()
        self.meta_optim = optim.Adam(list(self.net.parameters()), lr=self.meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]
        for i in range(task_num):
            print(f"conv layer params: {list(maml.net.conv.parameters())}")
            # might need to clip grad norms here later

            # calculate loss before any update
            with torch.no_grad():
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] += correct
                
            for k in range(self.update_step):
                logits = self.net(x_spt[i])
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, list(self.net.parameters()), create_graph=True)
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(self.net.parameters()))))
                # need to somehow update weights of nn with new parameters
                
                # set new net values
                for l, param in enumerate(self.net.parameters()):
                    param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
                # apply to query
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] += correct
                    
        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()
        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetuning(self, x_spt, y_spt, x_qry, y_qry):
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)
        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        with torch.no_grad():
            logits_q = net(x_qry)
            pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] += correct
        for k in range(self.update_step_test):
            # problem: grad is somehow None
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(net.parameters()))))
            for l, param in enumerate(net.parameters()):
                param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
            logits = net(x_spt)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, net.parameters())
            logits_q = net(x_qry)
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] += correct
        del net
        return np.array(corrects)/querysz

In [61]:
device = torch.device('cpu')
maml = Meta(update_lr = 0.4, meta_lr = 0.001, n_way = 5, k_spt = 1, k_qry = 15, task_num = 32, update_step = 5, update_step_test = 10).to(device)
db_train = OmniglotLoader(batch_size = 32, n_way = 5, k_spt = 1, k_qry = 15, downsampled_size = 28)

Files already downloaded and verified
Files already downloaded and verified


In [62]:
num_epochs = 10000

for step in range(num_epochs):
    x_spt, y_spt, x_qry, y_qry = db_train.next()
    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

    accs = maml(x_spt, y_spt, x_qry, y_qry)

    if step % 50 == 0:
        print('step:', step, '\ttraining acc:', accs)

    if step % 500 == 0:
        accs = []
        for _ in range(1000//32):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                        torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetuning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                accs.append(test_acc)
        accs = np.array(accs).mean(axis = 0).astype(np.float16)
        print('Test acc:', accs)


conv layer params: [Parameter containing:
tensor([[[[-0.0465,  0.3086,  0.8370,  0.4495],
          [-0.8802,  0.3038, -0.2839,  0.2290],
          [-0.0376,  0.6262,  0.1746, -0.7589],
          [-0.4709, -0.6483,  0.4340, -0.4118]],

         [[-0.5124,  0.4457, -0.4534,  0.5773],
          [-0.6847,  0.0634,  0.7204, -0.0908],
          [-0.4114,  0.0235, -0.4900, -0.7682],
          [-0.3153, -0.8926, -0.1881,  0.2615]],

         [[ 0.3364, -0.3205, -0.6565, -0.5942],
          [ 0.3783,  0.9020, -0.2022, -0.0490],
          [-0.7646,  0.1928, -0.6015,  0.1277],
          [ 0.3988, -0.2156, -0.4078,  0.7926]]],


        [[[-0.5796,  0.4690, -0.6659,  0.0262],
          [-0.4744,  0.2137,  0.5373, -0.6638],
          [ 0.2590, -0.3744, -0.5176, -0.7245],
          [-0.6099, -0.7708, -0.0048,  0.1838]],

         [[-0.5179,  0.0366,  0.7253, -0.4520],
          [-0.3762, -0.0895,  0.2829,  0.8778],
          [-0.7266, -0.3053, -0.5970, -0.1501],
          [-0.2497,  0.9473, -0.1937

KeyboardInterrupt: 