In [129]:
import json
import pandas as pd
import xml.etree.ElementTree as ET
from itertools import product
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from utils import parse_panels, bbox_to_xyxy, parse_rules, plot_example
from avr_dataset import configurations, id2type, id2size, slot2id, id2slot, panel_dict_to_df, extract_stage2_ground_truth

In [101]:
def extract_stage3_ground_truth(dataset_dir: str, split: str):
    dataset_path = Path(dataset_dir)
    all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
    all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in
                      product(configurations, all_file_stems)]

    all_panel_df = []
    full_rule_data = []
    full_target_data = []

    for file_path in all_file_paths:
        xml = ET.parse(file_path.with_suffix('.xml'))
        npz = np.load(file_path.with_suffix('.npz'))
        xml_root = xml.getroot()
        panel_info_list = parse_panels(xml_root)
        component_rules = parse_rules(xml_root)
        context_panels = panel_info_list[6:]
        
        full_target_data.append({'file': str(file_path), 'target': npz['target'].item()})

        # Get rules (labels)
        for component in component_rules:
            cid = int(component['component_id'])
            rule_data = {'file_path': str(file_path)}
            for rule in component['rules']:
                if (rule['attr'] == 'Number/Position') or (rule['attr'] == 'Number') or (rule['attr'] == 'Position'):
                    rule_data[f'component{cid}_number'] = rule['name']
                    rule_data[f'component{cid}_position'] = rule['name']
                elif rule['attr'] == 'Type':
                    rule_data[f'component{cid}_type'] = rule['name']
                elif rule['attr'] == 'Size':
                    rule_data[f'component{cid}_size'] = rule['name']
                elif rule['attr'] == 'Color':
                    rule_data[f'component{cid}_color'] = rule['name']
            full_rule_data.append(rule_data)

        # Get discrete panel representations (features)
        panel_df = panel_dict_to_df(range(6, 16), context_panels, str(file_path))
        all_panel_df.append(panel_df)

    return (pd.concat(all_panel_df).reset_index(drop=True),
            pd.DataFrame(full_rule_data),
            pd.DataFrame(full_target_data))

In [102]:
panels, rules, targets = extract_stage3_ground_truth('dataset', 'train')

In [103]:
def prepare_stage3_dataset(panels_df: pd.DataFrame, rules_df: pd.DataFrame, target_df: pd.DataFrame):
    panels_df_copy = panels_df.copy()
    reshaped_indices = ['file', 'component', 'panel']

    reshaped_panels_df = panels_df_copy.set_index(reshaped_indices).unstack(level=-1)
    reshaped_panels_df.columns.names = ['slot_attr', 'panel']
    reshaped_panels_df.columns = reshaped_panels_df.columns.swaplevel(0, 1)
    reshaped_panels_df = reshaped_panels_df.sort_index(axis=1, level=0)

    index_tuples = []
    for panel_idx, slot_idx, attr in list(product(range(6, 16),
                                                  range(0, 22),
                                                  ['color', 'size', 'type'])):
        index_tuples.append((panel_idx, f'slot{slot_idx}_{attr}'))
    multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['panel', 'slot_attr'])
    reshaped_panels_df = pd.DataFrame(reshaped_panels_df, columns=multi_index)

    reshaped_panels_df.columns = reshaped_panels_df.columns.map(lambda x: 'panel' + '_'.join(list(map(str, x))))
    reshaped_panels_df = reshaped_panels_df.groupby('file').max()
    
    # return reshaped_panels_df
    rules_df = rules_df.rename(columns={'file_path': 'file'})
    rules_df = rules_df.set_index(['file'])

    final_df = reshaped_panels_df.join(rules_df).join(target_df.set_index(['file']))

    return final_df

In [115]:
final = prepare_stage3_dataset(panels, rules, targets)
final = final.reset_index()
final

Unnamed: 0,file,panel6_slot0_color,panel6_slot0_size,panel6_slot0_type,panel6_slot1_color,panel6_slot1_size,panel6_slot1_type,panel6_slot2_color,panel6_slot2_size,panel6_slot2_type,...,component0_position,component0_type,component0_size,component0_color,component1_number,component1_position,component1_type,component1_size,component1_color,target
0,dataset/center_single/RAVEN_0_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Progression,Progression,,,,,,6
1,dataset/center_single/RAVEN_100_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Progression,Progression,,,,,,6
2,dataset/center_single/RAVEN_101_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Progression,Arithmetic,Distribute_Three,,,,,,1
3,dataset/center_single/RAVEN_102_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Distribute_Three,Arithmetic,Constant,,,,,,2
4,dataset/center_single/RAVEN_103_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Distribute_Three,Constant,Distribute_Three,,,,,,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6595,dataset/up_center_single_down_center_single/RA...,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,,,,,Constant,Constant,Constant,Arithmetic,Progression,3
6596,dataset/up_center_single_down_center_single/RA...,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Arithmetic,Distribute_Three,,,,,,1
6597,dataset/up_center_single_down_center_single/RA...,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,,,,,Constant,Constant,Distribute_Three,Arithmetic,Distribute_Three,1
6598,dataset/up_center_single_down_center_single/RA...,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Distribute_Three,Arithmetic,,,,,,7


