In [1]:
from torchvision.datasets import Omniglot
from torchvision import transforms, models
from torch.utils.data import Sampler, DataLoader
import matplotlib.pyplot as plt
import random
import numpy as np
import torch
import torchvision
import torch.nn.functional as F

np.random.seed(0)
torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
train = Omniglot(root="./data", download=True,background=True,
    transform=transforms.Compose(
        [
          # transforms.Grayscale(num_output_channels=3),
         transforms.Resize(28),
         transforms.ToTensor()
        ]
    )
)
train_size = int(0.8*len(train))
val_size = int(0.2*len(train))

train_data, val_data = torch.utils.data.random_split(train, [train_size, val_size])

test_data = Omniglot(root="./data", download=True,background=False,
    transform=transforms.Compose(
        [
          # transforms.Grayscale(num_output_channels=3),
         transforms.Resize(28),
         transforms.ToTensor()
        ]
    ),
)

Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_background.zip to ./data/omniglot-py/images_background.zip


  0%|          | 0/9464212 [00:00<?, ?it/s]

Extracting ./data/omniglot-py/images_background.zip to ./data/omniglot-py
Downloading https://raw.githubusercontent.com/brendenlake/omniglot/master/python/images_evaluation.zip to ./data/omniglot-py/images_evaluation.zip


  0%|          | 0/6462886 [00:00<?, ?it/s]

Extracting ./data/omniglot-py/images_evaluation.zip to ./data/omniglot-py


In [3]:
class CustomSampler(Sampler):
  def __init__(self, dataset,batch_size,n,n_s, n_q) -> None:
    self.d = {}
    self.n=n
    self.n_s=n_s
    self.n_q=n_q
    self.batch_size=batch_size
    for i,j in enumerate(dataset):
      if j[1] not in self.d:
        self.d[j[1]]=[i]
      else:
        self.d[j[1]].append(i)
    
    # print(len(self.d[0]))

  def __iter__ (self):
    l=[]
    for i in range(self.batch_size):
      labels = random.sample(list(self.d.keys()),self.n)
      x = np.array([random.sample(self.d[j],self.n_s+self.n_q) for j in labels])
      l.append(np.concatenate(x).ravel().tolist())
    return iter(l)

  def __len__(self):
    return self.batch_size

In [4]:
def shuffle(a,b):
  assert len(a)==len(b)
  a = torch.stack(a)
  b = torch.tensor(b,dtype=torch.int)
  p = np.random.permutation(len(a))
  return a[p],b[p]
  # return a,b

def collate_func(batch):
  x_s, y_s, x_q, y_q = [],[],[],[]
  d= {}
  for i,j in batch:
    if j not in d:
      d[j]=[i]
    else:
      d[j].append(i)
  
  for idx, (k,v) in enumerate(d.items()):
    random.shuffle(v)
    # print(len(v))
    y_q+=[idx]*len(v[:n_q])
    x_q+=v[:n_q]
    x_s+= v[n_q:]
    y_s+= [idx]*len(v[n_q:])
  
  # print(len(x_s),len(y_s))
  x_s, y_s = shuffle(x_s,y_s)
  x_q, y_q = shuffle(x_q, y_q)

  return x_s,y_s,x_q,y_q

In [5]:
n=5
n_s=2
n_q =1

In [6]:
batches = CustomSampler(train_data,16,n,n_s,n_q)
x = DataLoader(train_data, batch_sampler=batches, collate_fn=collate_func)

In [7]:
NUM_INPUT_CHANNELS = 1
NUM_HIDDEN_CHANNELS = 64
KERNEL_SIZE = 3
NUM_CONV_LAYERS = 4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
def get_params(num_out):
  meta_parameters = {}

  # construct feature extractor
  in_channels = NUM_INPUT_CHANNELS
  for i in range(NUM_CONV_LAYERS):
      meta_parameters[f'conv{i}'] = torch.nn.init.xavier_uniform_(
          torch.empty(
              NUM_HIDDEN_CHANNELS,
              in_channels,
              KERNEL_SIZE,
              KERNEL_SIZE,
              requires_grad=True,
              device=DEVICE
          )
      )
      meta_parameters[f'b{i}'] = torch.nn.init.zeros_(
          torch.empty(
              NUM_HIDDEN_CHANNELS,
              requires_grad=True,
              device=DEVICE
          )
      )
      in_channels = NUM_HIDDEN_CHANNELS

  # construct linear head layer
  meta_parameters[f'w{NUM_CONV_LAYERS}'] = torch.nn.init.xavier_uniform_(
      torch.empty(
          num_out,
          NUM_HIDDEN_CHANNELS,
          requires_grad=True,
          device=DEVICE
      )
  )
  meta_parameters[f'b{NUM_CONV_LAYERS}'] = torch.nn.init.zeros_(
      torch.empty(
          num_out,
          requires_grad=True,
          device=DEVICE
      )
  )
  return meta_parameters

