In [32]:
import os
import json
import torch
import numpy as np
import pandas as pd
from pprint import pprint
from torch.utils.data import Dataset
from itertools import product
import xml.etree.ElementTree as ET
from pathlib import Path

from utils import parse_rules, parse_panels, bbox_to_xyxy, plot_example
from avr_dataset import (
    extract_stage3_ground_truth,
    prepare_stage3_dataset
)

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from collections import Counter
from tqdm import tqdm
from torch import optim
from torch.optim.lr_scheduler import StepLR

In [120]:
def prepare_stage3_dataset(panels_df: pd.DataFrame, rules_df: pd.DataFrame | None, target_df: pd.DataFrame, all_panels=True):
    panels_df_copy = panels_df.copy()
    reshaped_indices = ['file', 'component', 'panel']

    reshaped_panels_df = panels_df_copy.set_index(reshaped_indices).unstack(level=-1)
    reshaped_panels_df.columns.names = ['slot_attr', 'panel']
    reshaped_panels_df.columns = reshaped_panels_df.columns.swaplevel(0, 1)
    reshaped_panels_df = reshaped_panels_df.sort_index(axis=1, level=0)

    index_tuples = []
    panel_idx_range = range(16) if all_panels else range(6, 16)
    for panel_idx, slot_idx, attr in list(product(panel_idx_range,
                                                  range(0, 22),
                                                  ['color', 'size', 'type'])):
        index_tuples.append((panel_idx, f'slot{slot_idx}_{attr}'))
    multi_index = pd.MultiIndex.from_tuples(index_tuples, names=['panel', 'slot_attr'])
    reshaped_panels_df = pd.DataFrame(reshaped_panels_df, columns=multi_index)

    reshaped_panels_df.columns = reshaped_panels_df.columns.map(lambda x: 'panel' + '_'.join(list(map(str, x))))
    reshaped_panels_df = reshaped_panels_df.groupby('file').max()

    if rules_df is not None:
        rules_df = rules_df.rename(columns={'file_path': 'file'})
        rules_df = rules_df.set_index(['file'])

        final_df = reshaped_panels_df.join(rules_df).join(target_df.set_index(['file']))
    else:
        final_df = reshaped_panels_df.join(target_df.set_index(['file']))

    return final_df

In [4]:
class AVRStage3DatasetV2(Dataset):
    def __init__(self, dataset_dir, split):
        super().__init__()
        panels_df, rules_df, targets_df = extract_stage3_ground_truth(dataset_dir, split, all_panels=True)
        self.final_df = prepare_stage3_dataset(panels_df, None, targets_df, all_panels=True)
        self.final_df = self.final_df.reset_index()
        self.info_col = self.final_df.columns.tolist()[0]
        self.panel_cols = self.final_df.columns.tolist()[1:-1]
        self.target_col = self.final_df.columns.tolist()[-1]
        self.rule2id = {'Constant': 0, 'Distribute_Three': 1, 'Progression': 2, 'Arithmetic': 3, -1: -1}
    
    def __len__(self):
        return len(self.final_df)
    
    def __getitem__(self, idx):
        data = self.final_df.iloc[idx]

        info = data[self.info_col]

        panels = torch.split(torch.tensor(data[self.panel_cols].values.astype(np.int64)), 22 * 3)
        # reshaped_panels = list(torch.stack(torch.split(p, 3)) for p in panels)
        panel_features = torch.stack(panels)

        return {
            'info': info,
            'panels': panel_features,
            'target': torch.tensor(data[self.target_col])
        }

In [5]:
dataset = AVRStage3DatasetV2('dataset', 'train')
dataloader = DataLoader(dataset, 32, shuffle=True)

In [6]:
a = next(iter(dataloader))

In [9]:
a['panels'].shape

torch.Size([32, 16, 66])