In [180]:
class AVRStage3Dataset(Dataset):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df, targets_df = extract_stage3_ground_truth(dataset_dir, split)
        self.final_df = prepare_stage3_dataset(panels_df, rules_df, targets_df)
        self.final_df = self.final_df.reset_index()
        self.info_col = self.final_df.columns.tolist()[0]
        self.panel_cols = self.final_df.columns.tolist()[1:-11]
        self.rule_cols = self.final_df.columns.tolist()[-11:-1]
        self.target_col = self.final_df.columns.tolist()[-1]
        self.rule2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3, -1: -1}
    
    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]

        info = data[self.info_col]

        panels = torch.split(torch.tensor(data[self.panel_cols].values.astype(np.int64)), 22 * 3)
        reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        panel_features = torch.stack(reshaped_panels)
        
        rules = data[self.rule_cols].replace(np.nan, -1).map(self.rule2id).to_dict()
        for key, val, in rules.items():
            rules[key] = torch.tensor(val)

        return {
            'info': info,
            'panels': panel_features,
            'rules': rules,
            'target': data[self.target_col]
        }

In [181]:
dataset = AVRStage3Dataset('dataset', 'train')

In [230]:
class Stage3RelNet(nn.Module):
    def __init__(self, mlp_hidden, classes=8):
        super().__init__()
        self.n_concat = 3 * 3 + 10
        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, 1),
        )
        
    def forward(self, x_panels, x_rules):
        context1, context2 = x_panels[:, 0, :, :], x_panels[:, 1, :, :]
        candidate1, candidate2, candidate3, candidate4, candidate5, candidate6, candidate7, candidate8 = x_panels[:, 2, :, :], x_panels[:, 3, :, :], x_panels[:, 4, :, :], x_panels[:, 5, :, :], x_panels[:, 6, :, :], x_panels[:, 7, :, :], x_panels[:, 8, :, :], x_panels[:, 9, :, :]
        
        context1_repeated = torch.repeat_interleave(context1, 22*22, dim=1)
        context2_repeated = torch.repeat_interleave(context2, 22, dim=1).repeat(1, 22, 1)
        
        c1_repeated = candidate1.repeat(1, 22*22, 1)
        c2_repeated = candidate2.repeat(1, 22*22, 1)
        c3_repeated = candidate3.repeat(1, 22*22, 1)
        c4_repeated = candidate4.repeat(1, 22*22, 1)
        c5_repeated = candidate5.repeat(1, 22*22, 1)
        c6_repeated = candidate6.repeat(1, 22*22, 1)
        c7_repeated = candidate7.repeat(1, 22*22, 1)
        c8_repeated = candidate8.repeat(1, 22*22, 1)
        
        batch_rules = torch.stack(list(x_rules.values())).T
        rules_repeated = batch_rules.unsqueeze(dim=1).repeat(1, 22*22*22, 1)
        
        final_c1 = torch.cat([context1_repeated, context2_repeated, c1_repeated, rules_repeated], dim=2).float()
        final_c2 = torch.cat([context1_repeated, context2_repeated, c2_repeated, rules_repeated], dim=2).float()
        final_c3 = torch.cat([context1_repeated, context2_repeated, c3_repeated, rules_repeated], dim=2).float()
        final_c4 = torch.cat([context1_repeated, context2_repeated, c4_repeated, rules_repeated], dim=2).float()
        final_c5 = torch.cat([context1_repeated, context2_repeated, c5_repeated, rules_repeated], dim=2).float()
        final_c6 = torch.cat([context1_repeated, context2_repeated, c6_repeated, rules_repeated], dim=2).float()
        final_c7 = torch.cat([context1_repeated, context2_repeated, c7_repeated, rules_repeated], dim=2).float()
        final_c8 = torch.cat([context1_repeated, context2_repeated, c8_repeated, rules_repeated], dim=2).float()
        
        g_c1 = self.g(final_c1)
        g_c2 = self.g(final_c2)
        g_c3 = self.g(final_c3)
        g_c4 = self.g(final_c4)
        g_c5 = self.g(final_c5)
        g_c6 = self.g(final_c6)
        g_c7 = self.g(final_c7)
        g_c8 = self.g(final_c8)
        
        g_c1_sum = g_c1.sum(1).squeeze()
        
        f_c1 = self.f(g_c1_sum)
        
        return f_c1

