<a href="https://colab.research.google.com/github/yaaili/test/blob/master/pytorch/01%E5%9F%BA%E4%BA%8Eminst%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%89%AA%E6%9E%9D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# 导入相关包
from google.colab import drive

drive.mount('/content/gdrive')
import os

os.chdir("/content/gdrive/My Drive/Colab Notebooks/pytorch深度学习")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [2]:
import torch
from torch import nn
import torch.nn.functional as F
print(torch.__version__)

1.5.0+cu101


In [0]:
#复制一个与x相同的mask
def to_var(x, requires_grad=False):
    """
    Automatically choose cpu or cuda
    """
    if torch.cuda.is_available():
        x = x.cuda()
    return x.clone().detach().requires_grad_(requires_grad)

# 2 构建网络

## 2.1 定义卷积层

In [0]:
#定义卷积层
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.mask_flag = False

    def set_mask(self, mask):
        self.mask = to_var(mask, requires_grad=False)
        self.weight.data = self.weight.data * self.mask.data
        self.mask_flag = True

    def get_mask(self):
        print(self.mask_flag)
        return self.mask

    def forward(self, data):
        if self.mask_flag:
            weight = self.weight * self.mask
            return F.conv2d(data, weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)
        else:
            return F.conv2d(data, self.weight, self.bias, self.stride,
                            self.padding, self.dilation, self.groups)


## 2.2 构建全连接层

In [0]:
class MaskedLinear(nn.Linear):
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__(in_channels, out_channels, bias)
        self.mask_flag = False

    def set_mask(self, mask):
        self.mask = to_var(mask, requires_grad=False)
        self.weight.data = self.weight.data * self.mask.data
        self.mask_flag = True

    def get_mask(self):
        print(self.mask_flag)
        return self.mask

    def forward(self, data):
        if self.mask_flag:
            weight = self.weight * self.mask
            return F.linear(data, weight, self.bias)
        else:
            return F.linear(data, self.weight, self.bias)

### 2.3定义网络结构
这个网络有三个卷积层，两个全连接层组成，最后输出的是10分分类

