In [1]:
import os
import torch

import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms
import scipy.ndimage
from modules import *
from meta_modules import *

from torch.utils.data.sampler import SubsetRandomSampler

# Set up dataset

In [2]:
class SignedDistanceTransform:
    def __call__(self, img_tensor):
        # Threshold.
        img_tensor[img_tensor<0.5] = 0.
        img_tensor[img_tensor>=0.5] = 1.

        # Compute signed distances with distance transform
        img_tensor = img_tensor.numpy()

        neg_distances = scipy.ndimage.morphology.distance_transform_edt(img_tensor)
        sd_img = img_tensor - 1.
        sd_img = sd_img.astype(np.uint8)
        signed_distances = scipy.ndimage.morphology.distance_transform_edt(sd_img) - neg_distances
        signed_distances /= float(img_tensor.shape[1])
        signed_distances = torch.Tensor(signed_distances)

        return signed_distances, torch.Tensor(img_tensor)

def get_mgrid(sidelen):
    # Generate 2D pixel coordinates from an image of sidelen x sidelen
    pixel_coords = np.stack(np.mgrid[:sidelen,:sidelen], axis=-1)[None,...].astype(np.float32)
    pixel_coords /= sidelen    
    pixel_coords -= 0.5
    pixel_coords = torch.Tensor(pixel_coords).view(-1, 2)
    return pixel_coords

class MNISTSDFDataset(torch.utils.data.Dataset):
    def __init__(self, split, size=(256,256)):
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            SignedDistanceTransform(),
        ])
        self.img_dataset = torchvision.datasets.MNIST('./datasets/MNIST', train=True if split == 'train' else False,
                                                download=True)
        self.meshgrid = get_mgrid(size[0])
        self.im_size = size

    def __len__(self):
        return len(self.img_dataset)

    def __getitem__(self, item):
        img, digit_class = self.img_dataset[item]

        signed_distance_img, binary_image = self.transform(img)
        
        coord_values = self.meshgrid.reshape(-1, 2)
        signed_distance_values = signed_distance_img.reshape((-1, 1))
        
        indices = torch.randperm(coord_values.shape[0])
        support_indices = indices[:indices.shape[0]//2]
        query_indices = indices[indices.shape[0]//2:]

        meta_dict = {'context': (coord_values[support_indices], signed_distance_values[support_indices]), 'query': (coord_values[query_indices], signed_distance_values[query_indices]), 'all': (coord_values, signed_distance_values)}

        return meta_dict

In [3]:
train_dataset = MNISTSDFDataset('train', size=(64, 64))
val_dataset = MNISTSDFDataset('val', size=(64, 64))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16)

In [11]:
def sdf_loss(predictions, gt, **kwargs):
    return ((predictions - gt)**2).mean()


def inner_maml_sdf_loss(predictions, gt, **kwargs):
    return ((predictions - gt)**2).sum(0).mean()

In [5]:
def lin2img(tensor):
    batch_size, num_samples, channels = tensor.shape
    sidelen = np.sqrt(num_samples).astype(int)
    return tensor.permute(0,2,1).view(batch_size, channels, sidelen, sidelen)

# Initialize Model

For this task, we use a simple model with two hidden layers of 256 hidden units. We use the original MAML algorithm with 3 steps and a single learnable LR initialized to 1e-1.

In [13]:
hypo_module = ReLUFC(in_features=2, out_features=1, 
                     num_hidden_layers=2, hidden_features=256)
hypo_module.net.apply(sal_init)
hypo_module.net[-1].apply(sal_init_last_layer)

model = MetaSDF(hypo_module, inner_maml_sdf_loss, init_lr=1e-1, 
                     lr_type='global', first_order=False).cuda()

# Training

In [None]:
optim = torch.optim.Adam(lr=1e-4, params=model.parameters())

train_losses = []
val_losses = []

for epoch in range(3):
    for step, meta_batch in enumerate(train_dataloader):
        model.train()        
        context_x, context_y = meta_batch['context']
        query_x = meta_batch['query'][0]

        context_x = context_x.cuda()
        context_y = context_y.cuda()
        query_x = query_x.cuda()

        # Adapt model using context examples
        fast_params = model.generate_params(context_x, context_y)
        
        # Use the adapted examples to make predictions on query
        pred_sd = model.forward_with_params(query_x, fast_params)
        
        # Calculate loss on query examples
        loss = sdf_loss(pred_sd, meta_batch['query'][1].cuda())
        train_losses.append(loss.item())

        optim.zero_grad()
        loss.backward()
        optim.step()

        if step % 1000 == 0:
            with torch.no_grad():
                pred_image = model.forward_with_params(meta_batch['all'][0].cuda(), fast_params)
            print(f"Epoch: {epoch} \t step: {step} \t loss: {loss.item()}")
            plt.imshow(lin2img(pred_image).cpu().numpy()[0][0])
            plt.show()

    print("Evaluating model")
    with torch.no_grad():
        model.eval()
        for step, meta_batch in enumerate(val_dataloader):
            # Instead of explicitly calling generate_params and forward_with_params,
            # we can pass the meta_batch dictionary to the model's forward method
            pred_sd, _ = model(meta_batch)
            val_loss = sdf_loss(pred_sd, meta_batch['query'][1].cuda())
            val_losses.append(val_loss.item())
            
            if step % 1000 == 0:
                pred_image = model.forward_with_params(meta_batch['all'][0].cuda(), fast_params)
                print(f"Val Image -- Epoch: {epoch} \t step: {step} \t loss: {loss.item()}")
                plt.imshow(lin2img(pred_image).cpu().numpy()[0][0])
                plt.show()