In [None]:
class RelationNet(nn.Module):
    def __init__(
        self,
        mlp_hidden=64,
        classes=4,
    ):
        super().__init__()

        self.n_concat = 9

        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )

        def get_head():
            return nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, classes),
        )

        self.mlp_hidden = mlp_hidden

        self.f_num = get_head()
        self.f_pos = get_head()
        self.f_type = get_head()
        self.f_size = get_head()
        self.f_color = get_head()


    def forward(self, x):
        x1, x2, x3, x4, x5, x6 = x[:, 0, :, :], x[:, 1, :, :], x[:, 2, :, :], x[:, 3, :, :], x[:, 4, :, :], x[:, 5, :, :]

        # xi.shape is [nbatch, nrow, ncol]
        x1_repeated = torch.repeat_interleave(x1, 22*22, dim=1)
        x2_repeated = torch.repeat_interleave(x2, 22, dim=1).repeat(1, 22, 1)
        x3_repeated = x3.repeat(1, 22*22, 1)
        x_row1 = torch.cat([x1_repeated, x2_repeated, x3_repeated], dim=2).float()

        x4_repeated = torch.repeat_interleave(x4, 22*22, dim=1)
        x5_repeated = torch.repeat_interleave(x5, 22, dim=1).repeat(1, 22, 1)
        x6_repeated = x6.repeat(1, 22*22, 1)
        x_row2 = torch.cat([x4_repeated, x5_repeated, x6_repeated], dim=2).float()

        g_row1 = self.g(x_row1)
        g_row2 = self.g(x_row2)
        g_row1 = g_row1.view(-1, 22*22*22, self.mlp_hidden).sum(1).squeeze()
        g_row2 = g_row2.view(-1, 22*22*22, self.mlp_hidden).sum(1).squeeze()

        row1_num = self.f_num(g_row1)
        row1_pos = self.f_pos(g_row1)
        row1_type = self.f_type(g_row1)
        row1_size = self.f_size(g_row1)
        row1_color = self.f_color(g_row1)

        row2_num = self.f_num(g_row2)
        row2_pos = self.f_pos(g_row2)
        row2_type = self.f_type(g_row2)
        row2_size = self.f_size(g_row2)
        row2_color = self.f_color(g_row2)

        predictions = {
            'number': (row1_num + row2_num) / 2,
            'position': (row1_pos + row2_pos) / 2,
            'type': (row1_type + row2_type) / 2,
            'size': (row1_size + row2_size) / 2,
            'color': (row1_color + row2_color) / 2,
        }

        return predictions

In [231]:
rn = Stage3RelNet(64)

In [232]:
z = rn(a['panels'], a['rules'])

In [233]:
z.shape

torch.Size([16, 1])

In [151]:
dataset = AVRStage3Dataset('dataset', 'train')

In [182]:
train_loader = DataLoader(dataset, 16, shuffle=True)

In [183]:
a = next(iter(train_loader))

In [184]:
a['rules']

{'component0_number': tensor([ 0,  0,  0,  0,  0, -1,  0,  0,  3,  0, -1,  3, -1, -1,  0,  1]),
 'component0_position': tensor([ 0,  0,  0,  0,  0, -1,  0,  0,  3,  0, -1,  3, -1, -1,  0,  1]),
 'component0_type': tensor([ 2,  1,  0,  2,  0, -1,  2,  0,  0,  0, -1,  2, -1, -1,  2,  0]),
 'component0_size': tensor([ 3,  2,  0,  1,  1, -1,  0,  1,  1,  1, -1,  2, -1, -1,  2,  0]),
 'component0_color': tensor([ 3,  3,  1,  1,  0, -1,  3,  3,  2,  0, -1,  1, -1, -1,  3,  2]),
 'component1_number': tensor([-1, -1, -1, -1, -1,  0, -1, -1, -1, -1,  0, -1,  3,  3, -1, -1]),
 'component1_position': tensor([-1, -1, -1, -1, -1,  0, -1, -1, -1, -1,  0, -1,  3,  3, -1, -1]),
 'component1_type': tensor([-1, -1, -1, -1, -1,  2, -1, -1, -1, -1,  2, -1,  1,  0, -1, -1]),
 'component1_size': tensor([-1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1]),
 'component1_color': tensor([-1, -1, -1, -1, -1,  3, -1, -1, -1, -1,  0, -1,  2,  3, -1, -1])}

In [179]:
a

{'info': ['dataset/distribute_nine/RAVEN_154_train',
  'dataset/distribute_four/RAVEN_651_train',
  'dataset/distribute_four/RAVEN_630_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_182_train',
  'dataset/distribute_four/RAVEN_140_train',
  'dataset/in_center_single_out_center_single/RAVEN_192_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_302_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_522_train',
  'dataset/distribute_four/RAVEN_110_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_220_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_513_train',
  'dataset/left_center_single_right_center_single/RAVEN_911_train',
  'dataset/up_center_single_down_center_single/RAVEN_974_train',
  'dataset/left_center_single_right_center_single/RAVEN_44_train',
  'dataset/in_distribute_four_out_center_single/RAVEN_355_train',
  'dataset/up_center_single_down_center_single/RAVEN_101_train'],
 'panels': tensor([[[[ 9,  0,  3],
      

In [164]:
b = torch.stack(list(a['rules'].values())).T

In [165]:
b.shape

torch.Size([16, 10])

In [170]:
b.unsqueeze(dim=1).repeat(1, 22*22*22, 1).shape

torch.Size([16, 10648, 10])

In [168]:
22*22*22

10648