# Large-Scale Dataset Pruning with Dynamic Uncertainty for MNIST

Implementation according to "Large-scale Dataset Pruning with Dynamic Uncertainty" (https://arxiv.org/abs/2306.05175)

**Objective:**

    "Implement data pruning using the dynamic uncertainty score on the MNIST dataset and train a model with 25% and 50% pruning (i.e., 25% or 50% of the data is removed for the final training using the calculated pruning scores). Compare your implementation with random subsampling of the data. Additionally, implement the following custom dynamic uncertainty score U(x) = abs(DFT(x)), where DFT is the Discrete Fourier Transform and abs takes the magnitude of the frequency spectrum. Consider the following: Do we need every value of the DFT, or can we remove some and get the same result? If we only want to consider the dynamics, which values do we need to remove? "

In [166]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Subset

from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import numpy as np

from tqdm import tqdm

In [167]:
batch_size = 32
pruning_ratio = 0.25
epochs = 10
crit = nn.CrossEntropyLoss()
J = 5 # J < epochs

In [168]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


---

In [169]:
# Eval accuracy of a given net on a given loader
def eval_acc(net, loader : DataLoader):
    net.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f'Test acc'):
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    return accuracy

---

Build Model

In [170]:
class CNN(nn.Module):
    def __init__(self) -> None:
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Linear(in_features=14*14*16, out_features=10)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        # x = self.softmax(x)
        return x
    
# Test model 
x_random = torch.randn((2,1,28,28))
net = CNN()
print(net(x_random))
del net

tensor([[-0.0244,  0.0529,  0.2159, -0.0866, -0.0318,  0.2499, -0.1300,  0.0026,
         -0.2277, -0.2908],
        [ 0.1704,  0.1435,  0.1600, -0.0016, -0.0947,  0.3685, -0.1382, -0.0325,
         -0.2190, -0.3119]], grad_fn=<AddmmBackward0>)


Load dataset

In [171]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,)) # Mean and std of MNIST dataset
        ])

trainset = MNIST(root='../data', train=True, download=True, transform=transform)
testset = MNIST(root='../data', train=False, download=True, transform=transform)

# Just shuffle once initially!
indices = torch.randperm(len(trainset))
train_size = int(0.8 * len(indices))
val_size = len(indices) - train_size
subset_size = int((1 - pruning_ratio) * train_size)
train_indices, val_indices = indices[:train_size], indices[train_size:]

train_subset = Subset(trainset, train_indices)
val_subset = Subset(trainset, val_indices)

