In [None]:
import os
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from dataloaders import *
from scene_net import *
from loss import SceneNetLoss
from train import train

from evaluation import SceneNetEval


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
postion = 'high'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = 'nyuv2_3'
method = 'prune_pt'
ratio = '70_seg_sn'
dest = f"path/to/save_dir"

In [None]:
network_name = f"{dataset}_{method}_{ratio}_{postion}"
print(network_name)

In [None]:
os.makedirs(dest, exist_ok=True)
os.makedirs(f"logs/{dataset}", exist_ok=True)
log_file = open(f"logs/{dataset}/{network_name}.txt", "w")

In [None]:
if dataset == "nyuv2_3":
    from config_nyuv2_3task import Config
    config = Config()
    train_dataset = NYU_v2(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = NYU_v2(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "cityscapes":
    from config_cityscapes import Config
    config = Config()
    train_dataset = CityScapes(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = CityScapes(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "taskonomy":
    from config_taskonomy import Config
    config = Config()
    train_dataset = Taskonomy(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    test_dataset = Taskonomy(config.DATA_ROOT, 'test')
    test_loader = DataLoader(test_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
else:
    print("Unrecognized Dataset Name.")
    exit()

In [None]:
print("TrainDataset:", len(train_dataset))
print("TestDataset:", len(test_dataset))

In [None]:
####################################
# It is mainly realized by freezing the parameters of unselected tasks and eliminating the loss function of unselected tasks.
# The elimination of unselected task loss function is mainly realized by adding the number '0 to 4' in the criterion_task parameter.
# 0 to 4 correspond to task 1 to task 5, respectively.
##########################
criterion_task = config.TASKS
if not isinstance(criterion_task, list):
    criterion_task = [criterion_task]
print(criterion_task)

##############################################################
# Freeze tasks that are not selected.
froze = 'task3' 
##############################################################

In [None]:
net = SceneNet(config.TASKS_NUM_CLASS, config.BACKBONE_NAME).to(device)

In [None]:
for name, param in net.named_parameters():
    if froze in name:
        param.requires_grad = False

In [None]:
import torch.nn.utils.prune as prune
# import torch.nn.functional as F
from prune_utils import print_sparsity
for module in net.modules():
    # Check if it's basic block
    if isinstance(module, nn.modules.conv.Conv2d) or isinstance(module, nn.modules.Linear):
        module = prune.identity(module, 'weight')

In [None]:
saved_state_dict = torch.load("path/to/pruned/model.pth")
# new_state_dict = {}
# for key, value in saved_state_dict.items():
#     new_key = 'module.' + key
#     new_state_dict[new_key] = value
net.load_state_dict(saved_state_dict)

In [None]:
for module in net.modules():
    # Check if it's basic block
    if isinstance(module, nn.modules.conv.Conv2d) or isinstance(module, nn.modules.Linear):
        module.weight = module.weight_orig * module.weight_mask
print_sparsity(net)

In [None]:
criterion = SceneNetLoss(dataset, criterion_task, config.TASKS_NUM_CLASS, config.LAMBDAS, device, config.DATA_ROOT)
optimizer = torch.optim.Adam(net.parameters(), lr = config.RETRAIN_LR, weight_decay = config.WEIGHT_DECAY)

In [None]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.RETRAIN_DECAY_LR_FREQ, gamma=config.DECAY_LR_RATE)
batch_update = 16

In [None]:
for name, param in net.named_parameters():
    print(name, param.requires_grad)

In [None]:
net = train(net, dataset, criterion, optimizer, scheduler, train_loader, test_loader, network_name, batch_update, max_iters = config.RETRAIN_EPOCH, save_model=True, log_file=log_file, method=method, dest=dest)
print_sparsity(net)
# torch.save(net.state_dict(), f"{dest}/final_{network_name}.pth")

In [None]:
evaluator = SceneNetEval(device, config.TASKS, config.TASKS_NUM_CLASS, config.IMAGE_SHAPE, dataset, config.DATA_ROOT)
print(f"{dest}/best_{network_name}.pth")
net.load_state_dict(torch.load(f"{dest}/best_{network_name}.pth"))
net.eval()
res = evaluator.get_final_metrics(net, test_loader)

log_file.write(str(res))
print(res)
log_file.close()