In [9]:
import json
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
from itertools import product
from pathlib import Path
from pprint import pprint
import torch
from torch.utils.data import dataloader, Dataset

from avr_dataset import panel_dict_to_df, slot2id_distribute9, prepare_stage3_dataset
from utils import parse_panels,parse_rules

In [2]:
configurations = ['distribute_nine']
split = 'train'

In [3]:
dataset_path = Path('dataset')
all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in
                    product(configurations, all_file_stems)]

In [4]:
def extract_stage3_ground_truth(dataset_dir: str, split: str, all_panels=True):
    dataset_path = Path(dataset_dir)
    all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
    all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in
                      product(configurations, all_file_stems)]

    all_panel_df = []
    full_rule_data = []
    full_target_data = []

    for file_path in all_file_paths:
        xml = ET.parse(file_path.with_suffix('.xml'))
        npz = np.load(file_path.with_suffix('.npz'))
        xml_root = xml.getroot()
        panel_info_list = parse_panels(xml_root)
        component_rules = parse_rules(xml_root)
        panels = panel_info_list if all_panels else panel_info_list[6:]
        
        full_target_data.append({'file': str(file_path), 'target': npz['target'].item()})

        # Get rules (labels)
        rule_data = {'file_path': str(file_path)}
        for component in component_rules:
            cid = int(component['component_id'])
            for rule in component['rules']:
                if (rule['attr'] == 'Number/Position') or (rule['attr'] == 'Number') or (rule['attr'] == 'Position'):
                    rule_data[f'component{cid}_number'] = rule['name']
                    rule_data[f'component{cid}_position'] = rule['name']
                elif rule['attr'] == 'Type':
                    rule_data[f'component{cid}_type'] = rule['name']
                elif rule['attr'] == 'Size':
                    rule_data[f'component{cid}_size'] = rule['name']
                elif rule['attr'] == 'Color':
                    rule_data[f'component{cid}_color'] = rule['name']
        full_rule_data.append(rule_data)

        # Get discrete panel representations (features)
        panel_idx_range = range(16) if all_panels else range(6, 16)
        panel_df = panel_dict_to_df(panel_idx_range, panels, str(file_path), slot2id=slot2id_distribute9)
        all_panel_df.append(panel_df)

    return (pd.concat(all_panel_df).reset_index(drop=True),
            pd.DataFrame(full_rule_data),
            pd.DataFrame(full_target_data))

In [10]:
panels, rules, targets = extract_stage3_ground_truth('dataset', 'train')
dataset_df = prepare_stage3_dataset(panels, None, targets, num_slots=9)

In [11]:
dataset_df

Unnamed: 0_level_0,panel0_slot0_color,panel0_slot0_size,panel0_slot0_type,panel0_slot1_color,panel0_slot1_size,panel0_slot1_type,panel0_slot2_color,panel0_slot2_size,panel0_slot2_type,panel0_slot3_color,...,panel15_slot6_color,panel15_slot6_size,panel15_slot6_type,panel15_slot7_color,panel15_slot7_size,panel15_slot7_type,panel15_slot8_color,panel15_slot8_size,panel15_slot8_type,target
file,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
dataset/distribute_nine/RAVEN_0_train,3,1,2,-1,-1,-1,3,1,2,-1,...,-1,-1,-1,-1,-1,-1,3,0,4,0
dataset/distribute_nine/RAVEN_100_train,-1,-1,-1,-1,-1,-1,9,4,1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,6
dataset/distribute_nine/RAVEN_101_train,7,2,1,-1,-1,-1,-1,-1,-1,7,...,0,4,5,0,5,5,0,5,5,7
dataset/distribute_nine/RAVEN_102_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,9,...,9,2,2,-1,-1,-1,9,3,5,5
dataset/distribute_nine/RAVEN_103_train,-1,-1,-1,5,1,5,5,1,5,-1,...,1,3,5,1,3,5,1,3,5,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
dataset/distribute_nine/RAVEN_991_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,6,2,1,6,2,1,6,5,1,6
dataset/distribute_nine/RAVEN_992_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,6,...,-1,-1,-1,5,4,2,-1,-1,-1,0
dataset/distribute_nine/RAVEN_993_train,5,4,3,-1,-1,-1,5,4,3,5,...,-1,-1,-1,7,4,2,-1,-1,-1,4
dataset/distribute_nine/RAVEN_994_train,-1,-1,-1,9,1,5,9,1,2,0,...,5,5,3,8,0,2,-1,-1,-1,4


In [16]:
class AVRDiagnosticDataset(Dataset):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df, targets_df = extract_stage3_ground_truth(dataset_dir, split, all_panels=True)
        self.final_df = prepare_stage3_dataset(panels_df, None, targets_df, num_slots=9, all_panels=True)
        self.final_df = self.final_df.reset_index()
        self.final_df = self.final_df.replace({-1, 12})
        
        self.info_col = self.final_df.columns.tolist()[0]
        self.panel_cols = self.final_df.columns.tolist()[1:-1]
        self.target_col = self.final_df.columns.tolist()[-1]
        self.rule2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3, -1: -1}
    
    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]

        info = data[self.info_col]

        panels = torch.split(torch.tensor(data[self.panel_cols].values.astype(np.int64)), 9 * 3)
        # reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        panel_features = torch.stack(panels)

        return {
            'info': info,
            'panels': panel_features,
            'target': torch.tensor(data[self.target_col])
        }

In [17]:
dataset = AVRDiagnosticDataset('dataset', 'train')

In [19]:
dataset[0]['panels'].shape

torch.Size([16, 27])