In [1]:
%cd -q ..

/home/zarizky/projects/neural-autoregressive-object-co-occurrence


In [2]:
from collections import OrderedDict
from copy import deepcopy
from math import log

import numpy as np
import pandas as pd
from scipy import special
from scipy.stats import poisson
from tqdm.auto import tqdm
from xgboost import XGBRegressor

In [3]:
np.random.seed(0)

df_train = pd.read_csv("dataset/coco2017-cooccurences-train.csv")
df_valid = pd.read_csv("dataset/coco2017-cooccurences-valid.csv")

X_train = df_train.iloc[:, 8:]
X_valid = df_valid.iloc[:, 8:]

orders = [np.random.permutation(X_train.columns) for i in range(100)]
orders = np.stack(orders)

data = []
for order in orders:
    data.append(dict(train=X_train.loc[:, order], valid=X_valid.loc[:, order]))

In [4]:
log_likelihood = dict(train=[], valid=[])
for data_order in (pbar := tqdm(data, unit="order")):
    model_autoreg = []
    for i in range(1, data[0]["train"].shape[-1]):
        X = data_order["train"].iloc[:, :i].to_numpy()
        y = data_order["train"].iloc[:, i].to_numpy()
        m = XGBRegressor(objective="count:poisson", random_state=0)
        m.fit(X, y)
        model_autoreg.append(deepcopy(m))

    X = data_order["train"].iloc[:, :1].to_numpy()
    for i, model in enumerate(model_autoreg):
        X = np.concatenate([X, model_autoreg[i].predict(X)[:, None]], axis=-1)
    log_likelihood["train"].append(poisson.logpmf(data_order["train"], X).sum(-1))

    X = data_order["valid"].iloc[:, :1].to_numpy()
    for i, model in enumerate(model_autoreg):
        X = np.concatenate([X, model_autoreg[i].predict(X)[:, None]], axis=-1)
    log_likelihood["valid"].append(poisson.logpmf(data_order["valid"], X).sum(-1))

    pbar.set_postfix(
        OrderedDict(
            [
                (
                    subset,
                    "{:.4f}".format(
                        (
                            special.logsumexp(log_likelihood[subset], 0)
                            - np.log(len(log_likelihood[subset]))
                        ).mean()
                    ),
                )
                for subset in ["train", "valid"]
            ]
        )
    )

for key in log_likelihood:
    log_likelihood[key] = np.array(log_likelihood[key])

  0%|          | 0/100 [00:00<?, ?order/s]

In [97]:
df_order = pd.DataFrame(
    orders, columns=["var_{:0=2d}".format(order + 1) for order in range(80)]
)

df_train = pd.DataFrame(
    log_likelihood["train"].T,
    columns=["order_{:0=2d}".format(order + 1) for order in range(100)],
)
df_valid = pd.DataFrame(
    log_likelihood["valid"].T,
    columns=["order_{:0=2d}".format(order + 1) for order in range(100)],
)

df_order.to_csv("outputs/xgb_order.csv", index=False)
df_train.to_csv("outputs/xgb_log_likelihood_train.csv", index=False)
df_valid.to_csv("outputs/xgb_log_likelihood_valid.csv", index=False)