In [26]:
class Stage3RelNetV2(nn.Module):
    def __init__(self, mlp_hidden=64, classes=8, n_candidates=8):
        super().__init__()
        self.n_concat = 22 * 3 * 2
        self.mlp_hidden = mlp_hidden
        self.n_candidates = n_candidates
        
        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, X):
        X = X.float()
        n_batch = X.size()[0]
        contexts, candidates = X[:, :8, :], X[:, 8:, :]
        contexts = contexts.unsqueeze(1).repeat(1, 8, 1, 1)
        candidates = candidates.unsqueeze(1).permute(0, 2, 1, 3)

        # Shape: (n_batch, n_candidates=8, n_panels=9, num_dim=22 * 3)
        context_candidate_pairs = torch.cat([contexts, candidates], dim=2)

        context_candidate_pairs1 = context_candidate_pairs.repeat_interleave(9, dim=2)
        context_candidate_pairs2 = context_candidate_pairs.repeat(1, 1, 9, 1)

        # Shape: (n_batch, n_candidates=8, n_pairs_per_candidate=9*9, num_dim=22 * 3)
        context_candidate_concat = torch.cat([context_candidate_pairs1, context_candidate_pairs2], dim=3)

        candidate_logits = torch.zeros((n_batch, self.n_candidates))
        for i in range(self.n_candidates):
            # Shape: (n_batch, n_pairs_per_candidate=9*9, num_dim=22 * 3)
            all_pairs_candidate_i = context_candidate_concat[:, i, :, :]
            n_pairs = all_pairs_candidate_i.shape[1]
            # Shape: (n_batch, n_pairs_per_candidate=9*9, mlp_hidden)
            g_res = torch.zeros((n_batch, n_pairs, self.mlp_hidden), dtype=torch.float)
            for j in range(n_pairs):
                candidate_i_pair_j = all_pairs_candidate_i[:, j, :]
                g_ci_pj = self.g(candidate_i_pair_j)
                g_res[:, j, :] = g_ci_pj
            # Shape: (n_batch, mlp_hidden)
            g_res_sum = torch.sum(g_res, dim=1)

            # Shape: (n_batch)
            f_res = self.f(g_res_sum)
            candidate_logits[:, i] = f_res.squeeze(1)
        return F.softmax(candidate_logits, dim=1)

In [27]:
net = Stage3RelNetV2()

In [28]:
22*3*2

132

In [29]:
net(a['panels'])

tensor([[0.1283, 0.1067, 0.1372, 0.1329, 0.1176, 0.1115, 0.1213, 0.1445],
        [0.1285, 0.1254, 0.1320, 0.1455, 0.1317, 0.1128, 0.1120, 0.1121],
        [0.1239, 0.1147, 0.1382, 0.1114, 0.1468, 0.1285, 0.1426, 0.0939],
        [0.1227, 0.1054, 0.1425, 0.1382, 0.1191, 0.1250, 0.1124, 0.1347],
        [0.1374, 0.1364, 0.1200, 0.1252, 0.1369, 0.1096, 0.1231, 0.1115],
        [0.1424, 0.1350, 0.1179, 0.1104, 0.1289, 0.1218, 0.1143, 0.1293],
        [0.1435, 0.1004, 0.1057, 0.1375, 0.1321, 0.1402, 0.1183, 0.1224],
        [0.1296, 0.1505, 0.1064, 0.1268, 0.1194, 0.1291, 0.1166, 0.1216],
        [0.1237, 0.1185, 0.1488, 0.1213, 0.1263, 0.1250, 0.1149, 0.1215],
        [0.1241, 0.1378, 0.1187, 0.1526, 0.1283, 0.1255, 0.1137, 0.0993],
        [0.1244, 0.1152, 0.1387, 0.1349, 0.1004, 0.1243, 0.1183, 0.1437],
        [0.1035, 0.1363, 0.1464, 0.1240, 0.1198, 0.1206, 0.1330, 0.1164],
        [0.1061, 0.1210, 0.1296, 0.1356, 0.1331, 0.1234, 0.1396, 0.1116],
        [0.1019, 0.1041, 0.1487, 0.101

In [37]:
pd.DataFrame({'a': [-1, -2], 'b': [-3, -4]}).replace({-1: 15})

Unnamed: 0,a,b
0,15,-3
1,-2,-4
