In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
from transformers import BertModel, BertConfig
from transformers.models.bert.modeling_bert import *

In [3]:
from new.networks import *
from new.utils import *
from new.models import *

In [4]:
from options.options import Options
import os
import torch
from tqdm.auto import tqdm

In [5]:
from build_dataset_model import build_suncg_dsets
from data.suncg_dataset import SuncgDataset

In [6]:
args = Options().parse()
if (args.output_dir is not None) and (not os.path.isdir(args.output_dir)):
    os.mkdir(args.output_dir)
if (args.test_dir is not None) and (not os.path.isdir(args.test_dir)):
    os.mkdir(args.test_dir)

| options
dataset: suncg
suncg_train_dir: metadata/data_rot_train.json
suncg_val_dir: metadata/data_rot_val.json
suncg_data_dir: /home/yizhou/Research/SUNCG/suncg_data
loader_num_workers: 8
embedding_dim: 64
gconv_mode: feedforward
gconv_dim: 128
gconv_hidden_dim: 512
gconv_num_layers: 5
mlp_normalization: batch
vec_noise_dim: 0
layout_noise_dim: 32
batch_size: 16
num_iterations: 60000
eval_mode_after: -1
learning_rate: 0.0001
print_every: 100
checkpoint_every: 1000
snapshot_every: 10000
output_dir: ./checkpoints
checkpoint_name: latest_checkpoint
timing: False
multigpu: False
restore_from_checkpoint: False
checkpoint_start_from: None
test_dir: ./layouts_out
gpu_id: 0
KL_loss_weight: 0.1
use_AE: False
decoder_cat: True
train_3d: True
KL_linear_decay: False
use_attr_30: True
manual_seed: 42
batch_gen: False
measure_acc_l1_std: False
heat_map: False
draw_2d: False
draw_3d: False
fine_tune: False
gan_shade: False
blender_path: /home/yizhou/blender-2.92.0-linux64/blender



In [7]:
dset_kwargs = {
        'data_dir': args.suncg_train_dir,
        'train_3d': args.train_3d,
        'use_attr_30': args.use_attr_30,
    }

In [8]:
train_dset = SuncgDataset(**dset_kwargs)

Starting to read the json file for SUNCG
loading relation score matrix from:  new/relation_graph_v1.p


In [9]:
train_dset.vocab

{'object_idx_to_name': ['__room__',
  'curtain',
  'shower_curtain',
  'dresser',
  'counter',
  'bookshelf',
  'picture',
  'mirror',
  'floor_mat',
  'chair',
  'sink',
  'desk',
  'table',
  'lamp',
  'door',
  'clothes',
  'person',
  'toilet',
  'cabinet',
  'floor',
  'window',
  'blinds',
  'wall',
  'pillow',
  'whiteboard',
  'bathtub',
  'television',
  'night_stand',
  'sofa',
  'refridgerator',
  'bed',
  'shelves'],
 'object_name_to_idx': {'__room__': 0,
  'curtain': 1,
  'shower_curtain': 2,
  'dresser': 3,
  'counter': 4,
  'bookshelf': 5,
  'picture': 6,
  'mirror': 7,
  'floor_mat': 8,
  'chair': 9,
  'sink': 10,
  'desk': 11,
  'table': 12,
  'lamp': 13,
  'door': 14,
  'clothes': 15,
  'person': 16,
  'toilet': 17,
  'cabinet': 18,
  'floor': 19,
  'window': 20,
  'blinds': 21,
  'wall': 22,
  'pillow': 23,
  'whiteboard': 24,
  'bathtub': 25,
  'television': 26,
  'night_stand': 27,
  'sofa': 28,
  'refridgerator': 29,
  'bed': 30,
  'shelves': 31},
 'pred_idx_to_na

In [10]:
train_dset[0]

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370172916/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  real_objs = (objs != __room__).nonzero().squeeze(1)


(4377,
 tensor([14, 18, 27, 26, 20,  9, 22, 26, 30,  3,  0]),
 tensor([[ 6.6050e-01,  7.2728e-03,  7.8242e-01,  9.3001e-01,  7.7818e-01,
           1.0000e+00],
         [-1.8332e-02,  6.7337e-01,  5.6516e-01,  1.9248e-01,  7.8162e-01,
           5.9611e-01],
         [ 4.6290e-02,  1.8182e-02,  3.2368e-01,  1.9460e-01,  2.1091e-01,
           4.4950e-01],
         [ 5.3706e-02,  2.1091e-01,  3.5881e-01,  1.7235e-01,  2.7273e-01,
           4.1805e-01],
         [ 7.1195e-01, -5.4443e-16,  5.0570e-01,  1.2310e+00,  5.0909e-01,
           5.5158e-01],
         [ 1.2971e-01,  1.8182e-02,  6.6184e-01,  4.3746e-01,  3.2920e-01,
           8.7699e-01],
         [ 8.4345e-01,  1.8182e-02,  4.0563e-02,  9.5468e-01,  1.0000e+00,
           1.1921e-01],
         [ 5.9412e-02,  6.3743e-01,  9.3840e-01,  5.4142e-01,  9.0634e-01,
           9.5150e-01],
         [ 2.2797e-01,  1.8182e-02, -1.0230e-01,  6.3582e-01,  2.9879e-01,
           4.4557e-01],
         [ 6.6919e-01,  1.8182e-02,  4.8725e-01

In [11]:
args.loader_num_workers

8

In [12]:
loader_kwargs = {
        'batch_size': args.batch_size,
        'num_workers': 1, #args.loader_num_workers,
        'shuffle': True,
        'collate_fn': new_collate_fn,
    }

In [13]:
train_loader = DataLoader(train_dset, **loader_kwargs)

In [14]:
fetgg = FromEncoderToGraphGenerator()

In [None]:
use_cuda = False

In [None]:
if use_cuda:
    fetgg = fetgg.cuda()

In [15]:
fetgg_optim = torch.optim.Adam(fetgg.parameters(),lr = 1e-4)

In [None]:
for c, batch in tqdm(enumerate(train_loader)):
    objs, boxes, angles, attention_mask = batch[0]
    if use_cuda:
        objs = objs.to(fetgg.device)
        boxes = boxes.to(fetgg.device)
        angles = angles.to(fetgg.device)
        attention_mask = attention_mask.to(fetgg.device)
    
    logits, loss = output = fetgg(objs, boxes, angles, attention_mask)
    
    fetgg_optim.zero_grad()
    loss.backward()
    fetgg_optim.step()
    
    if c % 100 == 0:
        print(loss.item())

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

3.5674777030944824
