# The challenge as I see it:
Learn a function $f(x, K) = y$, where $x$ is the input sampled from the ARC dataset, $K$ is the core knowledge prior and $y$ is the output from the ARC dataset. It should be noted that only $x$ and $K$ are the inputs here and nothing else.
For each task we already have some training examples $x$ and $y$. We can easily overfit a neural network and get 100% accuracy. But by doing so
we are not learning anything and getting nowhere closer to AGI. Even in a DSL approach, the algorithm is not learning anything! In such an approach
we have created handcrafted constraints on the type of functions i.e. we have manually engineered K. 

A smart algorithm should have a representation of cognitive priors and should automatically decide which specific
prior is suitable for a particular task. The tasks are independent of each other but they have something in common i.e. the cognitive priors.
We can use this common thread among all the tasks to learn a representation of cognitive priors. And whenever we want to do prediction on a new task, we can finetune and select our learnt priors for that task. So how to do this ?

# Meta Learning
Meta-learning is the process of learning how to learn. A meta-learning algorithm takes in a distribution of tasks, where each task is a learning problem, and it produces a quick learner — a learner that can generalize from a small number of examples. MAML is one of the famous meta-learning approaches
out there. But it requires us to compute Hessians (which is hard). Another interesting approach is the Reptile Algorithm.
It's mathematically similar to first order MAML and performs a very similar update. The performance of reptile is similar to MAML
on the Omniglot and Mini-ImageNet benchmarks. So, i decided to stick with Reptile instead of MAML.



# About the code

