In [210]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pprint import pprint
from torch.utils.data import Dataset
from itertools import product
import xml.etree.ElementTree as ET
from pathlib import Path

from utils import parse_rules, parse_panels, bbox_to_xyxy, plot_example
from avr_dataset import extract_stage2_ground_truth, prepare_stage2_dataset

import torch
from torch import nn
from torch.utils.data import DataLoader

from collections import Counter
from tqdm import tqdm
from torch import optim
from torch.optim.lr_scheduler import StepLR

In [211]:
class AVRStage2Dataset(nn.Module):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df = extract_stage2_ground_truth(dataset_dir, split)
        self.final_df = prepare_stage2_dataset(panels_df, rules_df, merge_row=False)
        self.final_df = self.final_df.reset_index()
        self.info_cols = self.final_df.columns.tolist()[:2]
        self.feature_cols = self.final_df.columns.tolist()[2:-5]
        self.label_cols = self.final_df.columns.tolist()[-5:]
        self.label2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3}

    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]
        
        info = data[self.info_cols].to_dict()
        
        panels = torch.split(torch.tensor(data[self.feature_cols].values.astype(np.int64)), 22 * 3)
        reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        features = torch.stack(reshaped_panels)
        
        labels = data[self.label_cols].map(self.label2id).to_dict()
        for key, val, in labels.items():
            labels[key] = torch.tensor(val)
        
        return {
            'info': info,
            'features': features,
            'labels': labels
        }
    

In [212]:
class RelationNet(nn.Module):
    def __init__(
        self,
        mlp_hidden=32,
        classes=4,
    ):
        super().__init__()
        
        self.n_concat = 9
        
        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )
        
        def get_head():
            return nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, classes),
        )
        
        self.mlp_hidden = mlp_hidden
        
        self.f_num = get_head()
        self.f_pos = get_head()
        self.f_type = get_head()
        self.f_size = get_head()
        self.f_color = get_head()


    def forward(self, x):
        x1, x2, x3, x4, x5, x6 = x[:, 0, :, :], x[:, 1, :, :], x[:, 2, :, :], x[:, 3, :, :], x[:, 4, :, :], x[:, 5, :, :]

        x1_repeated = torch.repeat_interleave(x1, 22*22, dim=1)
        x2_repeated = torch.repeat_interleave(x2, 22, dim=1).repeat(1, 22, 1)
        x3_repeated = x3.repeat(1, 22*22, 1)
        x_row1 = torch.cat([x1_repeated, x2_repeated, x3_repeated], dim=2).float()
        
        x4_repeated = torch.repeat_interleave(x4, 22*22, dim=1)
        x5_repeated = torch.repeat_interleave(x5, 22, dim=1).repeat(1, 22, 1)
        x6_repeated = x6.repeat(1, 22*22, 1)
        x_row2 = torch.cat([x4_repeated, x5_repeated, x6_repeated], dim=2).float()
        
        g_row1 = self.g(x_row1)
        g_row2 = self.g(x_row2)
        g_row1 = g_row1.view(-1, 22*22*22, self.mlp_hidden).sum(1).squeeze()
        g_row2 = g_row2.view(-1, 22*22*22, self.mlp_hidden).sum(1).squeeze()
        
        row1_num = self.f_num(g_row1)
        row1_pos = self.f_pos(g_row1)
        row1_type = self.f_type(g_row1)
        row1_size = self.f_size(g_row1)
        row1_color = self.f_color(g_row1)
        
        row2_num = self.f_num(g_row2)
        row2_pos = self.f_pos(g_row2)
        row2_type = self.f_type(g_row2)
        row2_size = self.f_size(g_row2)
        row2_color = self.f_color(g_row2)
        
        predictions = {
            'num': (row1_num + row2_num) / 2,
            'pos': (row1_pos + row2_pos) / 2,
            'type': (row1_type + row2_type) / 2,
            'size': (row1_size + row2_size) / 2,
            'color': (row1_color + row2_color) / 2,
        }

        return predictions

In [213]:
batch_size = 256
lr = 5e-6
lr_max = 5e-4
lr_gamma = 2
lr_step = 20
clip_norm = 50
weight_decay = 1e-4
n_epoch = 500
n_worker = 9
data_parallel = True

In [214]:
train_dataset = AVRStage2Dataset('dataset', 'train')
test_dataset = AVRStage2Dataset('dataset', 'test')

