In [None]:
import os
import torch
from torch.utils.data import DataLoader
from dataloaders import *
from scene_net import *
from loss import SceneNetLoss
from train import train

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network_name = 'network_name'
dest = "path/to/save/model"
dataset = "nyuv2_3" # choosen dataset form nyuv2_3, cityscapes and taskonomy

In [None]:
log_file = open(f"logs/{dataset}/{network_name}.txt", "w")

In [None]:
if dataset == "nyuv2_3":
    from config_nyuv2_3task import Config
    config = Config()
    train_dataset = NYU_v2(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = NYU_v2(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "cityscapes":
    from config_cityscapes import Config
    config = Config()
    train_dataset = CityScapes(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = CityScapes(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "taskonomy":
    from config_taskonomy import Config
    config = Config()
    train_dataset = Taskonomy(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = Taskonomy(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
else:
    print("Unrecognized Dataset Name.")
    exit()

In [None]:
print("TrainDataset:", len(train_dataset))
print("TestDataset:", len(test_dataset))

In [None]:
net = SceneNet(config.TASKS_NUM_CLASS, config.BACKBONE_NAME).to(device)
# if dataset == "taskonomy":
#     net = nn.DataParallel(net, device_ids=[0, 1])

In [None]:
criterion = SceneNetLoss(dataset, config.TASKS, config.TASKS_NUM_CLASS, config.LAMBDAS, device, config.DATA_ROOT)
optimizer = torch.optim.Adam(net.parameters(), lr = config.INIT_LR, weight_decay = config.WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.DECAY_LR_FREQ, gamma=config.DECAY_LR_RATE)

In [None]:
batch_update = 16
net = train(net, dataset, criterion, optimizer, scheduler, train_loader, test_loader, network_name, batch_update, max_iters=config.MAX_ITERS, log_file=log_file, save_model=True, method="baseline", dest=dest)

In [None]:
from evaluation import SceneNetEval
import warnings
warnings.filterwarnings('ignore')
evaluator = SceneNetEval(device, config.TASKS, config.TASKS_NUM_CLASS, config.IMAGE_SHAPE, dataset, config.DATA_ROOT)
net.load_state_dict(torch.load(f"{dest}/best_{network_name}.pth"))
net.eval()
res = evaluator.get_final_metrics(net, test_loader)
log_file.write(str(res))
print(res)
log_file.close()