In [1]:
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np

In [2]:
import argparse
from torchvision.transforms import Compose
#from pytorch_lightning import Trainer, seed_everything
import torch
from torch.utils.data import DataLoader, Subset
from transforms.scene import (
    SeqToTensor,
    Augment_rotation,
    Augment_jitterring,
    Get_dim_shift_info,
    Padding_shift_dim_model,
)
from datasets.suncg_shift_seperate_dataset_deepsynth import SUNCG_Dataset
from separate_models.scene_shift_dim import scene_transformer
#from pytorch_lightning.callbacks import ModelCheckpoint
from utils.config import read_config

In [3]:
parser = argparse.ArgumentParser()
#parser.add_argument("cfg_path", help="Path to config file", default="configs/scene_shift_cat_config.yaml")
args = parser.parse_args("")

args.cfg_path = "configs/scene_shift_dim_config.yaml"
cfg = read_config(args.cfg_path)

In [4]:
cfg

{'data': {'data_path': '/shared/data/new_room_data/bedroom_new_window_door_Img',
  'out_path': 'tests/data/scene_outputs',
  'list_path': 'None'},
 'model': {'cat': {'start_token': 52, 'stop_token': 51, 'pad_token': 50},
  'coor': {'start_token': 203, 'stop_token': 202, 'pad_token': 201},
  'orient': {'start_token': 363, 'stop_token': 362, 'pad_token': 361},
  'relation': {'start_token': 4, 'stop_token': 3, 'pad_token': 2},
  'dim': {'start_token': 83,
   'stop_token': 82,
   'pad_token': 81,
   'shape_cond': True},
  'max_seq_len': 80,
  'max_obj_num': 80,
  'cat_num': 28,
  'emb_dim': 128,
  'dim_fwd': 128,
  'num_heads': 4,
  'num_blocks': 8,
  'dropout': 0.3},
 'train': {'aug': {'jitter_list': [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
   'rotation_list': [0, 30, 60, 90, 120, 150, 180, 210, 240, 270]},
  'batch_size': 128,
  'epochs': 2000,
  'lr': 0.0003,
  'resume': None,
  'total_len': 1142,
  'train_len': 950,
  'l2': 0.001,
  'lr_restart': 10000,
  'warmup': 2000},
 'test': {'num

In [5]:
cfg['data']['data_path'] = "/home/ubuntu/research/suncg/bedroom"

In [6]:
cfg['train']['batch'] = 4

# Training

In [7]:
t = Compose(
        [
            Augment_rotation(cfg['train']['aug']['rotation_list']),
            Augment_jitterring(cfg['train']['aug']['jitter_list']),
            Get_dim_shift_info(cfg),
            Padding_shift_dim_model(cfg),
            SeqToTensor(),
        ]
    )

In [8]:
trainval_set = SUNCG_Dataset(
    data_folder=cfg["data"]["data_path"],
    list_path=cfg["data"]["list_path"],
    transform=t,
)

trainval_set.train_type = "dim"

total_len = len(trainval_set) - 2
train_len = int(0.8 * total_len)
train_set = Subset(trainval_set, range(train_len))
val_set = Subset(trainval_set, range(train_len, total_len))

train_loader = DataLoader(
    train_set, batch_size=cfg["train"]["batch_size"], shuffle=True, num_workers=0
)
val_loader = DataLoader(
    val_set, batch_size=cfg["train"]["batch_size"], num_workers=0
)


In [9]:
train_set[99][0].shape

torch.Size([80])

In [10]:
model = scene_transformer(cfg)

Using shape cond model


In [11]:
model = model.cuda()

In [12]:
optimizer = torch.optim.Adam(
            model.parameters(),
            lr=model.cfg["train"]["lr"],
            weight_decay= model.cfg["train"]["l2"],
        )

In [16]:
total_epoch = 30
print_every = 30
best_valid_loss = 100

In [17]:
in_patience = 10
current_patience = 0

In [None]:
for epoch in range(total_epoch):
    print("Training epoch: ", epoch)
    # train
    model.train()
    train_loss_list = []
    for i, batch in tqdm(enumerate(train_loader)):
        cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, dim_seq, room_shape = batch
        cat_seq = cat_seq.to(torch.device("cuda:0"))
        x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
        y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
        z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
        orient_seq = orient_seq.to(torch.device("cuda:0"))
        dim_seq = dim_seq.to(torch.device("cuda:0"))
        room_shape = room_shape.to(torch.device("cuda:0"))

        logprobs_dim = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, dim_seq, room_shape=room_shape)
        loss_dim = F.nll_loss(
            logprobs_dim.transpose(1, 2),
            dim_seq[:, 1:],
            ignore_index=model.cfg["model"]["dim"]["pad_token"],
        )

        optimizer.zero_grad()
        loss_dim.backward()
        optimizer.step()

        train_loss_list.append(loss_dim.item())
        if i % print_every == 0:
            print("loss: ", np.mean(train_loss_list))
            train_loss_list.clear()
            
    torch.save(model.state_dict(), "records/dim_best_valid_new.pth")

    # eval
#     model.eval()
#     val_loss_list = []
#     for i, batch in tqdm(enumerate(val_loader)):
#         cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, dim_seq, room_shape = batch
#         cat_seq = cat_seq.to(torch.device("cuda:0"))
#         x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
#         y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
#         z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
#         orient_seq = orient_seq.to(torch.device("cuda:0"))
#         dim_seq = dim_seq.to(torch.device("cuda:0"))
#         room_shape = room_shape.to(torch.device("cuda:0"))

#         logprobs_dim = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, dim_seq, room_shape=room_shape)
#         loss_dim = F.nll_loss(
#             logprobs_dim.transpose(1, 2),
#             dim_seq[:, 1:],
#             ignore_index=model.cfg["model"]["dim"]["pad_token"],
#         )
#         val_loss_list.append(loss_dim.item())

#     mean_valid_loss = np.mean(val_loss_list)
#     print("Validation loss: ", mean_valid_loss)
    
#     if mean_valid_loss < best_valid_loss:
#         best_valid_loss = mean_valid_loss
#         torch.save(model.state_dict(), "records/dim_best_valid.pth")
#         current_patience = 0
#     else:
#         current_patience += 1
#         if current_patience > in_patience:
#             break

Training epoch:  0


0it [00:00, ?it/s]

loss:  1.865207552909851
loss:  1.792915940284729
Training epoch:  1


0it [00:00, ?it/s]

loss:  1.807572603225708
loss:  1.7876659433046977
Training epoch:  2


0it [00:00, ?it/s]

loss:  1.7798100709915161
loss:  1.788124144077301
Training epoch:  3


0it [00:00, ?it/s]

loss:  1.8014047145843506
loss:  1.779676330089569
Training epoch:  4


0it [00:00, ?it/s]

loss:  1.7924044132232666
loss:  1.7774298508961996
Training epoch:  5


0it [00:00, ?it/s]

loss:  1.776291847229004
loss:  1.7733834584554036
Training epoch:  6


0it [00:00, ?it/s]

loss:  1.80858314037323
loss:  1.7738741477330526
Training epoch:  7


0it [00:00, ?it/s]

loss:  1.7342982292175293
loss:  1.7639914313952128
Training epoch:  8


0it [00:00, ?it/s]

loss:  1.7437413930892944
loss:  1.761853567759196
Training epoch:  9


0it [00:00, ?it/s]

loss:  1.7373099327087402
loss:  1.7612302343050639
Training epoch:  10


0it [00:00, ?it/s]

loss:  1.756481409072876
loss:  1.764276913801829
Training epoch:  11


0it [00:00, ?it/s]

loss:  1.8046578168869019
loss:  1.7474348425865174
Training epoch:  12


0it [00:00, ?it/s]

loss:  1.7558350563049316
loss:  1.7489638050397238
Training epoch:  13


0it [00:00, ?it/s]

loss:  1.7440974712371826
loss:  1.7477061430613199
Training epoch:  14


0it [00:00, ?it/s]

loss:  1.7465617656707764
loss:  1.7525884628295898
Training epoch:  15


0it [00:00, ?it/s]

loss:  1.728040337562561
loss:  1.7391502459843953
Training epoch:  16


0it [00:00, ?it/s]

loss:  1.7750576734542847
loss:  1.7409571766853333
Training epoch:  17


0it [00:00, ?it/s]

loss:  1.6904557943344116
loss:  1.7387949705123902
Training epoch:  18


0it [00:00, ?it/s]

loss:  1.7879306077957153