In [37]:
class MAML:
  def __init__(self,num_out, num_inner_steps, outer_lr=0.001, inner_lr =0.4,learn_inner_lrs=False):
    self.params = get_params(num_out)
    self.num_out = num_out
    self.num_inner_steps = num_inner_steps
    self.outer_lr =outer_lr
    self.inner_lrs = {k: torch.tensor(inner_lr, requires_grad=learn_inner_lrs) for k in self.params.keys()}
    self.optim = torch.optim.Adam(list(self.params.values())+list(self.inner_lrs.values()),lr=self.outer_lr)

  def forward(self, x, parameters):
    for i in range(NUM_CONV_LAYERS):
      x = F.conv2d(input=x,weight=parameters[f'conv{i}'],bias=parameters[f'b{i}'],stride=1,padding='same')
      x = F.batch_norm(x, None, None, training=True)
      x = F.relu(x)
    x = torch.mean(x, dim=[2, 3])
    out = F.linear(input=x,weight=parameters[f'w{NUM_CONV_LAYERS}'],bias=parameters[f'b{NUM_CONV_LAYERS}'])
    return out

  def adapt(self,xs,ys, adapt=False):
    params = {k: torch.clone(v) for k, v in self.params.items()}

    if not adapt:
      return params
    
    for i in range(self.num_inner_steps):
      pred = self.forward(xs,params)
      loss = F.cross_entropy(pred,ys.type(torch.LongTensor).cuda())
      grads = torch.autograd.grad(loss, params.values(), create_graph=True)

      for key,grad in zip(params.keys(),grads):
          params[key] = params[key] - self.inner_lrs[key]*grad
    
    return params
  
  def train(self, train_data, adapt= False, epochs = 15000):
    for epoch in range(epochs):
      running_loss = 0.0
      running_corrects = 0
      total =0

      for inputs in x:
        xs,ys,xq,yq = inputs
        # print(device)

        xs=xs.to(device)
        xq=xq.to(device)
        ys=ys.to(device)
        yq=yq.to(device)

        self.optim.zero_grad()

        params = self.adapt(xs,ys, adapt=adapt)

        pred = self.forward(xq, params).to(device)
        preds = torch.argmax(pred, dim=-1).to(device)
        
        loss = F.cross_entropy(pred.cuda(),yq.type(torch.LongTensor).cuda())

        loss.backward()
        self.optim.step()

        running_loss += loss.item() * xq.size(0)
        running_corrects += torch.sum(preds == yq.data)
        total+=len(yq.data)
      
      epoch_loss = running_loss / total
      epoch_acc = running_corrects.double() / total

      if epoch==0 or (epoch+1) %100 ==0:
        print(f'Epoch {epoch}/{epochs - 1}:- Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print('-' * 10)

In [38]:
maml = MAML(n,1)
maml.train(x,adapt=True)

Epoch 0/14999:- Loss: 1.3894 Acc: 0.5500
----------
Epoch 99/14999:- Loss: 0.5883 Acc: 0.7750
----------
Epoch 199/14999:- Loss: 0.4267 Acc: 0.8500
----------
Epoch 299/14999:- Loss: 0.3079 Acc: 0.9125
----------
Epoch 399/14999:- Loss: 0.4178 Acc: 0.8250
----------
Epoch 499/14999:- Loss: 0.2613 Acc: 0.9250
----------
Epoch 599/14999:- Loss: 0.2529 Acc: 0.9250
----------
Epoch 699/14999:- Loss: 0.2152 Acc: 0.9000
----------
Epoch 799/14999:- Loss: 0.1761 Acc: 0.9250
----------
Epoch 899/14999:- Loss: 0.1530 Acc: 0.9625
----------
Epoch 999/14999:- Loss: 0.2284 Acc: 0.9250
----------
Epoch 1099/14999:- Loss: 0.1797 Acc: 0.9625
----------
Epoch 1199/14999:- Loss: 0.1106 Acc: 0.9875
----------
Epoch 1299/14999:- Loss: 0.0866 Acc: 0.9875
----------
Epoch 1399/14999:- Loss: 0.0836 Acc: 0.9750
----------


KeyboardInterrupt: ignored