In [1]:
from options.options import Options
import os
import torch
from build_dataset_model import build_loaders, build_model
from utils import get_model_attr, calculate_model_losses, tensor_aug
from collections import defaultdict
import math

In [2]:
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

In [3]:
args = Options().parse()
if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
    os.mkdir(args.output_dir)
if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
    os.mkdir(args.test_dir)

| options
dataset: suncg
suncg_train_dir: metadata/data_rot_train.json
suncg_val_dir: metadata/data_rot_val.json
suncg_data_dir: /home/yizhou/Research/SUNCG/suncg_data
loader_num_workers: 8
embedding_dim: 64
gconv_mode: feedforward
gconv_dim: 128
gconv_hidden_dim: 512
gconv_num_layers: 5
mlp_normalization: batch
vec_noise_dim: 0
layout_noise_dim: 32
batch_size: 16
num_iterations: 60000
eval_mode_after: -1
learning_rate: 0.0001
print_every: 100
checkpoint_every: 1000
snapshot_every: 10000
output_dir: ./checkpoints
checkpoint_name: latest_checkpoint
timing: False
multigpu: False
restore_from_checkpoint: False
checkpoint_start_from: None
test_dir: ./layouts_out
gpu_id: 0
KL_loss_weight: 0.1
use_AE: False
decoder_cat: True
train_3d: True
KL_linear_decay: False
use_attr_30: True
manual_seed: 42
batch_gen: False
measure_acc_l1_std: False
heat_map: False
draw_2d: False
draw_3d: False
fine_tune: False
gan_shade: False
blender_path: /home/yizhou/blender-2.92.0-linux64/blender



In [4]:
args.use_AE = True

In [5]:
# tensorboard
writer = SummaryWriter()

In [6]:
# load data
vocab, train_loader, val_loader = build_loaders(args)

Starting to read the json file for SUNCG
loading relation score matrix from:  new/relation_graph_v1.p
Training dataset has 53860 scenes and 708041 objects
(13.15 objects per image)
Starting to read the json file for SUNCG
loading relation score matrix from:  new/relation_graph_v1.p


In [7]:
dt = train_loader.dataset

In [8]:
from new.CustomVAE import *
from utils import calculate_model_losses

In [9]:
# load decoder
ovaed = OriVAEDecoder(dt.vocab, embedding_dim=64)

In [10]:
optimizer = torch.optim.Adam(ovaed.parameters(), lr=args.learning_rate)

In [11]:
ovaed = ovaed.cuda()

In [12]:
t = 0 # total steps

In [13]:
# training
ovaed.train()
for batch in tqdm(train_loader):
    t += 1
    
    ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
    z = torch.randn(objs.size(0), 64).to(objs.device)
    boxes_pred, angles_pred = ovaed.decoder(z, objs, triples, attributes)
    
    total_loss, losses = calculate_model_losses(args, ovaed, boxes_pred, boxes, angles, angles_pred)
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if t % args.print_every == 0:
        print("On batch {} out of {}".format(t, args.num_iterations))
        for name, val in losses.items():
            print(' [%s]: %.4f' % (name, val))
            writer.add_scalar('Loss/'+ name, val, t)
    

  0%|          | 0/3367 [00:00<?, ?it/s]

On batch 100 out of 60000
 [bbox_pred]: 0.2670
 [angle_pred]: 1.4085
On batch 200 out of 60000
 [bbox_pred]: 0.2327
 [angle_pred]: 1.5395
On batch 300 out of 60000
 [bbox_pred]: 0.2356
 [angle_pred]: 1.8668
On batch 400 out of 60000
 [bbox_pred]: 0.2043
 [angle_pred]: 1.4522
On batch 500 out of 60000
 [bbox_pred]: 0.1894
 [angle_pred]: 1.2239
On batch 600 out of 60000
 [bbox_pred]: 0.1893
 [angle_pred]: 1.2829
On batch 700 out of 60000
 [bbox_pred]: 0.1986
 [angle_pred]: 1.1789
On batch 800 out of 60000
 [bbox_pred]: 0.1813
 [angle_pred]: 1.4180
On batch 900 out of 60000
 [bbox_pred]: 0.1820
 [angle_pred]: 1.1814
On batch 1000 out of 60000
 [bbox_pred]: 0.1759
 [angle_pred]: 1.1798
On batch 1100 out of 60000
 [bbox_pred]: 0.1923
 [angle_pred]: 1.3681
On batch 1200 out of 60000
 [bbox_pred]: 0.1686
 [angle_pred]: 1.3169
On batch 1300 out of 60000
 [bbox_pred]: 0.1704
 [angle_pred]: 1.6263
On batch 1400 out of 60000
 [bbox_pred]: 0.1708
 [angle_pred]: 1.3382
On batch 1500 out of 60000
 [

In [14]:
valid_t = 0

In [15]:
# validation
ovaed.eval()
for batch in tqdm(val_loader):
    valid_t += 1
    
    ids, objs, boxes, triples, angles, attributes, obj_to_img, triple_to_img = tensor_aug(batch)
    z = torch.randn(objs.size(0), 64).to(objs.device)
    boxes_pred, angles_pred = ovaed.decoder(z, objs, triples, attributes)
    
    total_loss, losses = calculate_model_losses(args, ovaed, boxes_pred, boxes, angles, angles_pred)
    
    if t % args.print_every == 0:
        print("On batch {} out of {}".format(t, args.num_iterations))
        for name, val in losses.items():
            print(' [%s]: %.4f' % (name, val))
            writer.add_scalar('Loss/Validation/'+ name, val, valid_t)
    

  0%|          | 0/842 [00:00<?, ?it/s]