In [1]:
import torch
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
import torch.optim as optim
from torchonn.op.mzi_op import project_matrix_to_unitary
import torchvision.transforms as transforms
import scipy.stats as stats
from copy import deepcopy



In [2]:
class OmniglotLoader():
    def __init__(self, batch_size, n_way, k_spt, k_qry, downsampled_size):
        """
        Args:
        batchsize: imgs per batch
        n_way: number of classes in support set
        k_spt: number of images per class in support set
        k_qry: number of images per class in query set
        """
        
        transform = transforms.Compose([
            lambda x: x.resize((downsampled_size, downsampled_size)),
            lambda x: np.reshape(x, (downsampled_size, downsampled_size, 1)),
            lambda x: (x - 128)/128
        ])
        
        self.x_train = torchvision.datasets.Omniglot(root='./data',
                                               download=True, transform=transform)
        self.x_val = torchvision.datasets.Omniglot(root='./data', background = False,
                                            download=True, transform=transform)
        temp_train = {}
        for (img, label) in self.x_train:
            if label in temp_train.keys():
                temp_train[label].append(img)
            else:
                temp_train[label] = [img]

        temp_val = {}
        for (img, label) in self.x_val:
            if label in temp_val.keys():
                temp_val[label].append(img)
            else:
                temp_val[label] = [img]

        self.x = []
        for label, imgs in temp_train.items():
            self.x.append(np.array(imgs))
        for label, imgs in temp_val.items():
            self.x.append(np.array(imgs))

        self.x = np.array(self.x).astype(np.float32)

        temp = []
        
        self.x_train = self.x[:964]
        self.x_test = self.x[964:]
        self.num_classes = self.x.shape[0]
        self.batch_size = batch_size
        self.downsampled_size = downsampled_size
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.indexes = {"train": 0, "test": 0}
        self.datasets = {"train": self.x_train, "test": self.x_test}

        self.datasets_cache = {"train": self.load_data_cache(self.datasets["train"]),
                               "test": self.load_data_cache(self.datasets["test"])}

    def load_data_cache(self, data_pack):

        num_episodes = 10
        
        set_size = self.k_spt * self.n_way
        query_size = self.k_qry * self.n_way
        data_cache = []

        for sample in range(num_episodes):
            x_supports, y_supports, x_queries, y_queries = [], [], [], []
            for _ in range(self.batch_size):
                x_support, y_support, x_query, y_query = [], [], [], []
                selected_class = np.random.choice(data_pack.shape[0], self.n_way, False)
                for j, cur_class in enumerate(selected_class):
                    selected_image = np.random.choice(20, self.k_spt + self.k_qry, False)
                    # meta-training and meta-test
                    x_support.append(data_pack[cur_class][selected_image[:self.k_spt]])
                    x_query.append(data_pack[cur_class][selected_image[self.k_spt:]])
                    y_support.append([j for _ in range(self.k_spt)])
                    y_query.append([j for _ in range(self.k_qry)])
                    
                support_perm = np.random.permutation(self.n_way * self.k_spt)
                x_support = np.array(x_support).reshape(self.n_way * self.k_spt, 1, self.downsampled_size, self.downsampled_size)[support_perm]
                y_support = np.array(y_support).reshape(self.n_way * self.k_spt)[support_perm]
                query_perm = np.random.permutation(self.n_way * self.k_qry)
                x_query = np.array(x_query).reshape(self.n_way * self.k_qry, 1, self.downsampled_size, self.downsampled_size)[query_perm]
                y_query = np.array(y_query).reshape(self.n_way * self.k_qry)[query_perm]

                x_supports.append(x_support)
                y_supports.append(y_support)
                x_queries.append(x_query)
                y_queries.append(y_query)
                
            x_supports = np.array(x_supports).astype(np.float32).reshape(self.batch_size, set_size, 1, self.downsampled_size, self.downsampled_size)
            y_supports = np.array(y_supports).astype(np.int32).reshape(self.batch_size, set_size)

            x_queries = np.array(x_queries).astype(np.float32).reshape(self.batch_size, query_size, 1, self.downsampled_size, self.downsampled_size)
            y_queries = np.array(y_queries).astype(np.int32).reshape(self.batch_size, query_size)

            data_cache.append([x_supports, y_supports, x_queries, y_queries])
            
        return data_cache

    def next(self, mode = "train"):
        # update cache if needed
        if self.indexes[mode] >= len(self.datasets_cache[mode]):
            self.indexes[mode] = 0
            self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
        next_batch = self.datasets_cache[mode][self.indexes[mode]]
        self.indexes[mode] += 1
        return next_batch