In [0]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)

        self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
        self.relu3 = nn.ReLU(inplace=True)

        self.linear1 = MaskedLinear(7 * 7 * 64, 128)
        self.linear2 = MaskedLinear(128, 10)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, data):
        out = self.maxpool1(self.relu1(self.conv1(data)))
        out = self.maxpool2(self.relu2(self.conv2(out)))
        out = self.relu3(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        out = self.linear2(out)
        return out

    def get_loss(self, output, label):
        return self.loss(output, label)

    def set_masks(self, masks, isLinear=False):
        # Should be a less manual way to set masks
        # Leave it for the future
        if isLinear:
            self.linear1.set_mask(masks[0])
            self.linear2.set_mask(masks[1])
        else:
            self.conv1.set_mask(torch.from_numpy(masks[0]))
            self.conv2.set_mask(torch.from_numpy(masks[1]))
            self.conv3.set_mask(torch.from_numpy(masks[2]))


In [7]:
if __name__ == '__main__':
    net = MyNet()
    for p in net.conv1.parameters():
        print(p.data.size())
    for p in net.linear1.parameters():
        print(p.data.size())

torch.Size([32, 1, 3, 3])
torch.Size([32])
torch.Size([128, 3136])
torch.Size([128])


In [8]:
total = sum([param.nelement() for param in net.parameters()])
print('  + Number of params: %.2fM' % (total / 1e6))


  + Number of params: 0.46M


# 3 训练

In [0]:
import torch
from torchvision import datasets, transforms

from torch.utils.data import DataLoader
import math
import time

## 3.1 数据准备

In [0]:
class Trainer:
    def __init__(self, save_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.save_path = save_path
        self.net = MyNet().to(self.device)
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        #加载训练集,不是独热编码
        self.train_data = DataLoader(datasets.MNIST("./datasets/", train=True, transform=self.trans, download=True),
                                     batch_size=100, shuffle=True, drop_last=True)
        #加载测试集
        self.test_data = DataLoader(datasets.MNIST("./datasets/", train=False, transform=self.trans, download=True),
                                     batch_size=100, shuffle=True, drop_last=True)
        self.optimizer = torch.optim.Adam(self.net.parameters())
        self.net.train()
        
    def evaluate_accuracy(self,data_iter):
        acc_sum, n = 0.0, 0
        for X, y in data_iter:
            X, y = X.to(self.device), y.to(self.device)
            if isinstance(self.net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (self.net(X).argmax(dim=1) == y).float().sum().item()
                net.train() # 改回训练模式
            else: # 自定义的模型
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (self.net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (self.net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
        return acc_sum / n

    def train(self):
        for epoch in range(1, 5):
            total = 0
            train_acc_sum,train_l_sum,n,start = 0.0,0.0,0,time.time()
            for i, (data, label) in enumerate(self.train_data):
                data, label = data.to(self.device), label.to(self.device)
                output = self.net(data)
                loss = self.net.get_loss(output, label)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total += len(data)
                #训练损失
                train_l_sum += loss.item()
                train_acc_sum += ((output.argmax(dim=1)) == label).sum().item()
                n += 100  
                progress = math.ceil(i / len(self.train_data) * 50)
                print("\rTrain epoch %d: %d/%d, [%-51s] %d%%" %
                      (epoch, total, len(self.train_data.dataset),
                       '-' * progress + '>', progress * 2), end='')
            test_acc = self.evaluate_accuracy(self.test_data)
            print("\nepoch %d,loss %.4f, train_acc %.3f, test_acc %.3f,time %.1f sec"
              %(epoch+1,train_l_sum/n,train_acc_sum/n,test_acc,time.time()-start))
              
            torch.save(self.net.state_dict(), self.save_path)





In [11]:
if __name__ == '__main__':
    trainer = Trainer("./net.pth")
    trainer.train()

Train epoch 1: 60000/60000, [-------------------------------------------------->] 100%
epoch 2,loss 0.0015, train_acc 0.954, test_acc 0.986,time 11.5 sec
Train epoch 2: 60000/60000, [-------------------------------------------------->] 100%
epoch 3,loss 0.0004, train_acc 0.987, test_acc 0.991,time 11.2 sec
Train epoch 3: 60000/60000, [-------------------------------------------------->] 100%
epoch 4,loss 0.0003, train_acc 0.990, test_acc 0.988,time 11.2 sec
Train epoch 4: 60000/60000, [-------------------------------------------------->] 100%
epoch 5,loss 0.0002, train_acc 0.992, test_acc 0.992,time 11.4 sec


# 4 对模型进行剪枝
## 4.1 构建剪枝网络

In [92]:
import torch
import torch.nn.utils.prune as prune


class Pruning:
    #net_path是修建的模型，amount是模型的修建率
    def __init__(self, net_path, amount):
        self.net = MyNet()
        #加载模型
        self.net.load_state_dict(torch.load(net_path))
        #将模型都定义为元组,这是全局修剪的方法
        self.parameters_to_prune = (
            (self.net.conv1, 'weight'),
            (self.net.conv2, 'weight'),
            (self.net.conv3, 'weight'),
            (self.net.linear1, 'weight'),
            (self.net.linear2, 'weight'),
        )
        self.amount = amount

    def pruning(self):
      #全局修剪参数，方法是修剪绝对值参数
        prune.global_unstructured(
            self.parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=self.amount,
        )
        # print(self.net.state_dict().keys())
        # 删除weight_orig 、weight_mask以及forward_pre_hook
        prune.remove(self.net.conv1, 'weight')
        prune.remove(self.net.conv2, 'weight')
        prune.remove(self.net.conv3, 'weight')
        prune.remove(self.net.linear1, 'weight')
        prune.remove(self.net.linear2, 'weight')
        # print(self.net.linear1.weight)
        # mask = weight_prune(self.net, 60)
        # self.net.set_masks(mask, True)
        # torch.save(self.net.state_dict(), "self.nets/pruned_net_without_conv.pth")
        # filter_prune(self.net, 50)
        print(
            "Sparsity in conv1.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv1.weight == 0))
                / float(self.net.conv1.weight.nelement())
            )
        )
        print(
            "Sparsity in conv2.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv2.weight == 0))
                / float(self.net.conv2.weight.nelement())
            )
        )
        print(
            "Sparsity in conv3.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.conv3.weight == 0))
                / float(self.net.conv3.weight.nelement())
            )
        )
        print(
            "Sparsity in linear1.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.linear1.weight == 0))
                / float(self.net.linear1.weight.nelement())
            )
        )
        print(
            "Sparsity in linear2.weight: {:.2f}%".format(
                100. * float(torch.sum(self.net.linear2.weight == 0))
                / float(self.net.linear2.weight.nelement())
            )
        )
        print(
            "Global sparsity: {:.2f}%".format(
                100. * float(
                    torch.sum(self.net.conv1.weight == 0)
                    + torch.sum(self.net.conv2.weight == 0)
                    + torch.sum(self.net.conv3.weight == 0)
                    + torch.sum(self.net.linear1.weight == 0)
                    + torch.sum(self.net.linear2.weight == 0)
                )
                / float(
                    self.net.conv1.weight.nelement()
                    + self.net.conv2.weight.nelement()
                    + self.net.conv3.weight.nelement()
                    + self.net.linear1.weight.nelement()
                    + self.net.linear2.weight.nelement()
                )
            )
        )
        # torch.save(self.net.state_dict(), "models/pruned_net_with_conv.pth")
        torch.save(self.net.state_dict(), f"./pruned_net_with_torch_{self.amount:.1f}_l1.pth")


if __name__ == '__main__':
    for i in range(1, 10):
        pruning = Pruning("./net.pth", 0.1 * i)
        pruning.pruning()


Sparsity in conv1.weight: 0.35%
Sparsity in conv2.weight: 4.50%
Sparsity in conv3.weight: 6.08%
Sparsity in linear1.weight: 10.64%
Sparsity in linear2.weight: 4.06%
Global sparsity: 10.00%
Sparsity in conv1.weight: 0.69%
Sparsity in conv2.weight: 9.61%
Sparsity in conv3.weight: 12.47%
Sparsity in linear1.weight: 21.22%
Sparsity in linear2.weight: 7.03%
Global sparsity: 20.00%
Sparsity in conv1.weight: 2.08%
Sparsity in conv2.weight: 14.75%
Sparsity in conv3.weight: 19.11%
Sparsity in linear1.weight: 31.78%
Sparsity in linear2.weight: 10.55%
Global sparsity: 30.00%
Sparsity in conv1.weight: 3.12%
Sparsity in conv2.weight: 20.15%
Sparsity in conv3.weight: 25.90%
Sparsity in linear1.weight: 42.32%
Sparsity in linear2.weight: 13.44%
Global sparsity: 40.00%
Sparsity in conv1.weight: 4.51%
Sparsity in conv2.weight: 26.18%
Sparsity in conv3.weight: 33.31%
Sparsity in linear1.weight: 52.76%
Sparsity in linear2.weight: 17.58%
Global sparsity: 50.00%
Sparsity in conv1.weight: 6.94%
Sparsity in c

# 5 检测

In [94]:
class Detector:
    def __init__(self, net_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.net = MyNet().to(self.device)
        self.trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        self.test_data = DataLoader(datasets.MNIST("../datasets/", train=False, transform=self.trans, download=False),
                                    batch_size=100, shuffle=True, drop_last=True)
        # 如果没有GPU的话把在GPU上训练的参数放在CPU上运行，cpu-->gpu 1:lambda storage, loc: storage.cuda(1)
        self.map_location = None if torch.cuda.is_available() else lambda storage, loc: storage
        self.net.load_state_dict(torch.load(net_path, map_location=self.map_location))
        self.net.eval()

    def detect(self):
        test_loss = 0
        correct = 0
        start = time.time()
        with torch.no_grad():
            for data, label in self.test_data:
                data, label = data.to(self.device), label.to(self.device)
                output = self.net(data)
                test_loss += self.net.get_loss(output, label)
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(label.view_as(pred)).sum().item()

        end = time.time()
        print(f"total time:{end - start}")
        test_loss /= len(self.test_data.dataset)

        print('Test: average loss: {:.4f}, accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(self.test_data.dataset),
            100. * correct / len(self.test_data.dataset)))


if __name__ == '__main__':
    print("./net.pth")
    detector1 = Detector("./net.pth")
    detector1.detect()

    for i in range(1, 10):
        amount = 0.1 * i
        print(f"./pruned_net_with_torch_{amount:.1f}_l1.pth")
        detector1 = Detector(f"./pruned_net_with_torch_{amount:.1f}_l1.pth")
        detector1.detect()


./net.pth
total time:1.4062409400939941
Test: average loss: 0.0003, accuracy: 9915/10000 (99%)

./pruned_net_with_torch_0.1_l1.pth
total time:1.404999017715454
Test: average loss: 0.0003, accuracy: 9915/10000 (99%)

./pruned_net_with_torch_0.2_l1.pth
total time:1.3932301998138428
Test: average loss: 0.0003, accuracy: 9916/10000 (99%)

./pruned_net_with_torch_0.3_l1.pth
total time:1.404151439666748
Test: average loss: 0.0003, accuracy: 9918/10000 (99%)

./pruned_net_with_torch_0.4_l1.pth
total time:1.3637206554412842
Test: average loss: 0.0003, accuracy: 9915/10000 (99%)

./pruned_net_with_torch_0.5_l1.pth
total time:1.3584282398223877
Test: average loss: 0.0003, accuracy: 9912/10000 (99%)

./pruned_net_with_torch_0.6_l1.pth
total time:1.3912827968597412
Test: average loss: 0.0003, accuracy: 9906/10000 (99%)

./pruned_net_with_torch_0.7_l1.pth
total time:1.3662254810333252
Test: average loss: 0.0003, accuracy: 9907/10000 (99%)

./pruned_net_with_torch_0.8_l1.pth
total time:1.39269590377

In [0]:
module = model.conv1
print(list(module.named_parameters()))