In [16]:
import json
import numpy as np
import pandas as pd
import xml.etree.ElementTree as ET
from itertools import product
from pathlib import Path
from pprint import pprint
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import dataloader, Dataset

from avr_dataset import panel_dict_to_df, slot2id_distribute9, prepare_stage3_dataset
from utils import parse_panels,parse_rules

In [2]:
configurations = ['distribute_nine']
split = 'train'

In [13]:
def extract_ground_truth(dataset_dir: str, split: str, all_panels=True):
    dataset_path = Path(dataset_dir)
    all_file_stems = list(fn.stem for fn in (dataset_path / Path(configurations[0])).glob(f'*_{split}.npz'))
    all_file_paths = [Path(dataset_path, config, base_fn) for config, base_fn in
                      product(configurations, all_file_stems)]

    full_data = []

    for file_path in all_file_paths:
        npz = np.load(file_path.with_suffix('.npz'))   
        full_data.append({'file': str(file_path),
                          'images': npz['image'],
                          'target': npz['target'].item()})

    return full_data

In [18]:
class AVRCNNDataset(Dataset):
    def __init__(self, dataset_dir, split, configurations=None):
        self.configurations = configuations else os.listdir(dataset_dir)
        self.dataset_path = Path(dataset_dir)
        self.all_file_stems = list(fn.stem for fn in (self.dataset_path / Path(self.configurations[0])).glob(f'*_{split}.npz'))
        self.all_file_paths = [Path(self.dataset_path, config, base_fn) for config, base_fn in
                          product(self.configurations, self.all_file_stems)]
    
        self.data = extract_ground_truth(dataset_dir, split)

    def __len__(self, idx):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [19]:
train_dataset = AVRCNNDataset('dataset', 'train', ['distribute_nine'])

In [22]:
train_dataset[0]['images'].shape

(16, 160, 160)

In [113]:
class CNN(nn.Module):
    def __init__(self, out_dim=256):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 32, 3, stride=2),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(32, 32, 3, stride=2),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(32, 32, 3, stride=2),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU())
        self.conv4 = nn.Sequential(nn.Conv2d(32, 32, 3, stride=2),
                                   nn.BatchNorm2d(32),
                                   nn.ReLU())
        self.fc = nn.Linear(32 * 9 * 9, out_dim)

    def forward(self, X):
        X = self.conv1(X)
        X = self.conv2(X)
        X = self.conv3(X)
        X = self.conv4(X)
        X = torch.flatten(X, start_dim=1)
        X = self.fc(X)
        return X

class RelNet(nn.Module):
    def __init__(self, input_dim=512, g_hidden=512, f_hidden=256):
        super().__init__()
        self.g = nn.Sequential(
            nn.Linear(input_dim, g_hidden),
            nn.ReLU(),
            nn.Linear(g_hidden, g_hidden),
            nn.ReLU(),
            nn.Linear(g_hidden, g_hidden),
            nn.ReLU(),
            nn.Linear(g_hidden, g_hidden),
            nn.ReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(g_hidden, f_hidden),
            nn.ReLU(),
            nn.Linear(f_hidden, f_hidden),
            nn.ReLU(),
            nn.Linear(f_hidden, 13),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(13, 1)
        )

    def forward(self, X, device='cpu'):
        batch_size = X.shape[0]
        contexts, candidates = X[:, :8, :], X[:, 8:, :]

        scores = torch.zeros((batch_size, 8), dtype=torch.float, device=device)
        for i in range(8):
            candidate_i = candidates[:, i, :].unsqueeze(dim=1)
            context_candidate_i = torch.cat([contexts, candidate_i], dim=1)
            context_candidate_pairs = torch.cat([
                context_candidate_i.repeat_interleave(9, dim=1),
                context_candidate_i.repeat(1, 9, 1)
            ], dim=2)
            g_res = self.g(context_candidate_pairs)
            g_res_sum = torch.sum(g_res, dim=1)
            f_res = self.f(g_res_sum)
            scores[:, i] = f_res.squeeze(1)
        return F.softmax(scores, dim=1)
        

class CNNRelNet(nn.Module):
    def __init__(self, embedding_dim=256, g_hidden=512, f_hidden=256):
        super().__init__()
        self.g_hidden = g_hidden
        self.f_hidden = f_hidden
        self.embedding_size = embedding_dim
        self.n_concat = 512
        
        self.cnn = CNN()
        self.relnet = RelNet()

    def forward(self, X, device='cpu'):
        batch_size = X.shape[0]
        embeddings = torch.zeros((batch_size, 16, self.embedding_size))
        for i in range(16):
            panel_i = X[:, i, :, :].unsqueeze(dim=1)
            embeddings[:, i, :] = self.cnn(panel_i)
        scores = self.relnet(embeddings)
        return scores

