<a href="https://colab.research.google.com/github/vifirsanova/hse-python-course/blob/main/compression/pruning_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 15861145.81it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 507536.20it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4427045.99it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 885041.99it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
class MNISTModel(nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
def train_model(model, trainloader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(trainloader):.4f}")

def test_model(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)  # Move to GPU/CPU
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

In [6]:
def iterative_pruning(model, trainloader, testloader, prune_iterations, prune_amount=0.2):
    for i in range(prune_iterations):
        print(f"\nIteration {i+1} of pruning:")

        prune.l1_unstructured(model.fc1, name='weight', amount=prune_amount)
        prune.l1_unstructured(model.fc2, name='weight', amount=prune_amount)

        model.fc1.to(device)
        model.fc2.to(device)

        print("Retraining after pruning...")
        train_model(model, trainloader, epochs=2)

        test_model(model, testloader)

In [7]:
def finalize_pruning(model):
    prune.remove(model.fc1, 'weight')
    prune.remove(model.fc2, 'weight')

    print("\nPruning finalized. Here are the weight statistics after final pruning:")
    print_weight_statistics(model.fc1, "fc1")
    print_weight_statistics(model.fc2, "fc2")

In [8]:
def print_weight_statistics(layer, name):
    nonzero = torch.count_nonzero(layer.weight)
    total = layer.weight.numel()
    sparsity = 100 - (nonzero / total * 100)
    print(f"{name}: Non-zero weights: {nonzero}, Total weights: {total}, Sparsity: {sparsity:.2f}%")

1. случайная инициализация весов

In [9]:
model = MNISTModel().to(device)

2. обучение

In [10]:
train_model(model, trainloader, epochs=5)

Epoch [1/5], Loss: 0.4096
Epoch [2/5], Loss: 0.1635
Epoch [3/5], Loss: 0.1152
Epoch [4/5], Loss: 0.0905
Epoch [5/5], Loss: 0.0772


In [11]:
accuracy_initial = test_model(model, testloader)

Accuracy: 97.37%


3. итеративное обучение (прунинг + обучение)

In [12]:
prune_iterations = 3 # количество итераций
prune_amount = 0.2 # pruning ratio

iterative_pruning(model, trainloader, testloader, prune_iterations, prune_amount)


Iteration 1 of pruning:
Retraining after pruning...
Epoch [1/2], Loss: 0.0579
Epoch [2/2], Loss: 0.0488
Accuracy: 97.73%

Iteration 2 of pruning:
Retraining after pruning...
Epoch [1/2], Loss: 0.0380
Epoch [2/2], Loss: 0.0323
Accuracy: 97.77%

Iteration 3 of pruning:
Retraining after pruning...
Epoch [1/2], Loss: 0.0245
Epoch [2/2], Loss: 0.0208
Accuracy: 98.01%


4. удаление маски прунинга (исходная инициализация) + ретрейн

In [13]:
finalize_pruning(model)


Pruning finalized. Here are the weight statistics after final pruning:
fc1: Non-zero weights: 120422, Total weights: 235200, Sparsity: 48.80%
fc2: Non-zero weights: 15360, Total weights: 30000, Sparsity: 48.80%


6. шаги 2-5, итеративно

In [14]:
print("\nFinal retraining after pruning:")
train_model(model, trainloader, epochs=5)

print("Final testing after pruning:")
accuracy_final = test_model(model, testloader)

print(f"Accuracy before pruning: {accuracy_initial:.2f}%")
print(f"Final accuracy after pruning: {accuracy_final:.2f}%")


Final retraining after pruning:
Epoch [1/5], Loss: 0.0328
Epoch [2/5], Loss: 0.0317
Epoch [3/5], Loss: 0.0265
Epoch [4/5], Loss: 0.0235
Epoch [5/5], Loss: 0.0208
Final testing after pruning:
Accuracy: 97.49%
Accuracy before pruning: 97.37%
Final accuracy after pruning: 97.49%
