In [2]:
import json
import pandas as pd
import xml.etree.ElementTree as ET
from itertools import product
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

from utils import parse_panels, bbox_to_xyxy, parse_rules
from avr_dataset import configurations, id2type, id2size, slot2id, id2slot, panel_dict_to_df, extract_stage2_ground_truth

In [76]:
def extract_stage3_ground_truth(dataset_dir: str, split: str):
    dataset_path = Path(dataset_dir)
    all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
    all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in
                      product(configurations, all_file_stems)]

    all_panel_df = []
    full_rule_data = []

    for file_path in all_file_paths:
        xml = ET.parse(file_path.with_suffix('.xml'))
        xml_root = xml.getroot()
        panel_info_list = parse_panels(xml_root)
        component_rules = parse_rules(xml_root)
        context_panels = panel_info_list[6:]

        # Get rules (labels)
        for component in component_rules:
            cid = int(component['component_id'])
            rule_data = {'file_path': str(file_path)}
            for rule in component['rules']:
                if (rule['attr'] == 'Number/Position') or (rule['attr'] == 'Number') or (rule['attr'] == 'Position'):
                    rule_data[f'component{cid}_number'] = rule['name']
                    rule_data[f'component{cid}_position'] = rule['name']
                elif rule['attr'] == 'Type':
                    rule_data[f'component{cid}_type'] = rule['name']
                elif rule['attr'] == 'Size':
                    rule_data[f'component{cid}_size'] = rule['name']
                elif rule['attr'] == 'Color':
                    rule_data[f'component{cid}_color'] = rule['name']
            full_rule_data.append(rule_data)

        # Get discrete panel representations (features)
        panel_df = panel_dict_to_df(range(6, 16), context_panels, str(file_path))
        all_panel_df.append(panel_df)

    return pd.concat(all_panel_df).reset_index(drop=True), pd.DataFrame(full_rule_data)

In [77]:
panels, rules = extract_stage3_ground_truth('dataset', 'train')

In [82]:
def prepare_stage3_dataset(panels_df: pd.DataFrame, rules_df: pd.DataFrame):
    panels_df_copy = panels_df.copy()
    reshaped_indices = ['file', 'component', 'panel']

    reshaped_panels_df = panels_df_copy.set_index(reshaped_indices).unstack(level=-1)
    reshaped_panels_df.columns.names = ['slot_attr', 'panel']
    reshaped_panels_df.columns = reshaped_panels_df.columns.swaplevel(0, 1)
    reshaped_panels_df = reshaped_panels_df.sort_index(axis=1, level=0)

    index_tuples = []
    for panel_idx, slot_idx, attr in list(product(range(6, 16),
                                                  range(0, 22),
                                                  ['color', 'size', 'type'])):
        index_tuples.append((panel_idx, f'slot{slot_idx}_{attr}'))
    multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['panel', 'slot_attr'])
    reshaped_panels_df = pd.DataFrame(reshaped_panels_df, columns=multi_index)

    reshaped_panels_df.columns = reshaped_panels_df.columns.map(lambda x: 'panel' + '_'.join(list(map(str, x))))
    reshaped_panels_df = reshaped_panels_df.groupby('file').max()
    
    # return reshaped_panels_df
    rules_df = rules_df.rename(columns={'file_path': 'file'})
    rules_df = rules_df.set_index(['file'])

    final_df = reshaped_panels_df.join(rules_df)

    return final_df

In [83]:
prepare_stage3_dataset(panels, rules)

Unnamed: 0_level_0,panel6_slot0_color,panel6_slot0_size,panel6_slot0_type,panel6_slot1_color,panel6_slot1_size,panel6_slot1_type,panel6_slot2_color,panel6_slot2_size,panel6_slot2_type,panel6_slot3_color,...,component0_number,component0_position,component0_type,component0_size,component0_color,component1_number,component1_position,component1_type,component1_size,component1_color
file,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
dataset/center_single/RAVEN_0_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Constant,Progression,Progression,,,,,
dataset/center_single/RAVEN_100_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Constant,Progression,Progression,,,,,
dataset/center_single/RAVEN_101_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Progression,Arithmetic,Distribute_Three,,,,,
dataset/center_single/RAVEN_102_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Distribute_Three,Arithmetic,Constant,,,,,
dataset/center_single/RAVEN_103_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Distribute_Three,Constant,Distribute_Three,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
dataset/up_center_single_down_center_single/RAVEN_993_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,,,,,,Constant,Constant,Constant,Arithmetic,Progression
dataset/up_center_single_down_center_single/RAVEN_994_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Constant,Arithmetic,Distribute_Three,,,,,
dataset/up_center_single_down_center_single/RAVEN_994_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,,,,,,Constant,Constant,Distribute_Three,Arithmetic,Distribute_Three
dataset/up_center_single_down_center_single/RAVEN_995_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Constant,Distribute_Three,Arithmetic,,,,,