In [112]:
model = CNNRelNet()
X = torch.ones((5,16,160,160), dtype=torch.float)
model(X).shape

tensor([[-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        ...,
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591]],
       grad_fn=<SelectBackward0>)
tensor([[-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        ...,
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591],
        [-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0591]],
       grad_fn=<SelectBackward0>)
tensor([[-0.0338,  0.0110, -0.0198,  ...,  0.0153,  0.0668, -0.0

torch.Size([5, 8])

In [108]:
torch.tensor([1,2,3]).repeat_interleave(2)

tensor([1, 1, 2, 2, 3, 3])

In [109]:
model = RelNet()
X = torch.tensor([[i]*256 for i in range(16)], dtype=torch.float).unsqueeze(0).repeat(10,1,1)
print(X.shape)
print()
model(X).shape

torch.Size([10, 16, 256])

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 2., 2., 2.],
        ...,
        [8., 8., 8.,  ..., 6., 6., 6.],
        [8., 8., 8.,  ..., 7., 7., 7.],
        [8., 8., 8.,  ..., 8., 8., 8.]])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 2., 2., 2.],
        ...,
        [9., 9., 9.,  ..., 6., 6., 6.],
        [9., 9., 9.,  ..., 7., 7., 7.],
        [9., 9., 9.,  ..., 9., 9., 9.]])
tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  ...,  1.,  1.,  1.],
        [ 0.,  0.,  0.,  ...,  2.,  2.,  2.],
        ...,
        [10., 10., 10.,  ...,  6.,  6.,  6.],
        [10., 10., 10.,  ...,  7.,  7.,  7.],
        [10., 10., 10.,  ..., 10., 10., 10.]])
tensor([[ 0.,  0.,  0.,  ...,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  ...,  1.,  1.,  1.],
        [ 0.,  0.,  0.,  ...,  2.,  2.,  2.],
        ...,
        [11., 11., 11., 

torch.Size([10, 8])

In [72]:
X.shape

torch.Size([16, 512])

In [None]:
class Stage3RelNetV2(nn.Module):
    def __init__(self, mlp_hidden=32, classes=8, n_candidates=8):
        super().__init__()
        self.n_concat = 9 * 3 * 2
        self.mlp_hidden = mlp_hidden
        self.n_candidates = n_candidates

        self.g = nn.Sequential(
            nn.Linear(self.n_concat, mlp_hidden),
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.LeakyReLU(),
        )

        self.f = nn.Sequential(
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.LeakyReLU(),
            nn.Linear(mlp_hidden, mlp_hidden),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(mlp_hidden, 1),
            nn.Sigmoid(),
        )

    def forward(self, X, device='cpu'):
        X = X.float()
        n_batch = X.size()[0]
        contexts, candidates = X[:, :8, :], X[:, 8:, :]
        contexts = contexts.unsqueeze(1).repeat(1, 8, 1, 1)
        candidates = candidates.unsqueeze(1).permute(0, 2, 1, 3)

        # Shape: (n_batch, n_candidates=8, n_panels=9, num_dim=22 * 3)
        context_candidate_pairs = torch.cat([contexts, candidates], dim=2)

        context_candidate_pairs1 = context_candidate_pairs.repeat_interleave(9, dim=2)
        context_candidate_pairs2 = context_candidate_pairs.repeat(1, 1, 9, 1)

        # Shape: (n_batch, n_candidates=8, n_pairs_per_candidate=9*9, num_dim=22 * 3)
        context_candidate_concat = torch.cat([context_candidate_pairs1, context_candidate_pairs2], dim=3)

        candidate_logits = torch.zeros((n_batch, self.n_candidates), dtype=torch.float).to(device)
        for i in range(self.n_candidates):
            # Shape: (n_batch, n_pairs_per_candidate=9*9, num_dim=22 * 3)
            all_pairs_candidate_i = context_candidate_concat[:, i, :, :]
            n_pairs = all_pairs_candidate_i.shape[1]
            # Shape: (n_batch, n_pairs_per_candidate=9*9, mlp_hidden)
            g_res = self.g(all_pairs_candidate_i)
            g_res_sum = torch.sum(g_res, dim=1)

            # Shape: (n_batch)
            f_res = self.f(g_res_sum)
            candidate_logits[:, i] = f_res.squeeze(1)
        return F.softmax(candidate_logits, dim=1)

In [47]:
model = CNN()
X = torch.ones((5,1,160,160))
model(X).shape

torch.Size([5, 256])

In [3]:
nn.Conv2d(1, 4, 32)

Conv2d(1, 4, kernel_size=(32, 32), stride=(1, 1))

In [5]:
for name, params in nn.Conv2d(1, 32, 3, stride=2).named_parameters():
    if 'conv' in name:
        print(name, params.size())