In [1]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.models.mobilenetv3 import mobilenet_v3_small
from torchvision.models.resnet import resnet18

import nni

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def build_mobilenet_v3():
    model = mobilenet_v3_small(pretrained=True)
    model.classifier[-1] = torch.nn.Linear(1024, 10)
    return model.to(device)


def build_resnet18():
    model = resnet18(pretrained=True)
    model.fc = torch.nn.Linear(512, 10)
    return model.to(device)


def prepare_dataloader(batch_size: int = 128):
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    train_loader = DataLoader(
        datasets.CIFAR10(Path(__file__).parent / 'data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=batch_size, shuffle=True, num_workers=8)

    test_loader = DataLoader(
        datasets.CIFAR10(Path(__file__).parent / 'data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=batch_size, shuffle=False, num_workers=8)
    return train_loader, test_loader


def prepare_optimizer(model: torch.nn.Module):
    optimize_params = [param for param in model.parameters() if param.requires_grad == True]
    optimizer = nni.trace(Adam)(optimize_params, lr=0.001)
    return optimizer


def train(model: torch.nn.Module, optimizer: torch.optim.Optimizer, training_step,
          lr_scheduler: _LRScheduler, max_steps: int, max_epochs: int):
    assert max_epochs is not None or max_steps is not None
    train_loader, test_loader = prepare_dataloader()
    max_steps = max_steps if max_steps else max_epochs * len(train_loader)
    max_epochs = max_steps // len(train_loader) + (0 if max_steps % len(train_loader) == 0 else 1)
    count_steps = 0

    model.train()
    for epoch in range(max_epochs):
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            loss = training_step((data, target), model)
            loss.backward()
            optimizer.step()
            count_steps += 1
            if count_steps >= max_steps:
                acc = evaluate(model, test_loader)
                print(f'[Training Epoch {epoch} / Step {count_steps}] Final Acc: {acc}%')
                return
        acc = evaluate(model, test_loader)
        print(f'[Training Epoch {epoch} / Step {count_steps}] Final Acc: {acc}%')


def evaluate(model: torch.nn.Module, test_loader):
    model.eval()
    correct = 0.0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return 100 * correct / len(test_loader.dataset)


def training_step(batch, model: torch.nn.Module):
    output = model(batch[0])
    loss = F.cross_entropy(output, batch[1])
    return loss


In [None]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path
import sys
sys.path.append(str(Path(__file__).absolute().parents[1]))

import torch

# from models import (
#     build_resnet18,
#     prepare_dataloader,
#     prepare_optimizer,
#     train,
#     training_step,
#     evaluate,
#     device
# )

from nni.compression.pytorch import TorchEvaluator
from nni.compression.pytorch.pruning import SlimPruner
# from nni.compression.pytorch import auto_set_denpendency_group_ids
from nni.compression.pytorch import ModelSpeedup


if __name__ == '__main__':
    # finetuning resnet18 on Cifar10
    model = build_resnet18()
    optimizer = prepare_optimizer(model)
    train(model, optimizer, training_step, lr_scheduler=None, max_steps=None, max_epochs=10)
    _, test_loader = prepare_dataloader()
    print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))
    print('Original model after 10 epochs finetuning acc: ', evaluate(model, test_loader), '%')

    config_list = [{
        'op_types': ['Conv2d','Linear'],
        'sparse_ratio': 0.7
    }]
    dummy_input = torch.rand(8, 3, 224, 224).to(device)
    # config_list = auto_set_denpendency_group_ids(model, config_list, dummy_input)
    optimizer = prepare_optimizer(model)
    evaluator = TorchEvaluator(train, optimizer, training_step)

    pruner = SlimPruner(model, config_list, evaluator, training_steps=1000)

    _, masks = pruner.compress()
    pruner.unwrap_model()

    model = ModelSpeedup(model, dummy_input, masks).speedup_model()
    print('Pruned model paramater number: ', sum([param.numel() for param in model.parameters()]))
    print('Pruned model without finetuning acc: ', evaluate(model, test_loader), '%')

    optimizer = prepare_optimizer(model)
    train(model, optimizer, training_step, lr_scheduler=None, max_steps=None, max_epochs=10)
    _, test_loader = prepare_dataloader()
    print('Pruned model after 10 epochs finetuning acc: ', evaluate(model, test_loader), '%')

In [None]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
sys.path.append(r'D:\31890\Desktop\codefile\Utils')
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from model__ import segformer_m,segnet


if __name__ == '__main__':

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = segnet()
    torch.save(model,r'D:\31890\Desktop\codefile\Utils\Packages\prune__\test\para1.pt')

    for name, module in model.named_modules():
        # prune 20% of connections in all 2D-conv layers
        if isinstance(module, torch.nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=0.2)
        # prune 40% of connections in all linear layers
        elif isinstance(module, torch.nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=0.4)
    
    torch.save(model,r'D:\31890\Desktop\codefile\Utils\Packages\prune__\test\para2.pt')

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import sys
sys.path.append(r'D:\31890\Desktop\codefile\Utils')

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from model__ import ViT, channel_selection
from model__ import ViT_slim

device = 'cuda' if torch.cuda.is_available() else 'cpu'
cudnn.benchmark = True

model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,                  # 512
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
    )
model = model.to(device)

model_path = "checkpoint/pruning-adamw-vit-4-79.84.t7"
print("=> loading checkpoint '{}'".format(model_path))
checkpoint = torch.load(model_path)
start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['acc']
model.load_state_dict(checkpoint['net'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(model_path, checkpoint['epoch'], best_prec1))

total = 0
for m in model.modules():
    if isinstance(m, channel_selection):
        total += m.indexes.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, channel_selection):
        size = m.indexes.data.shape[0]
        bn[index:(index+size)] = m.indexes.data.abs().clone()
        index += size

percent = 0.3
y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]

