In [6]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pprint import pprint
from torch.utils.data import Dataset
from itertools import product
import xml.etree.ElementTree as ET
from pathlib import Path

from utils import parse_rules, parse_panels, bbox_to_xyxy

In [9]:
def get_dataset(dataset_dir: str, split: str):
    configurations = [
        'center_single',
        'distribute_four',
        'distribute_nine',
        'in_center_single_out_center_single',
        'in_distribute_four_out_center_single',
        'left_center_single_right_center_single',
        'up_center_single_down_center_single'
    ]
    id2type = ['none', 'triangle', 'square', 'pentagon', 'hexagon', 'circle']
    id2size = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    id2color = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]
    
    dataset_path = Path(dataset_dir)
    all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
    all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in product(configurations, all_file_stems)]
    
    size_data = []
    color_data = []
    number_data = []
    position_data = []
    for file_path in all_file_paths:
        xml = ET.parse(file_path.with_suffix('.xml'))
        xml_root = xml.getroot()
        panel_info_list = parse_panels(xml_root)
        rules = parse_rules(xml_root)
        print(len(panel_info_list))
        print(rules)
        pprint(panel_info_list)
        break

get_dataset('dataset', 'train')

16
[{'component_id': '0', 'rules': [{'name': 'Constant', 'attr': 'Number/Position'}, {'name': 'Distribute_Three', 'attr': 'Type'}, {'name': 'Progression', 'attr': 'Size'}, {'name': 'Arithmetic', 'attr': 'Color'}]}]
[{'components': [{'component': {'id': '0', 'name': 'Grid'},
                  'entities': [{'Angle': '1',
                                'Color': '0',
                                'Size': '4',
                                'Type': '3',
                                'bbox': '[0.5, 0.5, 1, 1]',
                                'mask': '[3139,3,3296,7,3453,10,3610,14,3767,18,3924,22,4081,25,4238,29,4395,33,4552,37,4709,40,4866,44,5023,48,5180,52,5337,55,5493,60,5650,64,5807,68,5964,71,6121,75,6278,79,6435,83,6592,86,6750,89,6909,91,7069,92,7229,92,7389,93,7549,94,7709,95,7869,95,8029,96,8189,97,8349,98,8509,98,8669,99,8829,100,8989,101,9149,101,9309,102,9469,103,9629,104,9789,104,9949,105,10109,106,10269,107,10429,107,10589,108,10749,109,10909,110,11069,110,11229,111,113

In [None]:
class AVRStage1Dataset(Dataset):
    def __init__(
            self,
            dataset_dir: str,
            split: str = 'train',
            transform = None,
            target_transform = None):
        assert split in ['train', 'val', 'test']

        self.split = split
        self.configurations = [
            'center_single',
            'distribute_four',
            'distribute_nine',
            'in_center_single_out_center_single',
            'in_distribute_four_out_center_single',
            'left_center_single_right_center_single',
            'up_center_single_down_center_single'
        ]
        self.dataset_path = Path(dataset_dir)
        self.transform = transform
        self.target_transform = target_transform
        self.all_file_stems = list(fn.stem for fn in (self.dataset_path / Path(self.configurations[0])).glob(f'*_{self.split}.npz'))
        self.all_file_paths = [Path(self.dataset_path, config, base_fn) for config, base_fn in product(self.configurations, self.all_file_stems)]

        self.id2type = ['none', 'triangle', 'square', 'pentagon', 'hexagon', 'circle']
        self.id2size = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        self.id2color = [255, 224, 196, 168, 140, 112, 84, 56, 28, 0]

    def get_panel_by_id(self, idx: int):
        file_idx, panel_idx = int(idx / 16), int(idx % 16)
        return self.all_file_paths[file_idx], panel_idx

    def __len__(self):
        return len(self.all_file_paths) * 16

    def __getitem__(self, idx):
        file_path, panel_idx = self.get_panel_by_id(idx)

        npz = np.load(file_path.with_suffix('.npz'))
        image = torch.as_tensor(npz['image'][panel_idx, :, :], dtype=torch.float) / 255
        image = torch.unsqueeze(image, dim=0)

        targets = {}
        boxes = []
        types = []
        sizes = []
        colors = []
        xml = ET.parse(file_path.with_suffix('.xml'))
        xml_root = xml.getroot()
        panel_info_list = parse_panels(xml_root)
        panel_info = panel_info_list[panel_idx]
        all_entities = []
        for component in panel_info['components']:
            all_entities += component['entities']
        for entity in all_entities:
            boxes.append(bbox_to_xyxy(json.loads(entity['real_bbox'])))
            types.append(int(entity['Type']))
            sizes.append(int(entity['Size']))
            colors.append(int(entity['Color']))

        targets['boxes'] = torch.as_tensor(boxes, dtype=torch.float)
        targets['types'] = torch.as_tensor(types, dtype=torch.int64)
        targets['labels'] = torch.as_tensor(types, dtype=torch.int64)
        targets['sizes'] = torch.as_tensor(sizes, dtype=torch.int64)
        targets['colors'] = torch.as_tensor(colors, dtype=torch.int64)
        targets['image_id'] = torch.tensor(idx)
        # targets['image'] = image

        return image, targets