In [215]:
def train(epoch):
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
    pbar = tqdm(iter(train_dataloader))

    net.train(True)

    for i, data_dict in enumerate(pbar):
        x = data_dict['features'].cuda()
        y = data_dict['labels']
        for key, val in y.items():
            y[key] = val.cuda()

        net.zero_grad()
        out = net(x)
        loss = criterion(out['num'], y['num'])
        for key in ['pos', 'type', 'size', 'color']:
            loss += criterion(out[key], y[key])

        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), clip_norm)
        optimizer.step()
        
        acc = {'num': 0, 'pos': 0, 'type': 0, 'size': 0, 'color': 0}
        moving_loss = acc.copy()
        for key in acc:
            correct = out[key].data.cpu().numpy().argmax(1) == y['key'].data.cpu().numpy()
            acc[key] = correct.sum() / batch_size
            if moving_loss[key] == 0:
                moving_loss[key] = acc[key]
            else:
                moving_loss[key] = moving_loss[key] * 0.99 + acc[key] * 0.01
        avg_acc = sum(acc.values()) / 5

        pbar.set_description(
            'Epoch: {}; Loss: {:.5f}; Acc: {:.5f}; Moving Loss: {:.5f}; LR: {:.6f}'.format(
                epoch + 1,
                loss.detach().item(),
                avg_acc,
                moving_loss,
                optimizer.param_groups[0]['lr'],
            )
        )
        return moving_loss, acc

In [216]:
def valid(epoch):
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
    pbar = tqdm(iter(test_dataloader))
    
    net.eval()
    num_correct = {'num': 0, 'pos': 0, 'type': 0, 'size': 0, 'color': 0}
    acc = num_correct.copy()
    num_total = 0

    with torch.no_grad():
        for i, data_dict in enumerate(pbar):
            x = data_dict['features'].cuda()
            y = data_dict['labels']
            
            out = net(x)
            acc = {'num': 0, 'pos': 0, 'type': 0, 'size': 0, 'color': 0}
            for key in num_correct:
                correct = out[key].data.cpu().numpy().argmax(1) == y['key'].data.cpu().numpy()
                num_correct[key] = correct.sum()
            
            num_total += batch_size
    
    for key in num_correct:
        acc[key] = num_correct[key] / num_total

    print('Avg Acc: {:.5f}'.format(sum(num_correct.values()) / 5))
    
    return acc

In [217]:
net = RelationNet()
if torch.cuda.is_available():
    net = net.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_gamma)

train_acc_list = []
test_acc_list = []

for epoch in range(n_epoch):
    if scheduler.get_lr()[0] < lr_max:
        scheduler.step()

    _, train_acc = train(epoch)
    test_acc = valid(epoch)
    
    train_acc_list.append(train_acc)
    test_acc_list.append(test_acc)

  0%|          | 0/26 [00:00<?, ?it/s]


AssertionError: Torch not compiled with CUDA enabled

# SVM and XGBoost Experiments

In [5]:
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
label_columns = ['number', 'position', 'type', 'size', 'color']

panels_df, rules_df = extract_stage2_ground_truth('dataset', 'train')
final_df = prepare_stage2_dataset(panels_df, rules_df, merge_row=False)

X, Y = final_df[final_df.columns[:-5]], final_df[final_df.columns[-5:]]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33, random_state=42)

enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(X)
X_train_encoded = enc.transform(X_train).toarray()
X_test_encoded = enc.transform(X_test).toarray()

In [253]:
from sklearn import svm

clf = svm.SVC()
clf.fit(X_train_encoded, y_train['type'].values)
clf.score(X_test_encoded, y_test['type'].values)

0.4866850321395776

In [249]:
label_map = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3}

In [243]:
from xgboost import XGBClassifier

bst_type = XGBClassifier(n_estimators=20, max_depth=50, learning_rate=1, objective='multi:softprob')
bst_type.fit(X_train_encoded, y_train['type'].map(label_map).values)
bst_type.score(X_test_encoded, y_test['type'].map(label_map).values)

0.8523875114784206

In [251]:
bst_color = XGBClassifier(n_estimators=20, max_depth=100, learning_rate=0.1, objective='multi:softprob')
bst_color.fit(X_train_encoded, y_train['color'].map(label_map).values)
bst_color.score(X_test_encoded, y_test['color'].map(label_map).values)

0.6779155188246098

In [252]:
bst_size = XGBClassifier(n_estimators=20, max_depth=100, learning_rate=0.1, objective='multi:softprob')
bst_size.fit(X_train_encoded, y_train['size'].map(label_map).values)
bst_size.score(X_test_encoded, y_test['size'].map(label_map).values)

0.73989898989899

In [255]:
bst_number = XGBClassifier(n_estimators=20, max_depth=100, learning_rate=0.1, objective='multi:softprob')
bst_number.fit(X_train_encoded, y_train['number'].map(label_map).values)
bst_number.score(X_test_encoded, y_test['number'].map(label_map).values)

0.8675390266299358

In [256]:
bst_position = XGBClassifier(n_estimators=20, max_depth=100, learning_rate=0.1, objective='multi:softprob')
bst_position.fit(X_train_encoded, y_train['position'].map(label_map).values)
bst_position.score(X_test_encoded, y_test['position'].map(label_map).values)

0.8675390266299358