In [1]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pprint import pprint
from torch.utils.data import Dataset
from itertools import product
import xml.etree.ElementTree as ET
from pathlib import Path

from utils import parse_rules, parse_panels, bbox_to_xyxy, plot_example
from avr_dataset import (
    extract_stage3_ground_truth,
    #prepare_stage3_dataset
)

import torch
from torch import nn
from torch.utils.data import DataLoader

from collections import Counter
from tqdm import tqdm
from torch import optim
from torch.optim.lr_scheduler import StepLR

In [2]:
def prepare_stage3_dataset(panels_df: pd.DataFrame, rules_df: pd.DataFrame | None, target_df: pd.DataFrame, all_panels=True):
    panels_df_copy = panels_df.copy()
    reshaped_indices = ['file', 'component', 'panel']

    reshaped_panels_df = panels_df_copy.set_index(reshaped_indices).unstack(level=-1)
    reshaped_panels_df.columns.names = ['slot_attr', 'panel']
    reshaped_panels_df.columns = reshaped_panels_df.columns.swaplevel(0, 1)
    reshaped_panels_df = reshaped_panels_df.sort_index(axis=1, level=0)

    index_tuples = []
    panel_idx_range = range(16) if all_panels else range(6, 16)
    for panel_idx, slot_idx, attr in list(product(panel_idx_range,
                                                  range(0, 22),
                                                  ['color', 'size', 'type'])):
        index_tuples.append((panel_idx, f'slot{slot_idx}_{attr}'))
    multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['panel', 'slot_attr'])
    reshaped_panels_df = pd.DataFrame(reshaped_panels_df, columns=multi_index)

    reshaped_panels_df.columns = reshaped_panels_df.columns.map(lambda x: 'panel' + '_'.join(list(map(str, x))))
    reshaped_panels_df = reshaped_panels_df.groupby('file').max()

    # if rules_df is not None:
    #     rules_df = rules_df.rename(columns={'file_path': 'file'})
    #     rules_df = rules_df.set_index(['file'])

    #     final_df = reshaped_panels_df.join(rules_df).join(target_df.set_index(['file']))
    # else:
    #     final_df = reshaped_panels_df.join(target_df.set_index(['file']))

    # return final_df

In [3]:
panels, rules, targets = extract_stage3_ground_truth('dataset', 'train')

In [8]:
prepare_stage3_dataset(panels, rules, targets)

Unnamed: 0_level_0,panel0_slot0_color,panel0_slot0_size,panel0_slot0_type,panel0_slot1_color,panel0_slot1_size,panel0_slot1_type,panel0_slot2_color,panel0_slot2_size,panel0_slot2_type,panel0_slot3_color,...,panel15_slot18_type,panel15_slot19_color,panel15_slot19_size,panel15_slot19_type,panel15_slot20_color,panel15_slot20_size,panel15_slot20_type,panel15_slot21_color,panel15_slot21_size,panel15_slot21_type
file,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
dataset/center_single/RAVEN_0_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
dataset/center_single/RAVEN_100_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
dataset/center_single/RAVEN_101_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
dataset/center_single/RAVEN_102_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
dataset/center_single/RAVEN_103_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
dataset/up_center_single_down_center_single/RAVEN_991_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,5,5,1,-1,-1,-1,-1,-1,-1
dataset/up_center_single_down_center_single/RAVEN_992_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,4,2,1,-1,-1,-1,-1,-1,-1
dataset/up_center_single_down_center_single/RAVEN_993_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,1,2,5,-1,-1,-1,-1,-1,-1
dataset/up_center_single_down_center_single/RAVEN_994_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,-1,1,5,2,-1,-1,-1,-1,-1,-1


In [None]:
class AVRStage3DatasetV2(Dataset):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df, targets_df = extract_stage3_ground_truth(dataset_dir, split, all_panels=False)
        self.final_df = prepare_stage3_dataset(panels_df, rules_df, targets_df, all_panels=False)
        self.final_df = self.final_df.reset_index()
        self.info_col = self.final_df.columns.tolist()[0]
        self.panel_cols = self.final_df.columns.tolist()[1:-11]
        self.rule_cols = self.final_df.columns.tolist()[-11:-1]
        self.target_col = self.final_df.columns.tolist()[-1]
        self.rule2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3, -1: -1}
    
    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]

        info = data[self.info_col]

        panels = torch.split(torch.tensor(data[self.panel_cols].values.astype(np.int64)), 22 * 3)
        reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        panel_features = torch.stack(reshaped_panels)
        
        rules = data[self.rule_cols].replace(np.nan, -1).map(self.rule2id).to_dict()
        for key, val, in rules.items():
            rules[key] = torch.tensor(val)

        return {
            'info': info,
            'panels': panel_features,
            'rules': rules,
            'target': torch.tensor(data[self.target_col])
        }