trainloader = DataLoader(train_subset, batch_size=batch_size, shuffle=False)
valloader =  DataLoader(val_subset, batch_size=batch_size, shuffle=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

In [172]:
print(f"Trainset size: {train_size}")
print(f"Validation size: {val_size}")
print(f"Subset size: {subset_size}")
print(f"Teset size: {len(testset)}")

Trainset size: 48000
Validation size: 12000
Subset size: 36000
Teset size: 10000


---

**Baseline** (full dataset)

In [173]:
net_baseline = CNN()
optimizer_baseline = AdamW(net_baseline.parameters())
name_baseline = f"baseline_{epochs}"

In [174]:
for epoch in range(epochs):
    total_loss = 0.0
    for inputs, labels in tqdm(trainloader, desc=f'Epoch {epoch + 1}/{epochs}'):
        optimizer_baseline.zero_grad()
        outputs = net_baseline(inputs)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer_baseline.step()
        total_loss += loss.item()

    average_loss = total_loss / train_size
    print(f'Epoch {epoch + 1}/{epochs}, Average Loss: {average_loss:.4f}')

Epoch 1/10:   1%|          | 11/1500 [00:00<00:48, 30.68it/s]

Epoch 1/10: 100%|██████████| 1500/1500 [00:35<00:00, 42.28it/s]


Epoch 1/10, Average Loss: 0.0059


Epoch 2/10: 100%|██████████| 1500/1500 [00:27<00:00, 53.99it/s]


Epoch 2/10, Average Loss: 0.0020


Epoch 3/10: 100%|██████████| 1500/1500 [00:27<00:00, 53.94it/s]


Epoch 3/10, Average Loss: 0.0014


Epoch 4/10: 100%|██████████| 1500/1500 [00:27<00:00, 53.68it/s]


Epoch 4/10, Average Loss: 0.0011


Epoch 5/10: 100%|██████████| 1500/1500 [00:28<00:00, 51.77it/s]


Epoch 5/10, Average Loss: 0.0009


Epoch 6/10: 100%|██████████| 1500/1500 [00:28<00:00, 52.41it/s]


Epoch 6/10, Average Loss: 0.0007


Epoch 7/10: 100%|██████████| 1500/1500 [00:27<00:00, 54.20it/s]


Epoch 7/10, Average Loss: 0.0006


Epoch 8/10: 100%|██████████| 1500/1500 [00:27<00:00, 53.94it/s]


Epoch 8/10, Average Loss: 0.0005


Epoch 9/10: 100%|██████████| 1500/1500 [00:27<00:00, 53.60it/s]


Epoch 9/10, Average Loss: 0.0004


Epoch 10/10: 100%|██████████| 1500/1500 [00:28<00:00, 52.15it/s]

Epoch 10/10, Average Loss: 0.0004





In [175]:
print(f'Test Accuracy: {eval_acc(net=net_baseline, loader=testloader) * 100:.2f}%')

Test acc: 100%|██████████| 313/313 [00:04<00:00, 74.09it/s]

Test Accuracy: 98.41%





In [176]:
torch.save(net_baseline.state_dict(), f"../models/{name_baseline}.pth")

---

**Random Subsampling**

In [177]:
net_subsampling = CNN()
optimizer_subsampling = AdamW(net_subsampling.parameters())
name_subsampling = f"subsampling_{pruning_ratio}_{epochs}"

In [178]:
from torch.utils.data import SubsetRandomSampler
idxs = torch.randperm(train_size)[:subset_size]
sampler = SubsetRandomSampler(idxs)
randomsubsampledloader = DataLoader(trainloader.dataset, batch_size=batch_size, sampler=sampler)

In [179]:
for epoch in range(epochs):
    total_loss = 0.0
    for inputs, labels in tqdm(randomsubsampledloader, desc=f'Epoch {epoch + 1}/{epochs}'):
        optimizer_subsampling.zero_grad()
        outputs = net_subsampling(inputs)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer_subsampling.step()
        total_loss += loss.item()

    average_loss = total_loss / subset_size
    print(f'Epoch {epoch + 1}/{epochs}, Average Loss: {average_loss:.4f}')

Epoch 1/10: 100%|██████████| 1125/1125 [00:23<00:00, 48.87it/s]


Epoch 1/10, Average Loss: 0.0065


Epoch 2/10: 100%|██████████| 1125/1125 [00:20<00:00, 54.36it/s]


Epoch 2/10, Average Loss: 0.0021


Epoch 3/10: 100%|██████████| 1125/1125 [00:20<00:00, 53.64it/s]


Epoch 3/10, Average Loss: 0.0015


Epoch 4/10: 100%|██████████| 1125/1125 [00:21<00:00, 53.33it/s]


Epoch 4/10, Average Loss: 0.0012


Epoch 5/10: 100%|██████████| 1125/1125 [00:21<00:00, 52.41it/s]


Epoch 5/10, Average Loss: 0.0009


Epoch 6/10: 100%|██████████| 1125/1125 [00:20<00:00, 54.77it/s]


Epoch 6/10, Average Loss: 0.0007


Epoch 7/10: 100%|██████████| 1125/1125 [00:21<00:00, 52.53it/s]


Epoch 7/10, Average Loss: 0.0006


Epoch 8/10: 100%|██████████| 1125/1125 [00:20<00:00, 54.37it/s]


Epoch 8/10, Average Loss: 0.0005


Epoch 9/10: 100%|██████████| 1125/1125 [00:20<00:00, 54.93it/s]


Epoch 9/10, Average Loss: 0.0004


Epoch 10/10: 100%|██████████| 1125/1125 [00:20<00:00, 53.94it/s]

Epoch 10/10, Average Loss: 0.0003





In [180]:
print(f'Test Accuracy: {eval_acc(net=net_subsampling, loader=testloader) * 100:.2f}%')
torch.save(net_subsampling.state_dict(), f"../models/{name_subsampling}.pth")

Test acc: 100%|██████████| 313/313 [00:03<00:00, 83.85it/s]

Test Accuracy: 98.43%





---

**With Pruning**

In [181]:
net_pruning = CNN()
optimizer_pruning = AdamW(net_pruning.parameters())

In [182]:
# Algorithm 1: Dataset pruning with dynamic uncertainty.
# Input: Trainingset - trainloader, pruning ratio: pruning_ratio
# Required Model: net_pruning, traing epochs: epochs, uncertainty window: J

# To track uncertainty
uncertainty_window = np.zeros((train_size, J)) # Uncertainty window
uncertainty_EQ2 = np.zeros((train_size, epochs-J+1)) # Uncertainty according to Eq.2
uncertainty = np.zeros(train_size) # Overall uncertainty according to Eq.3

# for k = 0, · · · , K − 1 do
for epoch in range(epochs):
    total_loss = 0.0
    idx = 0
    
    # Sample a batch B ∼ T.
    # for (xi, yi) ∈ B do:
    for inputs, labels in tqdm(trainloader, desc=f'Epoch {epoch + 1}/{epochs}'):
        optimizer_pruning.zero_grad()

        # Compute prediction P(yi, xi, θ) and loss ℓ(ϕθ(A(xi)), yi)
        outputs = net_pruning(inputs)
        loss = crit(outputs, labels)

        # Store window
        predicted_values = outputs[range(outputs.size(0)), labels]
        uncertainty_window[idx:idx+len(labels), epoch%J] = predicted_values.detach().numpy()
        idx += len(labels)

        # Update θ ← θ − η∇θL, where L =Σℓ(ϕθ(A(xi)),yi) / |B|
        loss.backward()
        optimizer_pruning.step()
        total_loss += loss.item()

    # if k ≥ J then
        # Compute uncertainty Uk−J (xi) using Eq. 2
    if epoch >= J-1:
            U_epoch = np.std(uncertainty_window, ddof=1, axis=1)
            uncertainty_EQ2[:, epoch-J+1] = U_epoch

    average_loss = total_loss / subset_size
    print(f'Epoch {epoch + 1}/{epochs}, Average Loss: {average_loss:.4f}')

# for (xi, yi) ∈ T do
    # Compute dynamic uncertainty U(xi) using Eq. 3
uncertainty = np.mean(uncertainty_EQ2, axis=1)

# Sort T in the descending order of U(·)
sorted_indices = np.argsort(uncertainty)[::-1]

# S ← front (1 − r) × |T | samples in the sorted T
subset_indices = sorted_indices[:int(len(sorted_indices)*(1-pruning_ratio))]

# Output: Pruned dataset S
train_dynamic_uncertainty_subset = Subset(dataset=trainloader.dataset, indices=subset_indices)

Epoch 1/10:  43%|████▎     | 652/1500 [00:13<00:25, 33.90it/s]

Epoch 1/10: 100%|██████████| 1500/1500 [00:43<00:00, 34.20it/s]


Epoch 1/10, Average Loss: 0.0067


Epoch 2/10: 100%|██████████| 1500/1500 [00:37<00:00, 40.37it/s]


Epoch 2/10, Average Loss: 0.0027


Epoch 3/10: 100%|██████████| 1500/1500 [00:33<00:00, 44.65it/s]


Epoch 3/10, Average Loss: 0.0019


Epoch 4/10: 100%|██████████| 1500/1500 [00:33<00:00, 44.68it/s]


Epoch 4/10, Average Loss: 0.0014


Epoch 5/10: 100%|██████████| 1500/1500 [00:33<00:00, 44.20it/s]


Epoch 5/10, Average Loss: 0.0011


Epoch 6/10: 100%|██████████| 1500/1500 [00:33<00:00, 44.14it/s]


Epoch 6/10, Average Loss: 0.0008


Epoch 7/10: 100%|██████████| 1500/1500 [00:34<00:00, 43.50it/s]


Epoch 7/10, Average Loss: 0.0007


Epoch 8/10: 100%|██████████| 1500/1500 [00:34<00:00, 43.44it/s]


Epoch 8/10, Average Loss: 0.0005


Epoch 9/10: 100%|██████████| 1500/1500 [00:34<00:00, 43.39it/s]


Epoch 9/10, Average Loss: 0.0005


Epoch 10/10: 100%|██████████| 1500/1500 [00:34<00:00, 43.59it/s]

Epoch 10/10, Average Loss: 0.0004





In [183]:
net_pruning = CNN()
optimizer_pruning = AdamW(net_pruning.parameters())
name_pruning = f"pruning_{pruning_ratio}_{epochs}"
dynamic_uncertainty = DataLoader(train_dynamic_uncertainty_subset, batch_size=batch_size, shuffle=False)

In [184]:
for epoch in range(epochs):
    total_loss = 0.0
    for inputs, labels in tqdm(dynamic_uncertainty, desc=f'Epoch {epoch + 1}/{epochs}'):
        optimizer_pruning.zero_grad()
        outputs = net_pruning(inputs)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer_pruning.step()
        total_loss += loss.item()

    average_loss = total_loss / train_size
    print(f'Epoch {epoch + 1}/{epochs}, Average Loss: {average_loss:.4f}')

Epoch 1/10: 100%|██████████| 1125/1125 [00:23<00:00, 48.18it/s]


Epoch 1/10, Average Loss: 0.0033


Epoch 2/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.30it/s]


