In [1]:
%cd ..

/home/zarizky/projects/neural-autoregressive-object-co-occurrence


In [2]:
import einops
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch_optimizer as optim
from tqdm.auto import tqdm

from utils.dataset import ObjectCooccurrenceCOCODataset

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
dataset_train = ObjectCooccurrenceCOCODataset("dataset/coco2017-cooccurences-train.csv")
dataset_valid = ObjectCooccurrenceCOCODataset("dataset/coco2017-cooccurences-valid.csv")

batch_size = 8196
dataloader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size,
    True,
    num_workers=8,
    pin_memory=device == "cuda",
)

dataloader_valid = torch.utils.data.DataLoader(
    dataset_valid, batch_size, num_workers=8, pin_memory=device == "cuda"
)

In [4]:
class CategoricalGLM(torch.nn.Module):
    def __init__(self, num_features, max_value):
        super().__init__()

        self.num_features = num_features
        self.max_value = max_value

        self.pattern = "... d1, d1 d2 k -> ... d2 k"

        mask = torch.eye(self.num_features, dtype=torch.bool)
        mask = mask.logical_not()

        self.register_buffer("mask", mask)
        self.reset_parameters()

    def reset_parameters(self):
        weight = torch.empty(self.num_features, self.num_features, self.max_value)
        bias = torch.empty(self.num_features, self.max_value)

        torch.nn.init.xavier_uniform_(weight)
        torch.nn.init.zeros_(bias)

        self.weight = torch.nn.Parameter(weight)
        self.bias = torch.nn.Parameter(bias)

    def mask_weight(self):
        pattern = "d1 d2 k, d1 d2 -> d1 d2 k"
        return einops.einsum(self.weight, self.mask, pattern)

    def forward(self, inputs):
        weight = self.mask_weight()
        outputs = einops.einsum(inputs, weight, self.pattern)
        outputs = outputs + self.bias
        return outputs

In [5]:
epochs = 500

glm = CategoricalGLM(80, 29).to(device)
opt = optim.Lamb(glm.parameters(), lr=1e-5)
sch = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=opt, max_lr=1e-1, steps_per_epoch=len(dataloader_train), epochs=epochs
)

values = torch.arange(29).to(device)

for epoch in (pbar := tqdm(range(1, epochs + 1), unit="epoch")):
    for inputs in dataloader_train:
        targets = inputs.to(device)
        inputs = inputs.float().to(device)

        outputs = glm(inputs)
        dist = torch.distributions.Categorical(logits=outputs)
        dist = torch.distributions.Independent(dist, 1)

        ll = dist.log_prob(targets)
        loss = -ll.mean()
        loss.backward()
        opt.step()
        sch.step()
        opt.zero_grad()

    ll_train = 0
    ll_valid = 0
    with torch.inference_mode():
        for inputs in dataloader_train:
            targets = inputs.to(device)
            inputs = inputs.float().to(device)

            outputs = glm(inputs)
            dist = torch.distributions.Categorical(logits=outputs)
            dist = torch.distributions.Independent(dist, 1)
            expected_values = (values * outputs.softmax(-1)).sum(-1)

            ll = dist.log_prob(targets)
            ll_train = ll_train + ll.sum().item()

        for inputs in dataloader_valid:
            targets = inputs.to(device)
            inputs = inputs.float().to(device)

            outputs = glm(inputs)
            dist = torch.distributions.Categorical(logits=outputs)
            dist = torch.distributions.Independent(dist, 1)
            expected_values = (values * outputs.softmax(-1)).sum(-1)

            ll = dist.log_prob(targets)
            ll_valid = ll_valid + ll.sum().item()

    pbar.set_postfix(
        [
            ("train_nll", f"{-ll_train / len(dataset_train):.4f}"),
            ("valid_nll", f"{-ll_valid / len(dataset_valid):.4f}"),
        ]
    )

  0%|          | 0/500 [00:00<?, ?epoch/s]