Unnamed: 0,var_01,var_02,var_03,var_04,var_05,var_06,var_07,var_08,var_09,var_10,...,var_71,var_72,var_73,var_74,var_75,var_76,var_77,var_78,var_79,var_80
0,[broccoli]-[food],[tie]-[accessory],[skis]-[sports],[teddy bear]-[indoor],[bowl]-[kitchen],[sink]-[appliance],[keyboard]-[electronic],[microwave]-[appliance],[sandwich]-[food],[oven]-[appliance],...,[parking meter]-[outdoor],[toaster]-[appliance],[skateboard]-[sports],[bear]-[animal],[traffic light]-[outdoor],[scissors]-[indoor],[cell phone]-[electronic],[mouse]-[electronic],[apple]-[food],[spoon]-[kitchen]
1,[airplane]-[vehicle],[spoon]-[kitchen],[tv]-[electronic],[backpack]-[accessory],[bus]-[vehicle],[surfboard]-[sports],[parking meter]-[outdoor],[truck]-[vehicle],[clock]-[indoor],[carrot]-[food],...,[bench]-[outdoor],[giraffe]-[animal],[potted plant]-[furniture],[elephant]-[animal],[cat]-[animal],[hot dog]-[food],[scissors]-[indoor],[motorcycle]-[vehicle],[apple]-[food],[microwave]-[appliance]
2,[teddy bear]-[indoor],[keyboard]-[electronic],[skis]-[sports],[person]-[person],[orange]-[food],[train]-[vehicle],[broccoli]-[food],[handbag]-[accessory],[laptop]-[electronic],[potted plant]-[furniture],...,[bear]-[animal],[frisbee]-[sports],[backpack]-[accessory],[dog]-[animal],[chair]-[furniture],[stop sign]-[outdoor],[cell phone]-[electronic],[airplane]-[vehicle],[bird]-[animal],[toilet]-[furniture]
3,[bowl]-[kitchen],[microwave]-[appliance],[banana]-[food],[bed]-[furniture],[backpack]-[accessory],[clock]-[indoor],[mouse]-[electronic],[potted plant]-[furniture],[apple]-[food],[bench]-[outdoor],...,[boat]-[vehicle],[tv]-[electronic],[pizza]-[food],[oven]-[appliance],[baseball bat]-[sports],[vase]-[indoor],[person]-[person],[skateboard]-[sports],[bus]-[vehicle],[kite]-[sports]
4,[toaster]-[appliance],[broccoli]-[food],[microwave]-[appliance],[laptop]-[electronic],[couch]-[furniture],[cow]-[animal],[wine glass]-[kitchen],[carrot]-[food],[spoon]-[kitchen],[chair]-[furniture],...,[bird]-[animal],[fork]-[kitchen],[sports ball]-[sports],[umbrella]-[accessory],[skateboard]-[sports],[vase]-[indoor],[mouse]-[electronic],[cake]-[food],[potted plant]-[furniture],[bench]-[outdoor]
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,[microwave]-[appliance],[stop sign]-[outdoor],[keyboard]-[electronic],[suitcase]-[accessory],[oven]-[appliance],[train]-[vehicle],[airplane]-[vehicle],[horse]-[animal],[sink]-[appliance],[fire hydrant]-[outdoor],...,[cake]-[food],[bicycle]-[vehicle],[chair]-[furniture],[baseball glove]-[sports],[hot dog]-[food],[tv]-[electronic],[cat]-[animal],[hair drier]-[indoor],[bowl]-[kitchen],[refrigerator]-[appliance]
96,[frisbee]-[sports],[hot dog]-[food],[stop sign]-[outdoor],[airplane]-[vehicle],[bird]-[animal],[remote]-[electronic],[dog]-[animal],[cow]-[animal],[person]-[person],[scissors]-[indoor],...,[knife]-[kitchen],[bus]-[vehicle],[baseball glove]-[sports],[skateboard]-[sports],[sports ball]-[sports],[cell phone]-[electronic],[broccoli]-[food],[chair]-[furniture],[bed]-[furniture],[clock]-[indoor]
97,[bottle]-[kitchen],[baseball bat]-[sports],[cup]-[kitchen],[frisbee]-[sports],[sheep]-[animal],[bed]-[furniture],[umbrella]-[accessory],[bear]-[animal],[person]-[person],[bird]-[animal],...,[backpack]-[accessory],[cat]-[animal],[knife]-[kitchen],[elephant]-[animal],[potted plant]-[furniture],[mouse]-[electronic],[teddy bear]-[indoor],[keyboard]-[electronic],[cake]-[food],[pizza]-[food]
98,[refrigerator]-[appliance],[toilet]-[furniture],[truck]-[vehicle],[carrot]-[food],[bowl]-[kitchen],[book]-[indoor],[oven]-[appliance],[fire hydrant]-[outdoor],[tv]-[electronic],[airplane]-[vehicle],...,[tie]-[accessory],[sports ball]-[sports],[hot dog]-[food],[mouse]-[electronic],[parking meter]-[outdoor],[tennis racket]-[sports],[toothbrush]-[indoor],[cup]-[kitchen],[horse]-[animal],[bicycle]-[vehicle]