Epoch 2/10, Average Loss: 0.0011


Epoch 3/10: 100%|██████████| 1125/1125 [00:22<00:00, 50.22it/s]


Epoch 3/10, Average Loss: 0.0007


Epoch 4/10: 100%|██████████| 1125/1125 [00:23<00:00, 48.33it/s]


Epoch 4/10, Average Loss: 0.0004


Epoch 5/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.46it/s]


Epoch 5/10, Average Loss: 0.0003


Epoch 6/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.49it/s]


Epoch 6/10, Average Loss: 0.0003


Epoch 7/10: 100%|██████████| 1125/1125 [00:22<00:00, 50.31it/s]


Epoch 7/10, Average Loss: 0.0003


Epoch 8/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.64it/s]


Epoch 8/10, Average Loss: 0.0002


Epoch 9/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.37it/s]


Epoch 9/10, Average Loss: 0.0001


Epoch 10/10: 100%|██████████| 1125/1125 [00:22<00:00, 49.17it/s]

Epoch 10/10, Average Loss: 0.0002





In [185]:
print(f'Test Accuracy: {eval_acc(net=net_pruning, loader=testloader) * 100:.2f}%')
torch.save(net_pruning.state_dict(), f"../models/{name_pruning}.pth")

Test acc: 100%|██████████| 313/313 [00:05<00:00, 61.59it/s]

