In [94]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pprint import pprint
from torch.utils.data import Dataset
from itertools import product
import xml.etree.ElementTree as ET
from pathlib import Path

from utils import parse_rules, parse_panels, bbox_to_xyxy, plot_example
from avr_dataset import (
    extract_stage3_ground_truth,
    #prepare_stage3_dataset
)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from collections import Counter
from tqdm import tqdm
from torch import optim
from torch.optim.lr_scheduler import StepLR

In [9]:
def prepare_stage3_dataset(panels_df: pd.DataFrame, rules_df: pd.DataFrame | None, target_df: pd.DataFrame, all_panels=True):
    panels_df_copy = panels_df.copy()
    reshaped_indices = ['file', 'component', 'panel']

    reshaped_panels_df = panels_df_copy.set_index(reshaped_indices).unstack(level=-1)
    reshaped_panels_df.columns.names = ['slot_attr', 'panel']
    reshaped_panels_df.columns = reshaped_panels_df.columns.swaplevel(0, 1)
    reshaped_panels_df = reshaped_panels_df.sort_index(axis=1, level=0)

    index_tuples = []
    panel_idx_range = range(16) if all_panels else range(6, 16)
    for panel_idx, slot_idx, attr in list(product(panel_idx_range,
                                                  range(0, 22),
                                                  ['color', 'size', 'type'])):
        index_tuples.append((panel_idx, f'slot{slot_idx}_{attr}'))
    multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['panel', 'slot_attr'])
    reshaped_panels_df = pd.DataFrame(reshaped_panels_df, columns=multi_index)

    reshaped_panels_df.columns = reshaped_panels_df.columns.map(lambda x: 'panel' + '_'.join(list(map(str, x))))
    reshaped_panels_df = reshaped_panels_df.groupby('file').max()

    if rules_df is not None:
        rules_df = rules_df.rename(columns={'file_path': 'file'})
        rules_df = rules_df.set_index(['file'])

        final_df = reshaped_panels_df.join(rules_df).join(target_df.set_index(['file']))
    else:
        final_df = reshaped_panels_df.join(target_df.set_index(['file']))

    return final_df

In [3]:
panels, rules, targets = extract_stage3_ground_truth('dataset', 'train')

In [10]:
prepare_stage3_dataset(panels, rules, targets)

Unnamed: 0_level_0,panel0_slot0_color,panel0_slot0_size,panel0_slot0_type,panel0_slot1_color,panel0_slot1_size,panel0_slot1_type,panel0_slot2_color,panel0_slot2_size,panel0_slot2_type,panel0_slot3_color,...,component0_position,component0_type,component0_size,component0_color,component1_number,component1_position,component1_type,component1_size,component1_color,target
file,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
dataset/center_single/RAVEN_0_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Progression,Progression,,,,,,6
dataset/center_single/RAVEN_100_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Progression,Progression,,,,,,6
dataset/center_single/RAVEN_101_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Progression,Arithmetic,Distribute_Three,,,,,,1
dataset/center_single/RAVEN_102_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Distribute_Three,Arithmetic,Constant,,,,,,2
dataset/center_single/RAVEN_103_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Distribute_Three,Constant,Distribute_Three,,,,,,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
dataset/up_center_single_down_center_single/RAVEN_991_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Progression,Progression,Constant,Constant,Constant,Progression,Constant,Progression,5
dataset/up_center_single_down_center_single/RAVEN_992_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Constant,Progression,Constant,Constant,Progression,Distribute_Three,Arithmetic,6
dataset/up_center_single_down_center_single/RAVEN_993_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Distribute_Three,Arithmetic,Distribute_Three,Constant,Constant,Constant,Arithmetic,Progression,3
dataset/up_center_single_down_center_single/RAVEN_994_train,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,...,Constant,Constant,Arithmetic,Distribute_Three,Constant,Constant,Distribute_Three,Arithmetic,Distribute_Three,1


In [11]:
class AVRStage3DatasetV2(Dataset):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df, targets_df = extract_stage3_ground_truth(dataset_dir, split, all_panels=True)
        self.final_df = prepare_stage3_dataset(panels_df, None, targets_df, all_panels=True)
        self.final_df = self.final_df.reset_index()
        self.info_col = self.final_df.columns.tolist()[0]
        self.panel_cols = self.final_df.columns.tolist()[1:-1]
        self.target_col = self.final_df.columns.tolist()[-1]
        self.rule2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3, -1: -1}
    
    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]

        info = data[self.info_col]

        panels = torch.split(torch.tensor(data[self.panel_cols].values.astype(np.int64)), 22 * 3)
        # reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        panel_features = torch.stack(panels)

        return {
            'info': info,
            'panels': panel_features,
            'target': torch.tensor(data[self.target_col])
        }

In [27]:
dataset = AVRStage3DatasetV2('dataset', 'train')
dataloader = DataLoader(dataset, 32, shuffle=True)

In [28]:
a = next(iter(dataloader))

In [115]:
class Stage3RelNetV2(nn.Module):
    def __init__(self, mlp_hidden=64, classes=8, n_candidates=8):
        super().__init__()
        self.n_concat = 22 * 3 * 2
        self.mlp_hidden = mlp_hidden
        self.n_candidates = n_candidates
        
        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, X):
        n_batch = X.size()[0]
        contexts, candidates = X[:, :8, :], X[:, 8:, :]
        contexts = contexts.unsqueeze(1).repeat(1, 8, 1, 1)
        candidates = candidates.unsqueeze(1).permute(0, 2, 1, 3)

        # Shape: (n_batch, n_candidates=8, n_panels=9, num_dim=22 * 3)
        context_candidate_pairs = torch.cat([contexts, candidates], dim=2)

        context_candidate_pairs1 = context_candidate_pairs.repeat_interleave(9, dim=2)
        context_candidate_pairs2 = context_candidate_pairs.repeat(1, 1, 9, 1)

        # Shape: (n_batch, n_candidates=8, n_pairs_per_candidate=9*9, num_dim=22 * 3)
        context_candidate_concat = torch.cat([context_candidate_pairs1, context_candidate_pairs2], dim=3)

        candidate_logits = torch.zeros((n_batch, self.n_candidates))
        for i in range(self.n_candidates):
            # Shape: (n_batch, n_pairs_per_candidate=9*9, num_dim=22 * 3)
            all_pairs_candidate_i = context_candidate_concat[:, i, :, :]
            n_pairs = all_pairs_candidate_i.shape[1]
            # Shape: (n_batch, n_pairs_per_candidate=9*9, mlp_hidden)
            g_res = torch.zeros((n_batch, n_pairs, self.mlp_hidden), dtype=torch.float)
            for j in range(n_pairs):
                candidate_i_pair_j = all_pairs_candidate_i[:, j, :]
                g_ci_pj = self.g(candidate_i_pair_j)
                g_res[:, j, :] = g_ci_pj
            # Shape: (n_batch, mlp_hidden)
            g_res_sum = torch.sum(g_res, dim=1)

            # Shape: (n_batch)
            f_res = self.f(g_res_sum)
            candidate_logits[:, i] = f_res.squeeze(1)
        return F.softmax(candidate_logits, dim=1)

In [116]:
net = Stage3RelNetV2()