In [145]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
import torch.optim as optim
import torchvision.transforms as transforms
from torchonn.op.mzi_op import project_matrix_to_unitary
import scipy.stats as stats
from copy import deepcopy

In [315]:
torch.manual_seed(3407)

class ONNModel(ONNBaseModel):
    def __init__(self, device=torch.device("cpu")):
        super().__init__()
        self.counter = 0
        self.linear0 = onn.layers.MZIBlockLinear(
            in_features=28*28,
            out_features=100, # because we're doing 5-way
            bias=False,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        # self.linear1 = onn.layers.MZIBlockLinear(
        #     in_features=100,
        #     out_features=5, # because we're doing 5-way
        #     bias=False,
        #     miniblock=4,
        #     mode="usv",
        #     decompose_alg="clements",
        #     photodetect=True,
        #     device=device,
        # )
        self.linear0.reset_parameters()
        # self.linear1.reset_parameters()

    def unitary_projection(self) -> None:
        for m in self.modules():
            if isinstance(m, onn.layers.MZIBlockLinear):
                # print(f"U is: {m.U.data}")
                m.U.data.copy_(project_matrix_to_unitary(m.U.data))
                m.V.data.copy_(project_matrix_to_unitary(m.V.data))

    def forward(self, x):
        # if self.counter < 7:
        #     print(self.linear0.U)
        #     self.counter += 1
        x = torch.flatten(x, 1)
        x = torch.relu(self.linear0(x))
        # x = self.linear1(x)
        return x

In [346]:
max_norm = 5

def normalize_grads(grad):
    total_norm = torch.norm(torch.stack([torch.norm(g.detach()).to(device) for g in grad]))
    clip_coef_clamped = torch.clamp(max_norm/(total_norm + 1e-6), max = 1.0)
    for g in grad:
        g.detach().mul_(clip_coef_clamped.to(g.device))

class Meta(nn.Module):
    # note that the maml-pytorch library uses an argparser to handle most of this instead of passing in individual parameters
    def __init__(self, update_lr, meta_lr, n_way, k_spt, k_qry, task_num, update_step, update_step_test):
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        
        super(Meta, self).__init__()
        self.net = ONNModel()
        self.net.train()
        self.meta_optim = optim.Adam(list(self.net.parameters()), lr=self.meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]
        for i in range(task_num):
            # might need to clip grad norms here later

            # calculate loss before any update
            with torch.no_grad():
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] += correct
                
            for k in range(self.update_step):
                logits = self.net(x_spt[i])
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, list(self.net.parameters()), create_graph=True)
                normalize_grads(grad)
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(self.net.parameters()))))
                # set new net values
                for l, param in enumerate(self.net.parameters()):
                    param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
                # PROJECT TO UNITARY
                # print(self.net.linear0.U)
                self.net.unitary_projection()
                # apply to query
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] += correct
                    
        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.retain_grad()
        loss_q.backward()
        self.meta_optim.step()
        # is this needed?
        self.net.unitary_projection()
        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetuning(self, x_spt, y_spt, x_qry, y_qry):
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)
        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        normalize_grads(grad)
        with torch.no_grad():
            logits_q = net(x_qry)
            pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] += correct
        for k in range(self.update_step_test):
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(net.parameters()))))
            for l, param in enumerate(net.parameters()):
                param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
            net.unitary_projection()
            logits = net(x_spt)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, net.parameters())
            normalize_grads(grad)
            logits_q = net(x_qry)
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] += correct
        del net
        return np.array(corrects)/querysz

In [230]:
device = torch.device('cpu')
db_train = OmniglotLoader(batch_size = 32, n_way = 5, k_spt = 1, k_qry = 15, downsampled_size = 28)

Files already downloaded and verified
Files already downloaded and verified


In [344]:
torch.manual_seed(1)
maml = Meta(update_lr = 0.01, meta_lr = 0.001, n_way = 5, k_spt = 1, k_qry = 15, task_num = 32, update_step = 5, update_step_test = 10).to(device)

In [345]:
num_epochs = 10000

for step in range(num_epochs):
    x_spt, y_spt, x_qry, y_qry = db_train.next()
    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

    accs = maml(x_spt, y_spt, x_qry, y_qry)

    if step % 5 == 0:
        print('step:', step, '\ttraining acc:', accs)

    if step % 500 == 0 and step > 0:
        accs = []
        for _ in range(1000//32):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                        torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetuning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                accs.append(test_acc)
        accs = np.array(accs).mean(axis = 0).astype(np.float16)
        print('Test acc:', accs)


step: 0 	training acc: [0.1925     0.1875     0.20458333 0.19375    0.20541667 0.19875   ]


KeyboardInterrupt: 