Test Accuracy: 97.93%





---

**Discrete Fourier Transform**

In [186]:
from torch.fft import fft, fftn

In [187]:
x_test = torch.randn((5,1,28,28))
x_fft = fft(x_test).abs()

In [188]:
x_fft

tensor([[[[ 7.0016,  2.3408,  9.4974,  ...,  5.6659,  9.4974,  2.3408],
          [ 3.6225,  4.8430,  3.0725,  ...,  3.4119,  3.0725,  4.8430],
          [ 0.2753,  5.6806,  4.8276,  ...,  2.6253,  4.8276,  5.6806],
          ...,
          [ 0.3260,  8.0120, 11.6901,  ...,  2.6139, 11.6901,  8.0120],
          [ 5.0772,  5.0792,  8.8253,  ...,  2.1639,  8.8253,  5.0792],
          [ 5.1682,  3.1012,  6.1085,  ...,  4.5523,  6.1085,  3.1012]]],


        [[[ 1.9000, 11.5732,  0.8645,  ...,  6.9359,  0.8645, 11.5732],
          [ 0.8429,  7.1713,  1.7348,  ...,  3.8792,  1.7348,  7.1713],
          [ 4.5909,  2.1676,  3.8562,  ...,  5.0738,  3.8562,  2.1676],
          ...,
          [10.2103,  6.2847,  7.5095,  ...,  4.6634,  7.5095,  6.2847],
          [ 1.8560,  1.9186,  2.7215,  ...,  0.9371,  2.7215,  1.9186],
          [ 2.8874,  2.0643,  1.8409,  ...,  2.2602,  1.8409,  2.0643]]],


        [[[ 4.5337,  7.6894,  5.4230,  ...,  1.7024,  5.4230,  7.6894],
          [ 8.3174,  1.694

In [189]:
net_pruning = CNN()
optimizer_pruning = AdamW(net_pruning.parameters())
name_pruning = f"pruning_{pruning_ratio}"

In [190]:
x_test = torch.randn((5,1,28,28))
fftn(x_test).abs()

tensor([[[[ 95.6847,  33.9431,  47.4507,  ...,  85.2217,  47.4507,  33.9431],
          [ 24.1194,  97.7538,  47.1660,  ...,   3.2349,   7.6042,  64.5337],
          [ 35.5327,  57.3047,  64.6004,  ...,  73.6595,  87.1928,  41.8536],
          ...,
          [ 72.6193,  52.8112,  28.6272,  ..., 103.2676,  39.9172,  57.4693],
          [ 35.5327,  41.8536,  87.1928,  ...,  21.3594,  64.6004,  57.3047],
          [ 24.1193,  64.5337,   7.6042,  ...,  38.8029,  47.1660,  97.7538]]],


        [[[ 59.5293,  70.1278, 100.1183,  ...,  18.5314,  21.3645,  74.7055],
          [ 34.4041,  77.9607,  99.2598,  ...,  67.2690,  72.3460,  20.9297],
          [ 37.0432,  38.8848,  63.0727,  ..., 154.7969, 163.1138,  53.1937],
          ...,
          [ 76.3877,  38.3539,  29.4337,  ...,  29.7651,  74.1724,  58.2424],
          [ 91.4610,  43.8345,  51.3067,  ...,  46.2332,  77.5871,  57.5014],
          [ 39.6031,  65.3447,  64.8417,  ...,  64.9992,  52.2133,  72.0684]]],


        [[[ 78.7171,  52.8