In [1]:
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np

In [2]:
# from pytorch_lightning import Trainer, seed_everything

# from pytorch_lightning.callbacks import ModelCheckpoint

In [3]:
import argparse
from torchvision.transforms import Compose
import torch
from torch.utils.data import DataLoader, Subset

In [4]:
from transforms.scene import (
    SeqToTensor,
    Augment_rotation,
    Augment_jitterring,
    Get_cat_shift_info,
    Padding_joint,
    Add_Relations,
    Add_Descriptions,
    Add_Glove_Embeddings,
)


In [5]:
from datasets.suncg_shift_seperate_dataset_deepsynth import SUNCG_Dataset

In [6]:
from separate_models.scene_shift_cat import scene_transformer

In [7]:
from utils.config import read_config

In [8]:
parser = argparse.ArgumentParser()
#parser.add_argument("cfg_path", help="Path to config file", default="configs/scene_shift_cat_config.yaml")
args = parser.parse_args("")

args.cfg_path = "configs/scene_shift_cat_config.yaml"
cfg = read_config(args.cfg_path)

In [10]:
cfg

{'data': {'data_path': '/shared/data/new_room_data/bedroom_new_multi_img',
  'out_path': 'tests/data/scene_outputs',
  'list_path': 'None'},
 'model': {'cat': {'shuffle': False,
   'shape_cond': True,
   'text_cond': False,
   'start_token': 52,
   'stop_token': 51,
   'pad_token': 50},
  'coor': {'shape_cond': True,
   'text_cond': False,
   'start_token': 203,
   'stop_token': 202,
   'pad_token': 201},
  'orient': {'shape_cond': True,
   'text_cond': False,
   'start_token': 363,
   'stop_token': 362,
   'pad_token': 361},
  'relation': {'start_token': 4, 'stop_token': 3, 'pad_token': 2},
  'dim': {'shape_cond': True,
   'start_token': 83,
   'stop_token': 82,
   'pad_token': 81},
  'max_seq_len': 40,
  'max_obj_num': 100,
  'cat_num': 28,
  'emb_dim': 256,
  'dim_fwd': 256,
  'num_heads': 8,
  'num_blocks': 8,
  'dropout': 0.3},
 'text_model': {'max_seq_len': 100,
  'num_heads': 16,
  'num_blocks': 6,
  'dropout': 0.3,
  'voc': 120,
  'pad_token': 0},
 'train': {'aug': {'jitter_lis

In [11]:
cfg['data']['data_path'] = "/home/ubuntu/research/suncg/bedroom"

# run_training

In [12]:
transforms = [Augment_rotation(cfg['train']['aug']['rotation_list']), Augment_jitterring(cfg['train']['aug']['jitter_list']), Get_cat_shift_info(cfg)]

In [13]:
if cfg["model"]["cat"]["text_cond"]:
    transforms += [
        Add_Relations(),
        Add_Descriptions(),
        Add_Glove_Embeddings(max_sentences=3, max_length=50),
    ]
transforms.append(Padding_joint(cfg))
transforms.append(SeqToTensor())
t = Compose(transforms)

In [14]:
trainval_set = SUNCG_Dataset(data_folder=cfg['data']['data_path'], list_path=cfg['data']['list_path'], transform=t)
total_len = len(trainval_set)-2
train_len = int(0.8 * total_len)

In [15]:
train_set = Subset(trainval_set, range(train_len))
val_set = Subset(trainval_set, range(train_len, total_len))

In [16]:
train_set[3]

(tensor([52,  4,  0,  0,  1,  3,  8, 10, 11, 51, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
         50, 50, 50, 50]),
 tensor([203, 140,  57,  78, 164,  83, 103, 150,  51, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([203,  61,  61,  61,  61,  61,  61,  61,  61, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([203,  62, 129, 167, 124, 138,  36,  96, 121, 202, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201,
         201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201, 201]),
 tensor([363, 120, 300, 300, 120, 300, 210, 120, 120, 362, 361, 361, 361, 361,
         361, 361, 361, 361, 361, 361, 361, 3

In [17]:
train_loader = DataLoader(
    train_set, batch_size=cfg["train"]["batch_size"], shuffle=True, num_workers=0
)
val_loader = DataLoader(
    val_set, batch_size=cfg["train"]["batch_size"], num_workers=0
)

In [18]:
model = scene_transformer(cfg)


Using shape conditioned model


In [19]:
model = model.cuda()

In [20]:
optimizer = torch.optim.Adam(
            model.parameters(),
            lr=model.cfg["train"]["lr"],
            weight_decay= model.cfg["train"]["l2"],
        )

In [21]:
total_epoch = 50
print_every = 30
best_valid_loss = 100

in_patience = 10
current_patience = 0

In [None]:
for epoch in range(total_epoch):
    print("Training epoch: ", epoch)
    # train
    model.train()
    train_loss_list = []
    for i, batch in tqdm(enumerate(train_loader)):
        cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape = batch
        cat_seq = cat_seq.to(torch.device("cuda:0"))
        x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
        y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
        z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
        orient_seq = orient_seq.to(torch.device("cuda:0"))
        room_shape = room_shape.to(torch.device("cuda:0"))

        logprobs_cat = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape=room_shape)
        loss_cat = F.nll_loss(
            logprobs_cat.transpose(1, 2),
            cat_seq[:, 1:],
            ignore_index=model.cfg["model"]["cat"]["pad_token"],
        )

        optimizer.zero_grad()
        loss_cat.backward()
        optimizer.step()

        train_loss_list.append(loss_cat.item())
        if i % print_every == 0:
            print("loss: ", np.mean(train_loss_list))
            train_loss_list.clear()

    # train
    model.eval()
    val_loss_list = []
    for i, batch in tqdm(enumerate(val_loader)):
        cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape = batch
        cat_seq = cat_seq.to(torch.device("cuda:0"))
        x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
        y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
        z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
        orient_seq = orient_seq.to(torch.device("cuda:0"))
        room_shape = room_shape.to(torch.device("cuda:0"))

        logprobs_cat = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape=room_shape)
        loss_cat = F.nll_loss(
            logprobs_cat.transpose(1, 2),
            cat_seq[:, 1:],
            ignore_index=model.cfg["model"]["cat"]["pad_token"],
        )

        val_loss_list.append(loss_cat.item())

    mean_valid_loss = np.mean(val_loss_list)
    print("Validation loss: ", mean_valid_loss)
    
    if mean_valid_loss < best_valid_loss:
        best_valid_loss = mean_valid_loss
        torch.save(model.state_dict(), "records/cat_best_valid.pth")
        current_patience = 0
    else:
        current_patience += 1
        if current_patience > in_patience:
            break

Training epoch:  0


0it [00:00, ?it/s]

loss:  4.142786502838135
loss:  2.8436357816060385
loss:  2.2920062144597373


0it [00:00, ?it/s]

Validation loss:  2.006271427869797
Training epoch:  1


0it [00:00, ?it/s]

loss:  1.9522881507873535
loss:  1.9112616539001466
loss:  1.8216490149497986


0it [00:00, ?it/s]

Validation loss:  1.7733524978160857
Training epoch:  2


0it [00:00, ?it/s]

loss:  1.7373713692029318
loss:  1.6755924503008524


0it [00:00, ?it/s]

Validation loss:  1.6658772110939026
Training epoch:  3


0it [00:00, ?it/s]

loss:  1.5983136892318726
loss:  1.6298908869425455
loss:  1.6345850030581157


0it [00:00, ?it/s]

Validation loss:  1.6267554938793183
Training epoch:  4


0it [00:00, ?it/s]

loss:  1.4756146669387817
loss:  1.5875162879625957
loss:  1.593701926867167


0it [00:00, ?it/s]

Validation loss:  1.5968953311443328
Training epoch:  5


0it [00:00, ?it/s]

loss:  1.5466208457946777
loss:  1.5678273121515909
loss:  1.5807298064231872


0it [00:00, ?it/s]

Validation loss:  1.5863318383693694
Training epoch:  6


0it [00:00, ?it/s]

loss:  1.4992324113845825
loss:  1.5599275747934978
loss:  1.5410220185915628


0it [00:00, ?it/s]

Validation loss:  1.5677415132522583
Training epoch:  7


0it [00:00, ?it/s]

loss:  1.5015863180160522
loss:  1.5364260077476501
loss:  1.5391860405604045


0it [00:00, ?it/s]

Validation loss:  1.5654035568237306
Training epoch:  8


0it [00:00, ?it/s]

loss:  1.5544180870056152
loss:  1.528335444132487
loss:  1.534920831521352


0it [00:00, ?it/s]

Validation loss:  1.5812853157520295
Training epoch:  9


0it [00:00, ?it/s]

loss:  1.4767751693725586
loss:  1.5240188519159952
loss:  1.5211055874824524


0it [00:00, ?it/s]

Validation loss:  1.5533227801322937
Training epoch:  10


0it [00:00, ?it/s]

loss:  1.5114316940307617
loss:  1.5165903051694234
loss:  1.529936635494232


0it [00:00, ?it/s]

Validation loss:  1.553955191373825
Training epoch:  11


0it [00:00, ?it/s]

loss:  1.4984948635101318
loss:  1.5144067049026488
loss:  1.5200057824452717


0it [00:00, ?it/s]

Validation loss:  1.5414209008216857
Training epoch:  12


0it [00:00, ?it/s]

loss:  1.5532972812652588
loss:  1.5062721014022826
loss:  1.5145965496699014


0it [00:00, ?it/s]

Validation loss:  1.5441954672336577
Training epoch:  13


0it [00:00, ?it/s]

loss:  1.5500268936157227
loss:  1.5091872453689574
loss:  1.5078557173411051


0it [00:00, ?it/s]

Validation loss:  1.5362824559211732
Training epoch:  14


0it [00:00, ?it/s]

loss:  1.4238677024841309
loss:  1.5029184063275656
loss:  1.506085753440857


0it [00:00, ?it/s]

Validation loss:  1.531947875022888
Training epoch:  15


0it [00:00, ?it/s]

loss:  1.485878348350525
loss:  1.5014251232147218
loss:  1.5014480829238892


0it [00:00, ?it/s]

Validation loss:  1.547640508413315
Training epoch:  16


0it [00:00, ?it/s]

loss:  1.549347162246704
loss:  1.4890639265378316
loss:  1.5113752166430154


0it [00:00, ?it/s]

Validation loss:  1.5279921174049378
Training epoch:  17


0it [00:00, ?it/s]

loss:  1.4364923238754272
loss:  1.5084267934163411
loss:  1.4890702486038208


0it [00:00, ?it/s]

Validation loss:  1.597807091474533
Training epoch:  18


0it [00:00, ?it/s]

loss:  1.499051809310913
loss:  1.4870951573053997
loss:  1.498115070660909


0it [00:00, ?it/s]

Validation loss:  1.5316394090652465
Training epoch:  19


0it [00:00, ?it/s]

loss:  1.452535629272461
loss:  1.5022929668426515
loss:  1.4983164072036743


0it [00:00, ?it/s]

Validation loss:  1.5200016021728515
Training epoch:  20


0it [00:00, ?it/s]

loss:  1.4014338254928589
loss:  1.491713559627533
loss:  1.4972712715466818


0it [00:00, ?it/s]

Validation loss:  1.5292137920856477
Training epoch:  21


0it [00:00, ?it/s]

loss:  1.4967328310012817
loss:  1.4858655015627542
loss:  1.4935709396998087


0it [00:00, ?it/s]

Validation loss:  1.5338353216648102
Training epoch:  22


0it [00:00, ?it/s]

loss:  1.5727976560592651
loss:  1.4801759401957193
loss:  1.5063881754875184


0it [00:00, ?it/s]

Validation loss:  1.5314603984355926
Training epoch:  23


0it [00:00, ?it/s]

loss:  1.505948781967163
loss:  1.4969555298487345
loss:  1.4858598748842875


0it [00:00, ?it/s]

Validation loss:  1.54506014585495
Training epoch:  24


0it [00:00, ?it/s]

loss:  1.4926759004592896
loss:  1.4868239601453146
loss:  1.489581290880839


0it [00:00, ?it/s]

Validation loss:  1.6184792876243592
Training epoch:  25


0it [00:00, ?it/s]

loss:  1.528479814529419
loss:  1.4966655929883321
loss:  1.4873021801312765


0it [00:00, ?it/s]

Validation loss:  1.53280468583107
Training epoch:  26


0it [00:00, ?it/s]

loss:  1.514397144317627
loss:  1.494001070658366
