In [1]:
from options.options import Options
import os
import torch
from build_dataset_model import build_loaders, build_model
from utils import get_model_attr, calculate_model_losses, tensor_aug
from collections import defaultdict
import math


In [2]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import numpy as np

from new.CustomVAE import *
from new.utils import resolve_relative_positions, obtain_sampled_relations
from utils import calculate_model_losses
from data.suncg_dataset import g_add_in_room_relation, g_use_heuristic_relation_matrix, \
    g_prepend_room, g_add_random_parent_link, g_shuffle_subject_object

In [45]:
from torch_scatter import scatter_mean

In [3]:
# decoder option
g_decoder_option = "original" #"rgcn"
g_relative_location = "False"
g_parent_link_index = 16

args = Options().parse()
if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
    os.mkdir(args.output_dir)
if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
    os.mkdir(args.test_dir)

| options
dataset: suncg
suncg_train_dir: metadata/data_rot_train.json
suncg_val_dir: metadata/data_rot_val.json
suncg_data_dir: /home/yizhou/Research/SUNCG/suncg_data
loader_num_workers: 8
embedding_dim: 64
gconv_mode: feedforward
gconv_dim: 128
gconv_hidden_dim: 512
gconv_num_layers: 5
mlp_normalization: batch
vec_noise_dim: 0
layout_noise_dim: 32
batch_size: 16
num_iterations: 20000
eval_mode_after: -1
learning_rate: 0.0001
print_every: 100
checkpoint_every: 1000
snapshot_every: 10000
output_dir: ./checkpoints
checkpoint_name: latest_checkpoint
timing: False
multigpu: False
restore_from_checkpoint: False
checkpoint_start_from: None
test_dir: ./layouts_out
gpu_id: 0
KL_loss_weight: 0.1
use_AE: False
decoder_cat: True
train_3d: True
KL_linear_decay: False
use_attr_30: True
manual_seed: 42
batch_gen: False
measure_acc_l1_std: False
heat_map: False
draw_2d: False
draw_3d: False
fine_tune: False
gan_shade: False
blender_path: /home/yizhou/blender-2.92.0-linux64/blender



In [4]:
# has KL divergence loss
# args.use_AE = False

In [5]:
str(args)

"Namespace(dataset='suncg', suncg_train_dir='metadata/data_rot_train.json', suncg_val_dir='metadata/data_rot_val.json', suncg_data_dir='/home/yizhou/Research/SUNCG/suncg_data', loader_num_workers=8, embedding_dim=64, gconv_mode='feedforward', gconv_dim=128, gconv_hidden_dim=512, gconv_num_layers=5, mlp_normalization='batch', vec_noise_dim=0, layout_noise_dim=32, batch_size=16, num_iterations=20000, eval_mode_after=-1, learning_rate=0.0001, print_every=100, checkpoint_every=1000, snapshot_every=10000, output_dir='./checkpoints', checkpoint_name='latest_checkpoint', timing=False, multigpu=False, restore_from_checkpoint=False, checkpoint_start_from=None, test_dir='./layouts_out', gpu_id=0, KL_loss_weight=0.1, use_AE=True, decoder_cat=True, train_3d=True, KL_linear_decay=False, use_attr_30=True, manual_seed=42, batch_gen=False, measure_acc_l1_std=False, heat_map=False, draw_2d=False, draw_3d=False, fine_tune=False, gan_shade=False, blender_path='/home/yizhou/blender-2.92.0-linux64/blender'

In [6]:
# load data
vocab, train_loader, val_loader = build_loaders(args)

dt = train_loader.dataset

Starting to read the json file for SUNCG
Training dataset has 53860 scenes and 708041 objects
(13.15 objects per image)
Starting to read the json file for SUNCG


