In [None]:
!pip install torch_pruning

In [None]:
import numpy as np 
import pandas as pd
import json
from PIL import Image
import os
import shutil
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils.prune as prune
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.preprocessing import LabelBinarizer
import torch_pruning as tp

In [None]:
BATCH = 64
EPOCHS = 20

LR = 0.0001
IMG_SIZE = 256

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
TRAIN_DIR = '../input/cassava-leaf-disease-classification/train_images/'

In [None]:
labels = json.load(open("../input/cassava-leaf-disease-classification/label_num_to_disease_map.json"))
print(labels)

train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

In [None]:
def train_validate_test_split(df, train_percent=.7, validate_percent=.15, seed=2021):
    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    
    train = df.iloc[perm[:train_end]]
    train.to_csv('train.csv')
    
    validate = df.iloc[perm[train_end:validate_end]]
    validate.to_csv('val.csv')
    
    test = df.iloc[perm[validate_end:]]
    test.to_csv('test.csv')
    
    return train, validate, test

In [None]:
train_validate_test_split(train)

In [None]:
train = pd.read_csv('./train.csv')
val = pd.read_csv('./val.csv')
test = pd.read_csv('./test.csv')

In [None]:
X_Train = train['image_id'].values
Y_Train = train['label'].values

X_Val = val['image_id'].values
Y_Val = val['label'].values

X_Test = test['image_id'].values
Y_Test = test['label'].values


# X_Test = [name for name in (os.listdir(TEST_DIR))]

In [None]:
class GetData(Dataset):
    def __init__(self, Dir, FNames, Labels, Transform):
        self.dir = Dir
        self.fnames = FNames
        self.transform = Transform
        self.lbs = Labels
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):
        x = Image.open(os.path.join(self.dir, self.fnames[index]))
        if "train" in self.dir:            
            return self.transform(x), self.lbs[index]     
        elif "val" in self.dir:            
            return self.transform(x), self.fnames[index]
        elif "test" in self.dir:            
            return self.transform(x), self.fnames[index]

In [None]:
# Augmentation
Transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((IMG_SIZE, IMG_SIZE)),
     transforms.RandomRotation(90),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
# Создание dataloaders
trainset = GetData(TRAIN_DIR, X_Train, Y_Train, Transform)
trainloader = DataLoader(trainset, batch_size=BATCH, shuffle=True, num_workers=4)

validset = GetData(TRAIN_DIR, X_Val, Y_Val, Transform)
validloader = DataLoader(validset, batch_size=BATCH, shuffle=False, num_workers=4)

testset = GetData(TRAIN_DIR, X_Test, Y_Test, Transform)
testloader = DataLoader(testset, batch_size=BATCH, shuffle=False, num_workers=4)

In [None]:
# Загрузка модели
model = torchvision.models.resnet50()
model.fc = nn.Linear(2048, 5, bias=True)
model.load_state_dict(torch.load('../input/borisresnet50/ResNet50_teacher.pth'))
model = model.to(DEVICE)

In [None]:
def calc_weights(model):
    result = 0
    result += len(model.conv1.weight.reshape(-1)) + len(model.fc.weight.reshape(-1))
    return result

In [None]:
# Количество памяти, занятой нейронной сетью
def calc_size(model):
    torch.save(model.state_dict(), './saved_model1.pth')
    size = os.path.getsize('./saved_model1.pth')
    os.remove('./saved_model1.pth')
    return '{:.3f} KB'.format(size / 1024)

In [None]:
def get_metrics(model, dataloader, DEVICE):
    
    acc_pred_list = []
    roc_pred_list = []
    
    with torch.no_grad():
        y_true = []
        for images, label in dataloader:
            images = images.to(DEVICE)
            output = model(images)
            pred = F.softmax(output, dim=1)
            
            label = label.numpy()
            y_true.extend(label)
            
            roc_pred = pred
            roc_pred = [t.cpu().detach().numpy() for t in pred]
            roc_pred_list += roc_pred
            
            acc_pred = pred
            acc_pred = torch.argmax(acc_pred, dim=1).cpu().detach().numpy()
            acc_pred_list += [p.item() for p in acc_pred]
        
    lb = LabelBinarizer()
    binarized_classes = lb.fit_transform(y_true)
    y_true_binarized = binarized_classes

    return accuracy_score(acc_pred_list, y_true), roc_auc_score(y_true_binarized, roc_pred_list)

In [None]:
model = model.eval()

In [None]:
# conv_strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

In [None]:
# DG = tp.DependencyGraph()
# DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

In [None]:
# pruning_idxs = strategy(model.conv1.weight, amount=0.1)
# pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
# print(pruning_plan)

In [None]:
model = model.to(DEVICE)

In [None]:
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.fc, 'weight'),
)

In [None]:
epochs = 2
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
min_valid_loss = np.inf
compress_count = 7


print(f'Before Compress:\n{get_metrics(model, testloader, DEVICE)}\n{calc_size(model)}\n{calc_weights(model)} parameters\n')

for rate in range(compress_count):
    print(f'Compress №{rate+1}\n-------')
    train_loss = 0.0
#     pruning_plan.exec()
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=compress_rate,
    )
    model.train()
    for e in range(epochs):
        print(f'Epoch {e+1}')
        for data, labels in trainloader:
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()
            
            optimizer.zero_grad()
            target = model(data)
            loss = criterion(target,labels)
            loss.backward()
            optimizer.step()
            train_loss = loss.item() * data.size(0)

        valid_loss = 0.0
        model.eval()
        for data, labels in validloader:
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()

            target = model(data)
            loss = criterion(target,labels)
            valid_loss = loss.item() * data.size(0)
            
#         print(f'Training Loss: {(train_loss / len(trainloader)):.2f} \t\t Validation Loss: {(valid_loss / len(validloader)):.2f}\n')
        
    print(f'\nCompress Results:\n{get_metrics(model, testloader, DEVICE)}\n{calc_size(model)}\n{calc_weights(model)} parameters\n')
    
    print(
        "Sparsity in conv1.weight: {:.2f}%".format(
            100. * float(torch.sum(model.conv1.weight == 0))
            / float(model.conv1.weight.nelement())
        )
    )
    print(
        "Sparsity in fc.weight: {:.2f}%".format(
            100. * float(torch.sum(model.fc.weight == 0))
            / float(model.fc.weight.nelement())
        )
    )
    print(
        "Global sparsity: {:.2f}%\n".format(
            100. * float(
                torch.sum(model.conv1.weight == 0)
                + torch.sum(model.fc.weight == 0)
            )
            / float(
                model.conv1.weight.nelement()
                + model.fc.weight.nelement()
            )
        )
    )