In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import glob
import torch.utils.data as data
import torchvision.transforms as transforms
import PIL.Image as PILI
import matplotlib.pyplot as plt

In [None]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device
device = "mps" if torch.backends.mps.is_available() else "cpu"
device = torch.device(device)
print(f"Using device: {device}")

In [None]:
class Learner(nn.Module):
  def __init__(self, image_size):
    super(Learner, self).__init__()

    eps = 1e-3
    momentum = 0.95

    self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
    self.norm1 = nn.BatchNorm2d(32, eps, momentum)
    self.relu1 = nn.ReLU(inplace=False)
    self.pool1 = nn.MaxPool2d(2)

    self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
    self.norm2 = nn.BatchNorm2d(32, eps, momentum)
    self.relu2 = nn.ReLU(inplace=False)
    self.pool2 = nn.MaxPool2d(2)

    self.conv3 = nn.Conv2d(32, 32, 3, padding=1)
    self.norm3 = nn.BatchNorm2d(32, eps, momentum)
    self.relu3 = nn.ReLU(inplace=False)
    self.pool3 = nn.MaxPool2d(2)

    self.conv4 = nn.Conv2d(32, 32, 3, padding=1)
    self.norm4 = nn.BatchNorm2d(32, eps, momentum)
    self.relu4 = nn.ReLU(inplace=False)
    self.pool4 = nn.MaxPool2d(2)

    clr_in = image_size // 2**4
    self.linear = nn.Linear(32 * clr_in * clr_in, 5)

    self.criterion = nn.CrossEntropyLoss()

  def forward(self, x):
    dx = self.conv1(x)
    dx = self.norm1(dx)
    dx = self.relu1(dx)
    dx = self.pool1(dx)

    dx = self.conv2(dx)
    dx = self.norm2(dx)
    dx = self.relu2(dx)
    dx = self.pool2(dx)

    dx = self.conv3(dx)
    dx = self.norm3(dx)
    dx = self.relu3(dx)
    dx = self.pool3(dx)

    dx = self.conv4(dx)
    dx = self.norm4(dx)
    dx = self.relu4(dx)
    dx = self.pool4(dx)

    dx = torch.reshape(dx, [dx.size(0), -1])
    output = self.linear(dx)

    return output

  def get_params(self):
    return torch.cat([p.view(-1) for p in self.parameters()], 0)
  
  def copy_params(self, cI):
    idx = 0
    for p in self.parameters():
      plen = p.view(-1).size(0)
      p.data.copy_(cI[idx: idx+plen].view_as(p))
      idx += plen

  def reset_batch_stats(self):
      for m in self.modules():
          if isinstance(m, nn.BatchNorm2d):
              m.reset_running_stats()

In [None]:
class ModifiedLSTMCell(nn.Module):
  def __init__(self, n_learner_parameters):
    super(ModifiedLSTMCell,self).__init__()
    self.input_size = 4
    self.hidden_size = 20
    self.n_learner_params = n_learner_parameters

    self.cellIin = nn.Parameter(torch.Tensor(n_learner_parameters, 1))

    self.inputWeight = nn.Parameter(torch.Tensor(self.input_size + 2, self.hidden_size))
    self.forgetWeight = nn.Parameter(torch.Tensor(self.input_size + 2, 20))
    
    self.inputBias = nn.Parameter(torch.Tensor(1, self.hidden_size))
    self.forgetBias = nn.Parameter(torch.Tensor(1, self.hidden_size))

    self.reset_parameters()
  
  def forward(self, inputs, hx=None):
    x_all, grad = inputs
    batch, _ = x_all.size()

    if hx is None:
      f_prev = torch.zeros((batch, self.hidden_size)).to(self.forgetWeight.device)
      i_prev = torch.zeros((batch, self.hidden_size)).to(self.inputWeight.device)
      c_prev = self.cellIin
      hx = [f_prev, i_prev, c_prev]

    f_prev, i_prev, c_prev = hx
    
    f_next = torch.mm(torch.cat((x_all, c_prev, f_prev), 1), self.forgetWeight) + self.forgetBias.expand_as(f_prev)
    i_next = torch.mm(torch.cat((x_all, c_prev, i_prev), 1), self.inputWeight) + self.inputBias.expand_as(i_prev)
    c_next = torch.sigmoid(f_next).mul(c_prev) - torch.sigmoid(i_next).mul(grad)

    return c_next, [f_next, i_next, c_next]

  def reset_parameters(self):
      for weight in self.parameters():
          nn.init.uniform_(weight, -0.01, 0.01)
      nn.init.uniform_(self.forgetBias, 4, 6)
      nn.init.uniform_(self.inputBias, -5, -4)

  def init_cI(self, flat_params):
      self.cellIin.data.copy_(flat_params.unsqueeze(1))

