In [None]:
from options.options import Options
import os
import torch
from build_dataset_model import build_loaders, build_model
from utils import get_model_attr, calculate_model_losses, tensor_aug
from collections import defaultdict
import math


In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np

from new.CustomVAE import *
from new.utils import resolve_relative_positions
from utils import calculate_model_losses
from data.suncg_dataset import g_add_in_room_relation, g_use_heuristic_relation_matrix, \
    g_prepend_room, g_add_random_parent_link, g_shuffle_subject_object

# decoder option
g_decoder_option = "original" #"rgcn"
g_relative_location = "False"
g_parent_link_index = 16

args = Options().parse()
if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
    os.mkdir(args.output_dir)
if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
    os.mkdir(args.test_dir)

In [None]:
# has KL divergence loss
args.use_AE = False

In [None]:
# # tensorboard
# writer = SummaryWriter()

# writer.add_hparams({
#     "experiment type": "Decoder only: original version of the decoder",
#     "decoder type": g_decoder_option,
#     "use relative location": g_relative_location,
#     "Add 'in_room' relation": g_add_in_room_relation,
#     "Use heuristic relation matrix": g_use_heuristic_relation_matrix,
#     "prepend/append room info": g_prepend_room,
#     "add random parent link": g_add_random_parent_link,
#     "shuffle object/subject when loading data": g_shuffle_subject_object,
# }, {"NA": 0})

In [None]:
# load data
vocab, train_loader, val_loader = build_loaders(args)

dt = train_loader.dataset

In [None]:
# model args
kwargs = {
        'vocab': dt.vocab,
        'batch_size': args.batch_size,
        'train_3d': args.train_3d,
        'decoder_cat': args.decoder_cat,
        'embedding_dim': 84,#args.embedding_dim,
        'gconv_mode': args.gconv_mode,
        'gconv_num_layers': args.gconv_num_layers,
        'mlp_normalization': args.mlp_normalization,
        'vec_noise_dim': args.vec_noise_dim,
        'layout_noise_dim': args.layout_noise_dim,
        'use_AE': args.use_AE
    }

In [None]:
model_encoder = TransformerEncoder(**kwargs)
model_encoder = model_encoder.cuda()
optimizer_encoder = torch.optim.Adam(model_encoder.parameters(), lr=args.learning_rate)

# load decoder
if g_decoder_option == "original":
    model_decoder = OriVAEDecoder(**kwargs)
elif g_decoder_option == "rgcn":
    model_decoder = RGCNConv(**kwargs)
else:
    raise("MODEL MISSING {}".format(g_decoder_option))
model_decoder = model_decoder.cuda()
optimizer_decoder = torch.optim.Adam(model_decoder.parameters(), lr=args.learning_rate) 

In [None]:
for batch in tqdm(train_loader):
    #t += 1
    ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
    z = torch.randn(objs.size(0), 64).to(objs.device)
    break

In [None]:
# attention mask
obj_counts = [torch.sum(obj_to_img == i).item() for i in range(args.batch_size)]
block_list = [torch.ones((obj_counts[i],obj_counts[i])) for i in range(args.batch_size)]
attention_mask = torch.block_diag(*block_list).to(objs.device) # [BxB]

# encoder
hidden_states = model_encoder.encoder(objs, boxes, angles, attributes, attention_mask)
mu, logvar = model_encoder.get_hidden_representation(hidden_states)

In [None]:
# obj_vecs = hidden_states.squeeze(0)
# obj_vecs_box = model_encoder.box_mean_var(obj_vecs)
# mu_box = model_encoder.box_mean(obj_vecs_box)
# logvar_box = model_encoder.box_var(obj_vecs_box)

# obj_vecs_angle = model_encoder.angle_mean_var(obj_vecs)
# mu_angle = model_encoder.angle_mean(obj_vecs_angle)
# logvar_angle = model_encoder.angle_var(obj_vecs_angle)

# mu = torch.cat([mu_box, mu_angle], dim=1)
# logvar = torch.cat([logvar_box, logvar_angle], dim=1)

# mu.shape

# logvar.shape

In [None]:
if args.use_AE:
    z = mu
else:
    # reparameterization
    std = torch.exp(0.5*logvar)
    # standard sampling
    eps = torch.randn_like(std)
    z = eps.mul(std).add_(mu)

In [None]:
z.shape

In [None]:
boxes_pred, angles_pred = model_decoder.decoder(z, objs, triples, attributes)

In [None]:
if args.KL_linear_decay:
    KL_weight = 10 ** (t // 1e5 - 6)
else:
    KL_weight = args.KL_loss_weight
total_loss, losses = calculate_model_losses(args, None, boxes, boxes_pred, angles, angles_pred, mu=mu, logvar=logvar, KL_weight=KL_weight)
losses['total_loss'] = total_loss.item()
if not math.isfinite(losses['total_loss']):
    print('WARNING: Got loss = NaN, not backpropping')
    pass #continue

In [None]:
losses

In [None]:
total_loss.item()