In [None]:
import os

import torch
from torch.utils.data import DataLoader

from dataloaders import *
from scene_net import *
from prune_utils import *
from loss import SceneNetLoss


In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
dataset = 'nyuv2_3'
ratio =70
num_batches = 50
task = 'T2+T3'
method = 'SNIP'
dest = f"path/to/save/model"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_path = f'{dest}/{task}_{ratio}.pth'

In [None]:
os.makedirs(dest, exist_ok=True)

In [None]:
if dataset == "nyuv2_3":
    from config_nyuv2_3task import Config
    config = Config()
    train_dataset = NYU_v2(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE, num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = NYU_v2(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "cityscapes":
    from config_cityscapes import Config
    config = Config()
    train_dataset = CityScapes(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = int(config.BATCH_SIZE / 2), num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = CityScapes(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
elif dataset == "taskonomy":
    from config_taskonomy import Config
    config = Config()
    train_dataset = Taskonomy(config.DATA_ROOT, 'train', crop_h=config.CROP_H, crop_w=config.CROP_W)
    train_loader = DataLoader(train_dataset, batch_size = config.BATCH_SIZE//4, num_workers = 8, shuffle=True, pin_memory=True)
    # test_dataset = Taskonomy(config.DATA_ROOT, 'test')
    # test_loader = DataLoader(test_dataset, batch_size = 1, num_workers = 8, shuffle=True, pin_memory=True)
else:
    print("Unrecognized Dataset Name.")
    exit()

In [None]:
net = SceneNet(config.TASKS_NUM_CLASS, config.BACKBONE_NAME).to(device)

In [None]:
criterion = SceneNetLoss(dataset, config.TASKS, config.TASKS_NUM_CLASS, config.LAMBDAS, device, config.DATA_ROOT)

In [None]:
if ratio == 70:
    keep_ratio = 0.3

In [None]:
net = SNIP_prune(net, criterion, train_loader, num_batches, keep_ratio)

In [None]:
print_sparsity(net)

In [None]:
print(f"Saving the pruned model to {save_path}")
torch.save(net.state_dict(), save_path)