In [None]:
class MetaLearner(nn.Module):
  def __init__(self, n_learner_params):
    super(MetaLearner,self).__init__()
    self.lstm = nn.LSTMCell(input_size=4, hidden_size=20)
    self.metalstm = ModifiedLSTMCell(n_learner_parameters=n_learner_params)
  
  def forward(self, inputs, hs=None):
    loss, grad_prep, grad = inputs
    loss = loss.expand_as(grad_prep)
    inputs = torch.cat((loss, grad_prep), 1)

    if hs is None:
      hs = [None, None]

    hx, cx = self.lstm(inputs, hs[0])
    learner_params, hs = self.metalstm([hx, grad], hs[1])

    return learner_params.squeeze(), [(hx, cx), hs]


In [None]:
class EpisodeDataset(data.Dataset):
    def __init__(self, root, phase='train', n_shot=5, n_eval=15, transform=None):
        root = os.path.join(root, phase)
        self.labels = sorted(os.listdir(root))[1:]
        images = [glob.glob(os.path.join(root, label, '*')) for label in self.labels]
        self.episode_loader = [data.DataLoader(
            ClassDataset(images=images[idx], label=idx, transform=transform),
            batch_size=n_shot+n_eval, shuffle=True, num_workers=0) for idx, _ in enumerate(self.labels)]

    def __getitem__(self, idx):
        return next(iter(self.episode_loader[idx]))

    def __len__(self):
        return len(self.labels)


class ClassDataset(data.Dataset):
    def __init__(self, images, label, transform=None):
        self.images = images
        self.label = label
        self.transform = transform

    def __getitem__(self, idx):
        image = PILI.open(self.images[idx]).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, self.label

    def __len__(self):
        return len(self.images)


class EpisodicSampler(data.Sampler):
    def __init__(self, total_classes, n_class, n_episode):
        self.total_classes = total_classes
        self.n_class = n_class
        self.n_episode = n_episode

    def __iter__(self):
        for i in range(self.n_episode):
            yield torch.randperm(self.total_classes)[:self.n_class]

    def __len__(self):
        return self.n_episode

In [None]:
def prepare_data(data_root, n_shot, n_eval, n_class, episode, episode_val):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    train_set = EpisodeDataset(data_root, 'train', n_shot, n_eval, transform=transforms.Compose([transforms.RandomResizedCrop(84), transforms.ToTensor(), normalize]))
    val_set = EpisodeDataset(data_root, 'val', n_shot, n_eval, transform=transforms.Compose([transforms.ToTensor(), normalize]))
    test_set = EpisodeDataset(data_root, 'test', n_shot, n_eval, transform=transforms.Compose([transforms.ToTensor(), normalize]))

    train_loader = data.DataLoader(train_set, num_workers=0, batch_sampler=EpisodicSampler(len(train_set), n_class, episode))
    val_loader = data.DataLoader(val_set, num_workers=0, batch_sampler=EpisodicSampler(len(val_set), n_class, episode_val))
    test_loader = data.DataLoader(test_set, num_workers=0, batch_sampler=EpisodicSampler(len(test_set), n_class, episode_val))

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = prepare_data("./data/miniImagenet/", 5, 15, 5, 1000, 10)

In [None]:
image_size = 84
bn_eps = 1e-3
bn_momentum = 0.95
n_class = 5
n_eval = 15
n_shot = 5
grad_clip = 0.25
lr = 1e-3
epoch = 8
batch_size = 25
val_freq = 100