In [90]:
class ONNModel(ONNBaseModel):
    def __init__(self, device=torch.device("cpu")):
        super().__init__()
        self.conv0 = onn.layers.MZIBlockConv2d(
            in_channels=1,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=0,
            dilation=1,
            bias=False,
            miniblock=1,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv1 = onn.layers.MZIBlockConv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=0,
            dilation=1,
            bias=False,
            miniblock=1,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv2 = onn.layers.MZIBlockConv2d(
            in_channels=64,
            out_channels=64, 
            kernel_size=3,
            stride=2,
            padding=0,
            dilation=1,
            bias=False,
            miniblock=1,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.conv3 = onn.layers.MZIBlockConv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=2,
            stride=1,
            padding=0,
            dilation=1,
            bias=False,
            miniblock=1,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.linear = onn.layers.MZIBlockLinear(
            in_features=64,
            out_features=5, # because we're doing 5-way
            bias=False,
            miniblock=4,
            mode="usv",
            decompose_alg="clements",
            photodetect=True,
            device=device,
        )
        self.reset_all_parameters()

        
    def unitary_projection(self):
        for m in self.modules():
            if isinstance(m, onn.layers.MZIBlockLinear) or isinstance(m, onn.layers.MZIBlockConv2d):
                m.U.data.copy_(project_matrix_to_unitary(m.U.data))
                m.V.data.copy_(project_matrix_to_unitary(m.V.data))

    # debugging the code
    def reset_all_parameters(self):
        for m in self.modules():
            if isinstance(m, onn.layers.MZIBlockLinear):
                m.reset_parameters()
            if isinstance(m, onn.layers.MZIBlockConv2d):
                W = torch.nn.init.kaiming_normal_(
                    torch.empty(
                        m.grid_dim_y,
                        m.grid_dim_x,
                        m.miniblock,
                        m.miniblock,
                        dtype=m.U.dtype,
                        device=m.device,
                    )
                )
                print(torch.svd(W, some = False))
                U, S, V = torch.svd(W, some=False)
                print(f"W is: {W}")
                V = V.transpose(-2, -1)
                m.U.data.copy_(U)
                m.V.data.copy_(V)
                m.S.data.copy_(torch.ones_like(S, device=m.device))

    def forward(self, x):
        x = torch.relu(self.conv0(x))
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

In [78]:
max_norm = 5
torch.autograd.set_detect_anomaly(True)
# used to get rid of exploding gradients
def normalize_grads(grad):
    total_norm = torch.norm(torch.stack([torch.norm(g.detach()).to(device) for g in grad]))
    clip_coef_clamped = torch.clamp(max_norm/(total_norm + 1e-6), max = 1.0)
    for g in grad:
        g.detach().mul_(clip_coef_clamped.to(g.device))

class Meta(nn.Module):
    # note that the maml-pytorch library uses an argparser to handle most of this instead of passing in individual parameters
    def __init__(self, update_lr, meta_lr, n_way, k_spt, k_qry, task_num, update_step, update_step_test):
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        
        super(Meta, self).__init__()
        self.net = ONNModel()
        self.net.train()
        self.meta_optim = optim.Adam(list(self.net.parameters()), lr=self.meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        self.net.unitary_projection()
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
        corrects = [0 for _ in range(self.update_step + 1)]
        for i in range(task_num):
            # might need to clip grad norms here later

            # calculate loss before any update
            with torch.no_grad():
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] += correct
                
            for k in range(self.update_step):
                logits = self.net(x_spt[i])
                loss = F.cross_entropy(logits, y_spt[i])
                print(f"logits are: {logits}")
                print(f"loss is: {loss}")
                grad = torch.autograd.grad(loss, list(self.net.parameters()), create_graph=True)
                # avoid very large gradients killing everything
                print(f"grads are: {list(grad)}")
                normalize_grads(grad)
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(self.net.parameters()))))
                # set new net values
                for l, param in enumerate(self.net.parameters()):
                    param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
                
                self.net.unitary_projection()
                # apply to query
                logits_q = self.net(x_qry[i])
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] += correct
                    
        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.retain_grad()
        loss_q.backward()
        self.meta_optim.step()
        # is this needed?
        self.net.unitary_projection()
        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetuning(self, x_spt, y_spt, x_qry, y_qry):
        y_spt = y_spt.to(torch.int64)
        y_qry = y_qry.to(torch.int64)
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)
        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        # avoid very large gradients killing everything
        normalize_grads(grad)
        with torch.no_grad():
            logits_q = net(x_qry)
            pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] += correct
        for k in range(self.update_step_test):
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, list(net.parameters()))))
            for l, param in enumerate(net.parameters()):
                param.data = nn.parameter.Parameter(torch.clone(fast_weights[l]))
            net.unitary_projection()
            logits = net(x_spt)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, net.parameters())
            normalize_grads(grad)
            logits_q = net(x_qry)
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim = 1).argmax(dim = 1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] += correct
        del net
        return np.array(corrects)/querysz