In this kernel, I am only doing prediction on the test set. You can find the full code for training, validation and testing in [my github repo](https://github.com/sidml/reptile-transformer).
The ARC Dataset is divided into training, evaluation and test sets. 100 examples from the evaluation set are part of the test set. I use these 100 tasks
for validation. Each task can have 3 to 5 training images and images can vary in size across and within tasks from 2x2 to 30x30.
The model weights are going to be shared among the tasks. So, we need a model that is independent of image size. So, i decided to 
make an extra class (11). This class represents non existent pixels for a task. So, let's say task size is 10, then i pad the image with class 1 such
that image size is again 15. This has two advantages (hopefully)
1. Allows us to have a common size across tasks (important if you want to experiment with cnn)
2. Introduces the concept of emptiness and varying size to our model. 
   
I experimented with cnn's but i found their performance to be lacking. In the current code, you can see that i am using transformer model.
I reshape the image and pass it to the embedding layer first, followed by a positional encoding layer to account for the order of the pixel. I 
use Pytorch's implementation of TransformerEncoder. The output of TransformerEncoder model is sent to the final Linear layer, which gives
us logits for each class (we have total 11 classes). Then we use the usual cross entropy for model training.


Since we are aiming to learn parameters that optimizes over all the tasks, during the model training we need an outer loop which iterates over the tasks.
We train a transformer model for each task using the task specific inputs and outputs. This is our inner loop.
Then we perform a gradient update to interpolate between current model weights and trained weights from this task.
I found this picture from MAML paper very useful to get an intuition about ![how training works](https://bair.berkeley.edu/blog/assets/maml/maml.png)


I use the Transformer model. You can easily replace the model with any cnn or deep learning model that you desire.
I just wanted to introduce the idea of meta-learning and how it may be pertinent to the ARC dataset.

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np
import json
from glob import glob

import os
import random
import pickle
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import colors
import pandas as pd
import torch.nn as nn
import math

In [None]:
seed = 42
print(f'setting everything to seed {seed}')
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

cmap = colors.ListedColormap(
    ['#000000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
        '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25', '#FFFFFF'])
norm = colors.Normalize(vmin=0, vmax=10)

# Test dataset

In [None]:
class ARCTest(Dataset):

    def __init__(self, root, imgsz=30):
        """

        :param root: root path of mini-imagenet
        :param mode: train, val or test
        :param batchsz: batch size of sets, not batch of imgs
        :param n_way:
        :param k_query: num of qeruy imgs per class
        :param resize: resize to
        :param startidx: start to index label from startidx
        """
        super(ARCTest, self).__init__()
        self.out_rows, self.out_cols = imgsz, imgsz
        task_paths = sorted(glob(f'{root}/test/*.json')) 
        self.support_x_batch, self.support_y_batch, self.query_x_batch, self.task_paths = self.create_batch(task_paths)

    def pad_im(self, task, out_rows, out_cols, cval=10):

        ip = []
        op = []
        for mode in ['train']:
            num_pairs = len(task[mode])
            input_im = np.zeros((num_pairs, 1, out_rows, out_cols))
            output_im = np.zeros(
                (num_pairs, 1, out_rows, out_cols), dtype=np.long)
            for task_num in range(num_pairs):
                im = np.array(task[mode][task_num]['input'])
                nrows, ncols = im.shape
                if (nrows > out_rows) or (ncols > out_cols):
                    return 0, 0, 1, 0
                im = np.pad(im, ((out_rows-nrows, 0), (out_cols-ncols, 0)), mode='constant',
                            constant_values=(cval, cval))

                input_im[task_num, 0] = im
                im = np.array(task[mode][task_num]['output'])
                nrows, ncols = im.shape
                if (nrows > out_rows) or (ncols > out_cols):
                    return 0, 0, 1, 0
                im = np.pad(im, ((out_rows-nrows, 0), (out_cols-ncols, 0)), mode='constant',
                            constant_values=(cval, cval))
                output_im[task_num, 0] = im
            ip.extend(input_im)
            op.extend(output_im)

            test_ip = []
            num_pairs = len(task['test'])
            input_im = np.zeros((num_pairs, 1, out_rows, out_cols))
            for task_num in range(num_pairs):
                im = np.array(task['test'][task_num]['input'])
                nrows, ncols = im.shape
                if (nrows > out_rows) or (ncols > out_cols):
                    return 0, 0, 1, 0
                im = np.pad(im, ((out_rows-nrows, 0), (out_cols-ncols, 0)), mode='constant',
                            constant_values=(cval, cval))

                input_im[task_num, 0] = im
            test_ip.extend(input_im)

        return np.vstack(ip), np.vstack(op), 0, np.vstack(test_ip)

    def create_batch(self, task_paths):
        """
        create batch for meta-learning.
        ×episode× here means batch, and it means how many sets we want to retain.
        :param episodes: batch size
        :return:
        """
        x_batch, y_batch,query_x_batch  = [], [], []
        all_task_paths = []
        for task_file in task_paths:
            with open(task_file, 'r') as f:
                task = json.load(f)
            input_im, output_im, not_valid, query_im = self.pad_im(task, self.out_rows,
                                                                   self.out_cols)
            if not_valid:
                continue
            x_batch.extend(input_im[None])
            y_batch.extend(output_im[None])
            query_x_batch.extend(query_im[None])
            all_task_paths.append(task_file)
        return x_batch, y_batch, query_x_batch, all_task_paths

    def __getitem__(self, index):
        """
        index means index of sets, 0<= index <= batchsz-1
        :param index:
        :return:
        """
        # print('global:', support_y, query_y)
        # support_y: [setsz]
        # query_y: [querysz]
        # unique: [n-way], sorted
        support_x = torch.tensor(
            self.support_x_batch[index], dtype=torch.float32)
        support_y = torch.tensor(
            self.support_y_batch[index], dtype=torch.long)
        return support_x[:, None], support_y.reshape(-1), self.task_paths[index]

    def __len__(self):
        # as we have built up to batchsz of sets, you can sample some small batch size of sets.
        # return self.batchsz
        return len(self.support_x_batch)


# Defining the Transformer Model.

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(
            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float(
            '-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)

        output = self.decoder(output)
        return output


# Utilities

In [None]:
# https://www.kaggle.com/inversion/abstraction-and-reasoning-starter-notebook
def flattener(pred):
    str_pred = str([row for row in pred])
    str_pred = str_pred.replace(', ', '')
    str_pred = str_pred.replace('[[', '|')
    str_pred = str_pred.replace('][', '|')
    str_pred = str_pred.replace(']]', '|')
    return str_pred

def plot_figure(x_spt, y_spt, x_qry,
                pred_q, im_num, img_sz=30):

    plt.figure(figsize=(15,15))
    plt.subplot(2, 2, 1)
    plt.imshow(x_spt[0].cpu().numpy().reshape(img_sz, img_sz),
               cmap=cmap, norm=norm)
    plt.title('train input 1')
    plt.subplot(2, 2, 2)
    plt.imshow(y_spt[:img_sz*img_sz].cpu().numpy().reshape(img_sz, img_sz),
               cmap=cmap, norm=norm)
    plt.title('ideal output 1')
    plt.subplot(2, 2, 3)
    plt.title('test input 1')
    plt.imshow(x_qry[0].cpu().numpy().reshape(img_sz, img_sz),
               cmap=cmap, norm=norm)
    plt.title('model prediction')

    # do visualization only for the first input.
    pred_q = pred_q[0, :img_sz*img_sz].cpu().numpy().reshape(img_sz, img_sz)
    frow = np.nonzero(np.count_nonzero(pred_q-10, axis=1))[0][0]
    fcol = np.nonzero(np.count_nonzero(pred_q-10, axis=0))[0][0]
    a = np.copy(pred_q[frow:, fcol:])
    a[a == 10] = 0
    plt.subplot(2, 2, 4)
    plt.imshow(a,
               cmap=cmap, norm=norm)
    plt.suptitle(f'{im_num}')

    plt.savefig(f'.epoch_30_preds_{im_num}.png')
    plt.show()
    plt.close()

In [None]:
ntokens = 11  # the size of vocabulary
emsize = 32  # embedding dimension
nhid = 64  # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4  # the number of heads in the multiheadattention models
dropout = 0.5  # the dropout value
device = 'cuda'

# ntoken, ninp, nhead, nhid, nlayers, dropout=0.5
model = TransformerModel(ntokens, emsize, nhead,
                         nhid, nlayers, dropout).cpu()

sample_sub = pd.read_csv('/kaggle/input/abstraction-and-reasoning-challenge/sample_submission.csv')
sample_sub = sample_sub.set_index('output_id')
sample_sub.head()


# Model Predictions on Test Set

In [None]:
# device = torch.device('cpu')
innerstepsize = 1e-2  # stepsize in inner SGD
innerepochs = 1000  # number of epochs of each inner SGD


device = torch.device('cpu')

# batchsz here means total episode number
arc_dataset = ARCTest(
    root='/kaggle/input/abstraction-and-reasoning-challenge/', imgsz=15)

all_train_acc = []
imgsz, num_class = 15, 11
for step, ((x, y, task_path), q) in enumerate(zip(arc_dataset, arc_dataset.query_x_batch)):

    task_id = task_path.split('/')[-1]
    state = torch.load('../input/arctransformermodel/epoch_1_step_264_acc_0.688.pth', map_location='cpu')
    model.load_state_dict(state)

    optimizer = torch.optim.AdamW(model.parameters(), lr=innerstepsize)
    x, y = x.to(device), y.to(device)
    x = x.to(device).reshape(-1, imgsz*imgsz).long()

    train_losses = []
    train_acc = []
    model.train()
    for _ in range(innerepochs):
        optimizer.zero_grad()
        outputs = model(x).reshape(-1, num_class)
        loss = F.cross_entropy(outputs, y)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        acc = (outputs.argmax(1) == y).float().mean().item()
        train_acc.append(acc)
    print('\ttraining loss:',
          np.mean(train_losses), '\ttraining acc:', np.mean(train_acc))

    all_train_acc.append(np.mean(train_acc))
    model.eval()
    with torch.no_grad():
        q = torch.tensor(
            q.reshape(-1, imgsz*imgsz)).to(device).long()
        # print(q.shape)
        outputs = F.softmax(model(q), dim=1)
        outputs = outputs.argmax(2).reshape(-1, imgsz, imgsz)
        plot_figure(x, y, q, outputs, im_num=task_id, img_sz=imgsz)

    for task_num, preds in enumerate(outputs.cpu().numpy()):
        frow = np.nonzero(np.count_nonzero(preds-10, axis=1))[0][0]
        fcol = np.nonzero(np.count_nonzero(preds-10, axis=0))[0][0]
        preds = np.copy(preds[frow:, fcol:])
        preds[preds == 10] = 0
        sample_sub.loc[f'{task_id[:-5]}_{task_num}',
                    'output'] = flattener(preds.astype(int).tolist())
#     if step > 3: break
print('\nmean train acc:', np.mean(all_train_acc),
      'stddev train acc:', np.std(all_train_acc))

In [None]:
sample_sub.head()
sample_sub.to_csv('submission.csv')