In [31]:
import torch
import torch.nn.functional as F

In [2]:
# from pytorch_lightning import Trainer, seed_everything

# from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
import argparse
from torchvision.transforms import Compose
import torch
from torch.utils.data import DataLoader, Subset

In [4]:
from transforms.scene import (
    SeqToTensor,
    Augment_rotation,
    Augment_jitterring,
    Get_cat_shift_info,
    Padding_joint,
    Add_Relations,
    Add_Descriptions,
    Add_Glove_Embeddings,
)


In [5]:
from datasets.suncg_shift_seperate_dataset_deepsynth import SUNCG_Dataset

In [6]:
from separate_models.scene_shift_cat import scene_transformer

In [7]:
from utils.config import read_config

In [8]:
parser = argparse.ArgumentParser()
#parser.add_argument("cfg_path", help="Path to config file", default="configs/scene_shift_cat_config.yaml")
args = parser.parse_args("")

In [9]:
args.cfg_path = "configs/scene_shift_cat_config.yaml"
cfg = read_config(args.cfg_path)

In [10]:
cfg

{'data': {'data_path': '/shared/data/new_room_data/bedroom_new_multi_img',
  'out_path': 'tests/data/scene_outputs',
  'list_path': 'None'},
 'model': {'cat': {'shuffle': False,
   'shape_cond': True,
   'text_cond': False,
   'start_token': 52,
   'stop_token': 51,
   'pad_token': 50},
  'coor': {'shape_cond': True,
   'text_cond': False,
   'start_token': 203,
   'stop_token': 202,
   'pad_token': 201},
  'orient': {'shape_cond': True,
   'text_cond': False,
   'start_token': 363,
   'stop_token': 362,
   'pad_token': 361},
  'relation': {'start_token': 4, 'stop_token': 3, 'pad_token': 2},
  'dim': {'shape_cond': True,
   'start_token': 83,
   'stop_token': 82,
   'pad_token': 81},
  'max_seq_len': 40,
  'max_obj_num': 100,
  'cat_num': 28,
  'emb_dim': 256,
  'dim_fwd': 256,
  'num_heads': 8,
  'num_blocks': 8,
  'dropout': 0.3},
 'text_model': {'max_seq_len': 100,
  'num_heads': 16,
  'num_blocks': 6,
  'dropout': 0.3,
  'voc': 120,
  'pad_token': 0},
 'train': {'aug': {'jitter_lis

In [11]:
cfg['data']['data_path'] = "/home/ubuntu/research/suncg/bedroom"

# run_training

In [12]:
transforms = [Augment_rotation(cfg['train']['aug']['rotation_list']), Augment_jitterring(cfg['train']['aug']['jitter_list']), Get_cat_shift_info(cfg)]

In [13]:
if cfg["model"]["cat"]["text_cond"]:
    transforms += [
        Add_Relations(),
        Add_Descriptions(),
        Add_Glove_Embeddings(max_sentences=3, max_length=50),
    ]
transforms.append(Padding_joint(cfg))
transforms.append(SeqToTensor())
t = Compose(transforms)

In [14]:
trainval_set = SUNCG_Dataset(data_folder=cfg['data']['data_path'], list_path=cfg['data']['list_path'], transform=t)
total_len = len(trainval_set)-2
train_len = int(0.8 * total_len)

In [15]:
train_set = Subset(trainval_set, range(train_len))
val_set = Subset(trainval_set, range(train_len, total_len))

In [16]:
train_set[3]

(tensor([52,  4,  0,  0,  1,  3,  8, 10, 11, 51, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50]),
 tensor([203, 150,  50,  28, 109,  55, 154, 126,  55, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([203,  56,  57,  57,  57,  57,  57,  57,  57, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([203, 114,  75, 114, 166, 103,  69, 140,  67, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([363,  60, 240, 240,  60, 240, 150,  60,  60, 362, 361, 361, 361, 361,
         361, 361, 361, 361, 361, 361, 361, 3

In [17]:
train_loader = DataLoader(
    train_set, batch_size=cfg["train"]["batch_size"], shuffle=True, num_workers=0
)
val_loader = DataLoader(
    val_set, batch_size=cfg["train"]["batch_size"], num_workers=0
)

In [18]:
batch = next(iter(train_loader))

In [19]:
batch[0].shape

torch.Size([64, 40])

In [20]:
model = scene_transformer(cfg)


Using shape conditioned model


In [21]:
model = model.cuda()

In [33]:
optimizer = torch.optim.Adam(
            model.parameters(),
            lr=model.cfg["train"]["lr"],
            weight_decay= model.cfg["train"]["l2"],
        )

In [34]:
for batch in train_loader:
    cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape = batch
    cat_seq = cat_seq.to(torch.device("cuda:0"))
    x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
    y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
    z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
    orient_seq = orient_seq.to(torch.device("cuda:0"))
    room_shape = room_shape.to(torch.device("cuda:0"))
    
    logprobs_cat = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape=room_shape)
    loss_cat = F.nll_loss(
        logprobs_cat.transpose(1, 2),
        cat_seq[:, 1:],
        ignore_index=model.cfg["model"]["cat"]["pad_token"],
    )
    
    optimizer.zero_grad()
    loss_cat.backward()
    optimizer.step()
    
    print("loss: ", loss_cat.item())
    
    break
    

loss:  4.149105548858643