In [None]:
def preprocess_grad_loss(x):
    p = 10
    abs_x = x.abs()
    sign_x = torch.sign(x)
    eps = 1e-8
    indicator = (abs_x >= np.exp(-p)).to(torch.float32)
    x_1 = indicator * torch.log(abs_x + eps) / p + (1 - indicator) * (-1)
    x_2 = indicator * sign_x + (1 - indicator) * (np.exp(p) * x)
    return torch.stack([x_1, x_2], dim=1)

In [None]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, dim=1)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
        return res[0].item() if len(res) == 1 else [r.item() for r in res]

In [None]:
training_loss = []
val_loss = []
training_acc = []
val_acc = []

In [None]:
def meta_test(eval_loader, learner, metalearner):
    for subeps, (episode_x, episode_y) in enumerate(eval_loader):
        train_input = episode_x[:, :n_shot].flatten(0, 1).to(device)
        train_target = torch.arange(n_class).repeat_interleave(n_shot).to(device)
        test_input = episode_x[:, n_shot:].flatten(0, 1).to(device)
        test_target = torch.arange(n_class).repeat_interleave(n_eval).to(device)

        learner.reset_batch_stats()
        learner.train()
        cI = train_learner(learner, metalearner, train_input, train_target)

        output = learner(test_input)
        loss = learner.criterion(output, test_target)
        acc = accuracy(output, test_target)
    val_loss.append(loss.item())
    val_acc.append(acc)
    return acc

In [None]:
def train_learner(learner, metalearner, train_input, train_target):
  cI = metalearner.metalstm.cellIin.data
  hs = [None]
  for _ in range(epoch):
    for i in range(0, len(train_input), batch_size):
      x = train_input[i:i+batch_size]
      y = train_target[i:i+batch_size]

      learner.copy_params(cI)
      output = learner(x)
      loss = learner.criterion(output, y)
      learner.zero_grad()
      loss.backward()
      grad = torch.cat([p.grad.data.view(-1) / batch_size for p in learner.parameters()], 0)

      grad_prep = preprocess_grad_loss(grad)
      loss_prep = preprocess_grad_loss(loss.data.unsqueeze(0))
      metalearner_input = [loss_prep, grad_prep, grad.unsqueeze(1)]
      cI, h = metalearner(metalearner_input, hs[-1])
      hs.append(h)

  return cI

In [None]:
learner = Learner(image_size).to(device) 
metalearner = MetaLearner(learner.get_params().size(0)).to(device)
metalearner.metalstm.init_cI(learner.get_params())

optim = torch.optim.Adam(metalearner.parameters(), lr)

In [None]:
best_acc = 0.0
for eps, (episode_x, episode_y) in enumerate(train_loader):
  train_input = episode_x[:, :n_shot].flatten(0, 1).to(device)
  train_target = torch.LongTensor(np.repeat(range(n_class), n_shot)).to(device)
  test_input = episode_x[:, n_shot:].flatten(0, 1).to(device)
  test_target = torch.LongTensor(np.repeat(range(n_class), n_eval)).to(device)

  learner.reset_batch_stats()
  learner.train()
  cI = train_learner(learner, metalearner, train_input, train_target)

  output = learner(test_input)
  loss = learner.criterion(output, test_target)
  acc = accuracy(output, test_target)
  training_loss.append(loss.item())
  training_acc.append(acc)

  optim.zero_grad()
  loss.backward()
  nn.utils.clip_grad_norm_(metalearner.parameters(), grad_clip)
  optim.step()

  if eps % 100 == 0 and eps != 0:
    acc = meta_test(val_loader, learner, metalearner)
    if acc > best_acc:
        best_acc = acc
    print(eps, "val accuracy", acc)

In [None]:
plt.plot(list(range(len(training_loss))), training_loss)
plt.plot(list(range(0, len(val_loss) * 100, 100)), val_loss)

plt.xlabel("episode")
plt.ylabel("loss")
plt.legend(("train", 'test'))

In [None]:
plt.plot(list(range(len(training_acc))), training_acc)
plt.plot(list(range(0, len(val_acc) * 100, 100)), val_acc)

plt.xlabel("episode")
plt.ylabel("accuracy")
plt.legend(("train", 'test'))