In [33]:
device = torch.device('cpu')
db_train = OmniglotLoader(batch_size = 32, n_way = 5, k_spt = 1, k_qry = 15, downsampled_size = 28)

Files already downloaded and verified
Files already downloaded and verified


In [91]:
maml = Meta(update_lr = 0.1, meta_lr = 0.001, n_way = 5, k_spt = 1, k_qry = 15, task_num = 32, update_step = 5, update_step_test = 10).to(device)

torch.return_types.svd(
U=tensor([[[[ 1.]],

         [[-1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]]],


        [[[-1.]],

         [[-1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[ 1.]],

         [[ 1.]],

         [[ 1.]]],


        [[[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[-1.]],

         [[ 1.]],

         [[-1.]]],


        [[[-1.]],

         [[-1.]],

         [[-1.]],

         [[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]],

         [[-1.]]],


        [[[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[-1.]],

         [[ 1.]],

         [[-1.]],

         [[ 1.]],

         [[-1.]],

         [[-1.]]],


        [[[ 1.]],

         [[ 1.]],

         [[ 1.]],

         [[-1.]],

         [[-1.]],

         [[ 1.

In [53]:
num_epochs = 10000

for step in range(num_epochs):
    x_spt, y_spt, x_qry, y_qry = db_train.next()
    x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)

    accs = maml(x_spt, y_spt, x_qry, y_qry)

    if step % 10 == 0:
        print('step:', step, '\ttraining acc:', accs)

    if step % 500 == 0 and step > 0:
        accs = []
        for _ in range(1000//32):
            # test
            x_spt, y_spt, x_qry, y_qry = db_train.next('test')
            x_spt, y_spt, x_qry, y_qry = torch.from_numpy(x_spt).to(device), torch.from_numpy(y_spt).to(device), \
                                        torch.from_numpy(x_qry).to(device), torch.from_numpy(y_qry).to(device)
            # split to single task each time
            for x_spt_one, y_spt_one, x_qry_one, y_qry_one in zip(x_spt, y_spt, x_qry, y_qry):
                test_acc = maml.finetuning(x_spt_one, y_spt_one, x_qry_one, y_qry_one)
                accs.append(test_acc)
        accs = np.array(accs).mean(axis = 0).astype(np.float16)
        print('Test acc:', accs)


logits are: tensor([[inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf]], grad_fn=<PowBackward0>)
loss is: nan


  File "/opt/homebrew/Cellar/python@3.10/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/homebrew/Cellar/python@3.10/3.10.6_2/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 725, in start
    self.io_loop.start()
  File "/Users/matthewho/Photonic_computing/photonics_env/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyn

RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.