In [None]:
import os

import torch
import torch.nn as nn


from dataloaders import *
from scene_net import *
from prune_utils import *

import torch.nn.utils.prune as prune

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
dataset = 'nyuv2_3'
task = 'T2+T3'
ration = 0.7

In [None]:
if dataset == "nyuv2_3":
    from config_nyuv2_3task import Config
    config = Config()
elif dataset == "cityscapes":
    from config_cityscapes import Config
    config = Config()
elif dataset == "taskonomy":
    from config_taskonomy import Config
    config = Config()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
net = SceneNet(config.TASKS_NUM_CLASS, config.BACKBONE_NAME).to(device)

In [None]:
parameters_to_prune = []
for name, layer in net.named_modules():
    if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
        # if 'backbone' in name or 'task2' in name or 'task3' in name:
        if 'backbone' in name or 'task' in name:
            parameters_to_prune.append((layer, 'weight'))

In [None]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.RandomUnstructured,
    amount=ration,
)

In [None]:
def print_sparsity(prune_net, printing=True):
    # Prine the sparsity
    num = 0
    denom = 0
    ct = 0
    for module in prune_net.modules():
        if isinstance(module, nn.modules.conv.Conv2d) or isinstance(module, nn.modules.Linear):
            if hasattr(module, 'weight'):
                num += torch.sum(module.weight == 0)
                denom += module.weight.nelement()
                if printing:
                    print(
                    f"Layer {ct}", "Sparsity in weight: {:.2f}%".format(
                        100. * torch.sum(module.weight == 0) / module.weight.nelement())
                    )
                ct += 1
    if printing:
        print(f"Model Sparsity Now: {num / denom * 100}")
    return num / denom

In [None]:
print_sparsity(net)

In [None]:
torch.save(net.state_dict(), f"path/to/save_model.pth")