# print(thre)

pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, channel_selection):
        # print(k)
        # print(m)
        if k in [16,40,64,88,112,136]:
            weight_copy = m.indexes.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
            thre_ = thre.clone()
            while (torch.sum(mask)%8 !=0):                       # heads
                thre_ = thre_ - 0.0001
                mask = weight_copy.gt(thre_).float().cuda()
        else:
            weight_copy = m.indexes.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.indexes.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))

pruned_ratio = pruned/total
print('Pre-processing Successful!')
print(cfg)


def test(model):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    testset = torchvision.datasets.CIFAR10(root='/home/lxc/ABCPruner/data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print('Acc: %.3f%% (%d/%d)' % (100.*correct/total, correct, total))

test(model)
cfg_prune = []
for i in range(len(cfg)):
    if i%2!=0:
        cfg_prune.append([cfg[i-1],cfg[i]])

newmodel = ViT_slim(image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1,
    cfg=cfg_prune)

newmodel.to(device)
# num_parameters = sum([param.nelement() for param in newmodel.parameters()])

newmodel_dict = newmodel.state_dict().copy()

i = 0
newdict = {}
for k,v in model.state_dict().items():
    if 'net1.0.weight' in k:
        # print(k)
        # print(v.size())
        # print('----------')
        idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
        newdict[k] = v[idx.tolist()].clone()
    elif 'net1.0.bias' in k:
        # print(k)
        # print(v.size())
        # print('----------')
        idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
        newdict[k] = v[idx.tolist()].clone()
    elif 'to_q' in k or 'to_k' in k or 'to_v' in k:
        # print(k)
        # print(v.size())
        # print('----------')
        idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
        newdict[k] = v[idx.tolist()].clone()
    elif 'net2.0.weight' in k:
        # print(k)
        # print(v.size())
        # print('----------')
        idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
        newdict[k] = v[:,idx.tolist()].clone()
        i = i + 1
    elif 'to_out.0.weight' in k:
        # print(k)
        # print(v.size())
        # print('----------')
        idx = np.squeeze(np.argwhere(np.asarray(cfg_mask[i].cpu().numpy())))
        newdict[k] = v[:,idx.tolist()].clone()
        i = i + 1

    elif k in newmodel.state_dict():
        newdict[k] = v

newmodel_dict.update(newdict)
newmodel.load_state_dict(newmodel_dict)

torch.save(newmodel.state_dict(), 'pruned.pth')
print('after pruning: ', end=' ')
test(newmodel)