In [2]:
import json
import pandas as pd
import xml.etree.ElementTree as ET
from itertools import product
from pathlib import Path

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import parse_panels, bbox_to_xyxy, parse_rules, plot_example
from avr_dataset import (
    configurations, id2type, id2size, slot2id, id2slot,
    panel_dict_to_df, extract_stage2_ground_truth,
    extract_stage3_ground_truth, prepare_stage3_dataset,
    AVRStage3Dataset
)

In [3]:
class Stage3RelNet(nn.Module):
    def __init__(self, mlp_hidden, classes=8):
        super().__init__()
        self.n_concat = 3 * 3 + 10
        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, 1),
        )
        
    def forward(self, x_panels, x_rules):
        context1, context2 = x_panels[:, 0, :, :], x_panels[:, 1, :, :]
        candidate1, candidate2, candidate3, candidate4, candidate5, candidate6, candidate7, candidate8 = x_panels[:, 2, :, :], x_panels[:, 3, :, :], x_panels[:, 4, :, :], x_panels[:, 5, :, :], x_panels[:, 6, :, :], x_panels[:, 7, :, :], x_panels[:, 8, :, :], x_panels[:, 9, :, :]
        
        context1_repeated = torch.repeat_interleave(context1, 22*22, dim=1)
        context2_repeated = torch.repeat_interleave(context2, 22, dim=1).repeat(1, 22, 1)
        
        c1_repeated = candidate1.repeat(1, 22*22, 1)
        c2_repeated = candidate2.repeat(1, 22*22, 1)
        c3_repeated = candidate3.repeat(1, 22*22, 1)
        c4_repeated = candidate4.repeat(1, 22*22, 1)
        c5_repeated = candidate5.repeat(1, 22*22, 1)
        c6_repeated = candidate6.repeat(1, 22*22, 1)
        c7_repeated = candidate7.repeat(1, 22*22, 1)
        c8_repeated = candidate8.repeat(1, 22*22, 1)
        
        batch_rules = torch.stack(list(x_rules.values())).T
        rules_repeated = batch_rules.unsqueeze(dim=1).repeat(1, 22*22*22, 1)
        
        final_c1 = torch.cat([context1_repeated, context2_repeated, c1_repeated, rules_repeated], dim=2).float()
        final_c2 = torch.cat([context1_repeated, context2_repeated, c2_repeated, rules_repeated], dim=2).float()
        final_c3 = torch.cat([context1_repeated, context2_repeated, c3_repeated, rules_repeated], dim=2).float()
        final_c4 = torch.cat([context1_repeated, context2_repeated, c4_repeated, rules_repeated], dim=2).float()
        final_c5 = torch.cat([context1_repeated, context2_repeated, c5_repeated, rules_repeated], dim=2).float()
        final_c6 = torch.cat([context1_repeated, context2_repeated, c6_repeated, rules_repeated], dim=2).float()
        final_c7 = torch.cat([context1_repeated, context2_repeated, c7_repeated, rules_repeated], dim=2).float()
        final_c8 = torch.cat([context1_repeated, context2_repeated, c8_repeated, rules_repeated], dim=2).float()
        
        g_c1 = self.g(final_c1)
        g_c2 = self.g(final_c2)
        g_c3 = self.g(final_c3)
        g_c4 = self.g(final_c4)
        g_c5 = self.g(final_c5)
        g_c6 = self.g(final_c6)
        g_c7 = self.g(final_c7)
        g_c8 = self.g(final_c8)
        
        g_c1_sum = g_c1.sum(1).squeeze()
        g_c2_sum = g_c2.sum(1).squeeze()
        g_c3_sum = g_c3.sum(1).squeeze()
        g_c4_sum = g_c4.sum(1).squeeze()
        g_c5_sum = g_c5.sum(1).squeeze()
        g_c6_sum = g_c6.sum(1).squeeze()
        g_c7_sum = g_c7.sum(1).squeeze()
        g_c8_sum = g_c8.sum(1).squeeze()
        
        f_c1 = self.f(g_c1_sum)
        f_c2 = self.f(g_c2_sum)
        f_c3 = self.f(g_c3_sum)
        f_c4 = self.f(g_c4_sum)
        f_c5 = self.f(g_c5_sum)
        f_c6 = self.f(g_c6_sum)
        f_c7 = self.f(g_c7_sum)
        f_c8 = self.f(g_c8_sum)

        logits = torch.cat([f_c1, f_c2, f_c3, f_c4, f_c5, f_c6, f_c7, f_c8], axis=1)
        
        return F.log_softmax(logits, dim=1)

In [4]:
dataset = AVRStage3Dataset('dataset', 'train')
train_loader = DataLoader(dataset, 16, shuffle=True)

In [5]:
rn = Stage3RelNet(64)
a = next(iter(train_loader))
z = rn(a['panels'], a['rules'])

In [6]:
z.shape

torch.Size([16, 8])

In [7]:
dataset[0]

{'info': 'dataset/center_single/RAVEN_0_train',
 'panels': tensor([[[-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [ 5,  0,  4],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1]],
 
         [[-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [ 4,  2,  4],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
          [-1, -1, -1],
   

In [10]:
panels, rules, targets= extract_stage3_ground_truth('dataset', 'train')

In [12]:
targets['target'].unique()

array([4, 0, 7, 1, 6, 3, 2, 5])

In [None]:
prepare_stage3_dataset