In [7]:
# model args
kwargs = {
        'vocab': dt.vocab,
        'batch_size': args.batch_size,
        'train_3d': args.train_3d,
        'decoder_cat': args.decoder_cat,
        'embedding_dim': 84,#args.embedding_dim,
        'gconv_mode': args.gconv_mode,
        'gconv_num_layers': args.gconv_num_layers,
        'mlp_normalization': args.mlp_normalization,
        'vec_noise_dim': args.vec_noise_dim,
        'layout_noise_dim': args.layout_noise_dim,
        'use_AE': args.use_AE
    }

In [8]:
# load encoder
model_encoder = TransformerEncoder(**kwargs)
model_encoder = model_encoder.cuda()
optimizer_encoder = torch.optim.Adam(model_encoder.parameters(), lr=args.learning_rate)


In [9]:
# load decoder
if g_decoder_option == "original":
    model_decoder = OriVAEDecoder(**kwargs)
elif g_decoder_option == "rgcn":
    model_decoder = RGCNConv(**kwargs)
else:
    raise("MODEL MISSING {}".format(g_decoder_option))
model_decoder = model_decoder.cuda()
optimizer_decoder = torch.optim.Adam(model_decoder.parameters(), lr=args.learning_rate) 

In [10]:
# load graph generator
model_generator = GraphGenerator(**kwargs)
model_generator = model_generator.cuda()
optimizer_generator = torch.optim.Adam(model_generator.parameters(), lr=args.learning_rate) 

In [11]:
for batch in tqdm(train_loader):
    #t += 1
    ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
    break

  0%|          | 0/3367 [00:00<?, ?it/s]

In [12]:
# attention mask
obj_counts = [torch.sum(obj_to_img == i).item() for i in range(args.batch_size)]
block_list = [torch.ones((obj_counts[i],obj_counts[i])) for i in range(args.batch_size)]
attention_mask = torch.block_diag(*block_list).to(objs.device) # [BxB]

# encoder
hidden_states = model_encoder.encoder(objs, boxes, angles, attributes, attention_mask)
mu, logvar = model_encoder.get_hidden_representation(hidden_states)

In [14]:
# calculate edges
score_matrix = model_generator.get_score_matrix(hidden_states, attention_mask)
all_samples, all_log_probs, all_entropy = model_generator.sample(score_matrix, obj_to_img)

# query new relation
new_triples = obtain_sampled_relations(objs, all_samples, boxes.cpu().data, dt.vocab)

In [28]:
# reparameterization
if args.use_AE:
    z = mu
else:
    std = torch.exp(0.5*logvar)
    # standard sampling
    eps = torch.randn_like(std)
    z = eps.mul(std).add_(mu)

In [32]:
# decoder
boxes_pred, angles_pred = model_decoder.decoder(z, objs, new_triples, attributes)

In [34]:
# loss 
if args.KL_linear_decay:
    KL_weight = 10 ** (t // 1e5 - 6)
else:
    KL_weight = args.KL_loss_weight
total_loss, losses = calculate_model_losses(args, None, boxes, boxes_pred, angles, angles_pred, mu=mu, logvar=logvar, KL_weight=KL_weight)
losses['total_loss'] = total_loss.item()
if not math.isfinite(losses['total_loss']):
    print('WARNING: Got loss = NaN, not backpropping')
    pass #continue

In [80]:
loss_bbox = F.l1_loss(boxes_pred, boxes, reduction = "none")
loss_bbox = torch.mean(loss_bbox, dim = 1)
#loss_bbox_per_batch = scatter_mean(loss_bbox, obj_to_img, dim = 0)

loss_angle = F.nll_loss(angles_pred, angles, reduction = "none")
#loss_angle_per_batch = scatter_mean(loss_angle, obj_to_img, dim = 0)

# calculate policy gradient
J = - torch.mean(all_log_probs * (loss_bbox.detach() + 0.1 * loss_angle.detach()) + 0.1 * all_entropy)


In [81]:
J

tensor(1.8952, device='cuda:0', grad_fn=<NegBackward>)

In [None]:
J.backward()