In [None]:
import torch
import torch.nn.functional as F
from tqdm.auto import tqdm
import numpy as np

In [None]:
import argparse
from torchvision.transforms import Compose
#from pytorch_lightning import Trainer, seed_everything
import torch
from torch.utils.data import DataLoader, Subset
from transforms.scene import (
    SeqToTensor,
    Padding_shift_ori_model,
    Augment_rotation,
    Augment_jitterring,
    Get_cat_shift_info,
)


In [None]:
from datasets.suncg_shift_seperate_dataset_deepsynth import SUNCG_Dataset
from separate_models.scene_shift_ori_col import scene_transformer
#from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from utils.config import read_config

#seed_everything(1)

In [None]:
parser = argparse.ArgumentParser()
#parser.add_argument("cfg_path", help="Path to config file", default="configs/scene_shift_cat_config.yaml")
args = parser.parse_args("")

args.cfg_path = "configs/scene_shift_ori_config.yaml"
cfg = read_config(args.cfg_path)

In [None]:
cfg['data']['data_path'] = "/home/ubuntu/research/suncg/bedroom"

In [None]:
t = Compose(
        [
            Augment_rotation(cfg['train']['aug']['rotation_list']),
            Augment_jitterring(cfg['train']['aug']['jitter_list']),
            Get_cat_shift_info(cfg),
            Padding_shift_ori_model(cfg),
            SeqToTensor(),
        ]
    )

In [None]:
trainval_set = SUNCG_Dataset(data_folder=cfg["data"]["data_path"], transform=t)
trainval_set.train_type = "ori"

In [None]:
total_len = len(trainval_set)-2
train_len = int(0.8 * total_len)


train_set = Subset(trainval_set, range(train_len))
val_set = Subset(trainval_set, range(train_len, total_len))

train_loader = DataLoader(
    train_set, batch_size=cfg["train"]["batch_size"], shuffle=True, num_workers=4
)
val_loader = DataLoader(
    val_set, batch_size=cfg["train"]["batch_size"], num_workers=4
)


In [None]:
train_set[0]

In [None]:
model = scene_transformer(cfg)

In [None]:
model = model.cuda()

In [None]:
optimizer = torch.optim.Adam(
            model.parameters(),
            lr=model.cfg["train"]["lr"],
            weight_decay=model.cfg["train"]["l2"],
        )

In [None]:
total_epoch = 30
print_every = 30
best_valid_loss = 100

in_patience = 5
current_patience = 0

In [None]:
for epoch in range(total_epoch):
    print("Training epoch: ", epoch)
    # train
    model.train()
    train_loss_list = []
    for i, batch in tqdm(enumerate(train_loader)):
        cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape = batch
        cat_seq = cat_seq.to(torch.device("cuda:0"))
        x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
        y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
        z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
        orient_seq = orient_seq.to(torch.device("cuda:0"))
        room_shape = room_shape.to(torch.device("cuda:0"))

        logprobs_ori = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape=room_shape)
        loss_ori = F.nll_loss(
            logprobs_ori.transpose(1, 2),
            orient_seq[:, 1:],
            ignore_index=model.cfg["model"]["orient"]["pad_token"],
        )

        optimizer.zero_grad()
        loss_ori.backward()
        optimizer.step()

        train_loss_list.append(loss_ori.item())
        if i % print_every == 0:
            print("loss: ", np.mean(train_loss_list))
            train_loss_list.clear()

    # train
    model.eval()
    val_loss_list = []
    for i, batch in tqdm(enumerate(val_loader)):
        cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape = batch
        cat_seq = cat_seq.to(torch.device("cuda:0"))
        x_loc_seq = x_loc_seq.to(torch.device("cuda:0"))
        y_loc_seq = y_loc_seq.to(torch.device("cuda:0"))
        z_loc_seq = z_loc_seq.to(torch.device("cuda:0"))
        orient_seq = orient_seq.to(torch.device("cuda:0"))
        room_shape = room_shape.to(torch.device("cuda:0"))

        logprobs_ori = model(cat_seq, x_loc_seq, y_loc_seq, z_loc_seq, orient_seq, room_shape=room_shape)
        loss_ori = F.nll_loss(
            logprobs_ori.transpose(1, 2),
            orient_seq[:, 1:],
            ignore_index=model.cfg["model"]["orient"]["pad_token"],
        )

        val_loss_list.append(loss_ori.item())

    mean_valid_loss = np.mean(val_loss_list)
    print("Validation loss: ", mean_valid_loss)
    
    if mean_valid_loss < best_valid_loss:
        best_valid_loss = mean_valid_loss
        torch.save(model.state_dict(), "records/ori_best_valid.pth")
        current_patience = 0
    else:
        current_patience += 1
        if current_patience > in_patience:
            break