In [None]:
import os
import torch
from torch.utils.data import DataLoader

from dataloaders import *
from scene_net import *
from prune_utils import *
from loss import Our_SceneNetLoss


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
dataset = 'taskonomy'
ratio =50  #Take a number between 0 and 100 to indicate the sparsity of the model. The larger the value, the higher the proportion of 0 in the model.
num_batches = 50 # Data driven rounds
method = 'prune_pt'
dest = f"path/to/save/model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
os.makedirs(dest, exist_ok=True)

In [None]:

if dataset == "nyuv2_3":
    from config_nyuv2_3task import Config
    config = Config()
    train_dataset = NYU_v2(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = NYU_v2(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "cityscapes":
    from config_cityscapes import Config
    config = Config()
    train_dataset = CityScapes(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = int(config.BATCH_SIZE / 2), num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = CityScapes(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "taskonomy":
    from config_taskonomy import Config
    config = Config()
    train_dataset = Taskonomy(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE//4, num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = Taskonomy(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
else:
    print("Unrecognized Dataset Name.")
    exit()

In [None]:
selected_tasks=[config.TASKS[3], config.TASKS[4]] # Tasks to be selected
# selected_tasks = config.TASKS
if not isinstance(selected_tasks, list):
    selected_tasks = [selected_tasks]
    str_task = selected_tasks[0]
else:
    str_task = "_".join(selected_tasks)
print(selected_tasks)
print(str_task)

In [None]:
network_name = f"{dataset}_{method}_{ratio}_{str_task}"
save_path = f"{dest}/{ratio}_{str_task}/{network_name}.pth"
os.makedirs(f"{dest}/{ratio}_{str_task}", exist_ok=True)
print(save_path)

In [None]:
net = SceneNet(config.TASKS_NUM_CLASS, config.BACKBONE_NAME).to(device)
orgin_dict = torch.load("path/to/base_model.pth")
## If the base_model is multi card trained. Otherwise, comment out the content.
# modified_state_dict = {}
# for key, value in orgin_dict.items():
#     if key.startswith("module."):
#         new_key = key[7:]  
#         modified_state_dict[new_key] = value
#     else:
#         modified_state_dict[key] = value
# net.load_state_dict(modified_state_dict)

net.load_state_dict(orgin_dict)

In [None]:
def create_pruned_model_nyuv2(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks):
    if ratio == 90:
        keep_ratio = 0.08
    elif ratio == 70:
        keep_ratio = 0.257
    elif ratio == 50:
        keep_ratio = 0.46
    elif ratio == 30:
        keep_ratio = 0.675
    else:
        keep_ratio = (100 - ratio) / 100
    net = prune_net(net, criterion, train_loader, num_batches, keep_ratio, device, selected_tasks, tasks)
    return net

def create_pruned_model_cityscapes(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks):
    if ratio == 90:
        keep_ratio = 0.095
    elif ratio == 70:
        keep_ratio = 0.3
    elif ratio == 50:
        keep_ratio = 0.51
    elif ratio == 30:
        keep_ratio = 0.71
    else:
        keep_ratio = (100 - ratio) / 100
    net = prune_net(net, criterion, train_loader, num_batches, keep_ratio, device, selected_tasks, tasks)
    return net

def create_pruned_model_taskonomy(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks):
    if ratio == 90:
        keep_ratio = 0.1
    elif ratio == 70:
        keep_ratio = 0.257
    elif ratio == 50:
        keep_ratio = 0.5
    elif ratio == 30:
        keep_ratio = 0.675
    else:
        keep_ratio = (100 - ratio) / 100
    net = prune_net(net, criterion, train_loader, num_batches, keep_ratio, device, selected_tasks, tasks)
    return net

In [None]:
def create_prune_model(dataset, ratio, num_batches, method, config, device, net, train_loader, selected_tasks):
    if method == "prune_pt":
        criterion = Our_SceneNetLoss(dataset, config.TASKS, config.TASKS_NUM_CLASS, config.LAMBDAS, device, config.DATA_ROOT)
        if dataset in ["nyuv2", 'nyuv2_3']:
            net = create_pruned_model_nyuv2(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks=config.TASKS)
        elif dataset == "cityscapes":
            net = create_pruned_model_cityscapes(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks=config.TASKS)
        elif dataset == "taskonomy":
            net = create_pruned_model_taskonomy(net, ratio, criterion, train_loader, num_batches, device, selected_tasks, tasks=config.TASKS)
    else:
        print("Unrecognized Dataset Name.")
        
    return net

In [None]:
net = create_prune_model(dataset, ratio, num_batches, method, config, device, net, train_loader, selected_tasks)

In [None]:
print(f"Saving the pruned model to {save_path}")
torch.save(net.state_dict(), save_path)