In [1]:
import torch
from torchvision import datasets
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models as torchvision_models
import os
import shutil
from pathlib import Path
from torch.cuda.amp import autocast
from tqdm import tqdm
from kmeans_pytorch import kmeans
import wandb

import torchvision

import sys
sys.path.append("../")
from src.utils import ResNet18, transform_train, transform_test

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchvision_archs = sorted(name for name in torchvision_models.__dict__
                           if name.islower() and not name.startswith("__")
                           and callable(torchvision_models.__dict__[name]))

print(device)

cuda


In [4]:
class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        # path = super(ReturnIndexDataset, self).samples[idx]
        return idx, img, lab, idx

In [5]:
def find_class_means(X, labels, num_clusters):
    dim = X[0].shape[0]
    labels_sum = {i: torch.zeros(dim) for i in range(num_clusters)}
    labels_count = {i: 0 for i in range(num_clusters)}
    for i in range(len(X)):
        tensor = X[i]
        label = int(labels[i].item())
        labels_sum[label] += tensor
        labels_count[label] += 1
    labels_mean_tensor = torch.zeros((num_clusters, dim))
    for i in range(num_clusters):
        labels_mean_tensor[i] = labels_sum[i] / labels_count[i]
    return labels, labels_mean_tensor


In [6]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [26]:
def find_desired_samples(reps, indices, labels, base_dataset, target_dataset, cluster_centers, cluster_ids_x, quantile):
    res_values = []
    res_indices = []
    res_class_labels = []
    res_cluster_labels = []

    batch_size = 16
    num_clusters = len(cluster_centers)
    reps_dataset = CustomDataset(reps.detach())
    reps_dataloader = DataLoader(reps_dataset, batch_size=batch_size, shuffle=False)

    indices = torch.squeeze(indices)
    labels = torch.squeeze(labels)
    cluster_ids_x = torch.squeeze(cluster_ids_x)
    cluster_centers = cluster_centers.to(device)

    # calculate norm
    i = 0
    for tensor in tqdm(reps_dataloader, desc='Calculating norms'):
        tensor = tensor.to(device)
        norm_tensor = torch.linalg.norm(tensor.unsqueeze(dim=1) - cluster_centers.unsqueeze(dim=0), dim=2).detach()
        norm_tensor, norm_tensor_indecies = torch.sort(norm_tensor, dim=1)
        res_values += (-norm_tensor[:, 0]).tolist()
        res_indices += (indices[batch_size * i: (i + 1) * batch_size]).tolist()
        res_class_labels += (labels[batch_size * i: (i + 1) * batch_size]).tolist()
        res_cluster_labels += norm_tensor_indecies[:, 0].tolist()
        i += 1

    # reordering samples and finding quantiles baesd on each class
    cluster_scores = {k: [res_values[i] for i in range(len(res_values)) if int(res_cluster_labels[i]) == k] for k in
                        range(len(cluster_centers))}

    quantiles = {k: torch.quantile(torch.tensor(cluster_scores[k]), q=quantile) for k in
                    range(num_clusters) if len(cluster_scores[k]) != 0}
    score_dicts = {int(res_indices[i]): (res_values[i], int(res_class_labels[i]), int(res_cluster_labels[i])) for i
                    in
                    range(len(res_values))}
    results_based_on_class = {i: [] for i in range(len(target_dataset.classes))}

    # finding images which are in the quantile period
    for k, v in tqdm(score_dicts.items(), desc='Finding images in quntile'):
        if v[0] > quantiles[v[2]].item():
            results_based_on_class[v[1]].append(k)

    # find path of desired samples
    img_paths = {}
    for idx, img, label, ind in tqdm(target_dataset, desc='Gathering paths of desired samples'):
        image_path = target_dataset.samples[idx][0]
        if ind in results_based_on_class[label]:
            try:
                img_paths[label].append(image_path)
            except KeyError:
                img_paths[label] = [image_path]

    return img_paths

def reverse_normalization(images):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    return un_normalize(images)


def save_outputs(dst_path, dataset, farthest_samples_paths):
    try:
        shutil.rmtree(dst_path)
    except FileNotFoundError:
        pass
    Path(dst_path).mkdir(parents=True, exist_ok=True)
    for cls in dataset.classes:
        Path(os.path.join(dst_path, str(cls))).mkdir(parents=True, exist_ok=True)
    for cls, paths in farthest_samples_paths.items():
        for i, path in enumerate(paths):
            shutil.copy(path, os.path.join(dst_path, dataset.classes[cls]))


def generate_representations(batch_size, model, dataloader, dataset, desc=''):
    model.eval()

    reps = torch.zeros((len(dataloader) * batch_size, 1000))
    indices = torch.zeros((len(dataloader) * batch_size, 1))
    labels = torch.zeros((len(dataloader) * batch_size, 1))
    i = 0
    for idx, tensor, label, index in tqdm(dataloader, desc=desc):
        tensor = tensor.to(device)
        with autocast(enabled=True):
            feats = model(tensor)
        reps[i * batch_size: min((i + 1) * batch_size, len(dataset))] = feats.detach().cpu()
        labels[i * batch_size: min((i + 1) * batch_size, len(dataset))] = label[:, None]
        indices[i * batch_size: min((i + 1) * batch_size, len(dataset))] = index[:, None]
        i += 1
    return reps, indices, labels

In [27]:
# pretrained resnet-18 model
model = torch.hub.load('facebookresearch/swav:main', 'resnet50', pretrained=True)
model = model.to(device)

Using cache found in /nfs/homedirs/dhp/.cache/torch/hub/facebookresearch_swav_main


In [9]:
cifar_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# Create directories for each class
classes = cifar_dataset.classes
data_dir = './cifar10_data'

for cls in classes:
    os.makedirs(os.path.join(data_dir, cls), exist_ok=True)

import torchvision.transforms.functional as TF

# Move images to respective class directories
for idx, (image, label) in enumerate(cifar_dataset):
    class_dir = os.path.join(data_dir, classes[label])
    image_path = os.path.join(class_dir, f"img_{idx}.jpg")
    tensor_image = TF.to_tensor(image)  # Convert PIL image to tensor
    torchvision.utils.save_image(tensor_image, image_path)

print("CIFAR10 dataset downloaded and organized successfully.")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|███████████████████████████████████████████████████| 170498071/170498071 [00:07<00:00, 22878958.47it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
CIFAR10 dataset downloaded and organized successfully.


In [10]:
def get_data(data_path):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    dataset = ReturnIndexDataset(data_path, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=16,
        num_workers=2,
        pin_memory=True,
        drop_last=False,
        shuffle=True
    )
    return dataloader, dataset

In [28]:
dataloader, dataset = get_data("./cifar10_data")

In [29]:
reps, indices, labels = generate_representations(
    batch_size=16,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    desc='Generating representations'
)

Generating representations: 100%|███████████████████████████████████████| 3126/3126 [00:46<00:00, 67.11it/s]


In [30]:
data_size, dims = reps.shape
num_clusters = len(dataset.classes)

cluster_ids_x, cluster_centers = kmeans(X=reps,
                                        num_clusters=num_clusters,
                                        distance='euclidean',
                                        device=device,
                                        tol=1e-5)

running k-means on cuda..


[running kmeans]: 117it [03:36,  1.85s/it, center_shift=0.000006, iteration=117, tol=0.000010]


Quantile 50

In [31]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.5
                                              )

save_outputs(dst_path='./easy_samples_50/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████████████████████████████████████████| 3126/3126 [00:01<00:00, 1882.88it/s]
Finding images in quntile: 100%|█████████████████████████████████| 50001/50001 [00:00<00:00, 1023091.31it/s]
Gathering paths of desired samples: 100%|████████████████████████████| 50001/50001 [01:02<00:00, 800.57it/s]


Quantile 10

In [32]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.1
                                              )

save_outputs(dst_path='./easy_samples_10/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████████████████████████████████████████| 3126/3126 [00:01<00:00, 1863.11it/s]
Finding images in quntile: 100%|█████████████████████████████████| 50001/50001 [00:00<00:00, 1001841.05it/s]
Gathering paths of desired samples: 100%|████████████████████████████| 50001/50001 [01:02<00:00, 798.45it/s]


Quantile 20

In [33]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.2
                                              )

save_outputs(dst_path='./easy_samples_20/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████████████████████████████████████████| 3126/3126 [00:01<00:00, 1852.45it/s]
Finding images in quntile: 100%|█████████████████████████████████| 50001/50001 [00:00<00:00, 1007888.36it/s]
Gathering paths of desired samples: 100%|████████████████████████████| 50001/50001 [01:02<00:00, 804.08it/s]


In [34]:
!find easy_samples_50 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

easy_samples_50/ship: 2140
easy_samples_50/cat: 2299
easy_samples_50/truck: 2884
easy_samples_50/horse: 2480
easy_samples_50/airplane: 2278
easy_samples_50/bird: 2011
easy_samples_50/dog: 2434
easy_samples_50/automobile: 2642
easy_samples_50/deer: 2729
easy_samples_50/frog: 3093


### Train

In [19]:
import torch
import torchvision
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

In [35]:
prune_frac = "10"
wandb.init(project='cifar10_pruning', name='easy-neural-scaling-'+prune_frac)

0,1
Accuracy,▁▃▃▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████████
Final-Accuracy,▁
Loss,█▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,82.11
Final-Accuracy,82.11
Loss,0.18271
Top-5 Accuracy,98.91
Training Time,676.53654


In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_folder = "./easy_samples_"+ prune_frac

In [37]:
trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:    # Print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

Files already downloaded and verified
[1,   200] loss: 1.665
[epoch:1,  accuracy: 54.900
[2,   200] loss: 1.100
[epoch:2,  accuracy: 63.110
[3,   200] loss: 0.881
[epoch:3,  accuracy: 69.070
[4,   200] loss: 0.747
[epoch:4,  accuracy: 72.230
[5,   200] loss: 0.660
[epoch:5,  accuracy: 75.980
[6,   200] loss: 0.601
[epoch:6,  accuracy: 76.800
[7,   200] loss: 0.557
[epoch:7,  accuracy: 78.220
[8,   200] loss: 0.512
[epoch:8,  accuracy: 79.360
[9,   200] loss: 0.467
[epoch:9,  accuracy: 79.910
[10,   200] loss: 0.444
[epoch:10,  accuracy: 80.290
[11,   200] loss: 0.412
[epoch:11,  accuracy: 81.730
[12,   200] loss: 0.382
[epoch:12,  accuracy: 82.640
[13,   200] loss: 0.368
[epoch:13,  accuracy: 82.340
[14,   200] loss: 0.339
[epoch:14,  accuracy: 82.000
[15,   200] loss: 0.317
[epoch:15,  accuracy: 84.230
[16,   200] loss: 0.295
[epoch:16,  accuracy: 83.310
[17,   200] loss: 0.285
[epoch:17,  accuracy: 84.810
[18,   200] loss: 0.262
[epoch:18,  accuracy: 83.520
[19,   200] loss: 0.241
[e

Prune 20

In [38]:
prune_frac = "20"
wandb.init(project='cifar10_pruning', name='easy-neural-scaling-'+prune_frac)
train_folder = "./easy_samples_"+ prune_frac

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:    # Print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

0,1
Accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇█▇█▇████████████
Final-Accuracy,▁
Loss,█▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,85.7
Final-Accuracy,85.7
Loss,0.11597
Top-5 Accuracy,99.27
Training Time,1154.72579


Files already downloaded and verified
[1,   200] loss: 1.664
[epoch:1,  accuracy: 54.950
[2,   200] loss: 1.111
[epoch:2,  accuracy: 62.530
[3,   200] loss: 0.921
[epoch:3,  accuracy: 69.100
[4,   200] loss: 0.776
[epoch:4,  accuracy: 72.430
[5,   200] loss: 0.689
[epoch:5,  accuracy: 76.160
[6,   200] loss: 0.613
[epoch:6,  accuracy: 77.640
[7,   200] loss: 0.571
[epoch:7,  accuracy: 77.360
[8,   200] loss: 0.527
[epoch:8,  accuracy: 79.700
[9,   200] loss: 0.489
[epoch:9,  accuracy: 80.940
[10,   200] loss: 0.457
[epoch:10,  accuracy: 80.100
[11,   200] loss: 0.436
[epoch:11,  accuracy: 82.960
[12,   200] loss: 0.407
[epoch:12,  accuracy: 82.030
[13,   200] loss: 0.381
[epoch:13,  accuracy: 83.810
[14,   200] loss: 0.358
[epoch:14,  accuracy: 82.480
[15,   200] loss: 0.334
[epoch:15,  accuracy: 84.340
[16,   200] loss: 0.315
[epoch:16,  accuracy: 83.710
[17,   200] loss: 0.295
[epoch:17,  accuracy: 84.040
[18,   200] loss: 0.285
[epoch:18,  accuracy: 84.690
[19,   200] loss: 0.262
[e

Prune 50

In [39]:
prune_frac = "50"
wandb.init(project='cifar10_pruning', name='easy-neural-scaling-'+prune_frac)
train_folder = "./easy_samples_"+ prune_frac

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        printed = False
        if i % 200 == 199:  # Print every 200 mini-batches
            printed = True
            print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

        if len(trainloader) < 199 and not printed:
            if i % 100 == 99:
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 100))
                wandb.log({"Loss": running_loss / 100}, step=epoch * len(trainloader) + i)
                # Log the loss to wandb, so that we can visualize it
                running_loss = 0.0
                step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

0,1
Accuracy,▁▃▄▅▆▆▆▇▇▇▇▇▇▇█▇██████████████
Final-Accuracy,▁
Loss,█▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,85.56
Final-Accuracy,85.56
Loss,0.13025
Top-5 Accuracy,99.41
Training Time,1037.73031


Files already downloaded and verified
[1,   100] loss: 1.741


  return F.conv2d(input, weight, bias, self.stride,


[epoch:1,  accuracy: 47.480
[2,   100] loss: 1.242
[epoch:2,  accuracy: 54.250
[3,   100] loss: 1.047
[epoch:3,  accuracy: 61.440
[4,   100] loss: 0.916
[epoch:4,  accuracy: 65.850
[5,   100] loss: 0.826
[epoch:5,  accuracy: 67.730
[6,   100] loss: 0.740
[epoch:6,  accuracy: 70.640
[7,   100] loss: 0.669
[epoch:7,  accuracy: 70.790
[8,   100] loss: 0.630
[epoch:8,  accuracy: 72.250
[9,   100] loss: 0.603
[epoch:9,  accuracy: 72.140
[10,   100] loss: 0.550
[epoch:10,  accuracy: 73.060
[11,   100] loss: 0.522
[epoch:11,  accuracy: 74.250
[12,   100] loss: 0.507
[epoch:12,  accuracy: 76.200
[13,   100] loss: 0.476
[epoch:13,  accuracy: 74.900
[14,   100] loss: 0.440
[epoch:14,  accuracy: 77.180
[15,   100] loss: 0.426
[epoch:15,  accuracy: 76.760
[16,   100] loss: 0.405
[epoch:16,  accuracy: 78.450
[17,   100] loss: 0.380
[epoch:17,  accuracy: 77.470
[18,   100] loss: 0.363
[epoch:18,  accuracy: 78.630
[19,   100] loss: 0.356
[epoch:19,  accuracy: 76.950
[20,   100] loss: 0.331
[epoch:20,

In [40]:
wandb.init(project='cifar10_pruning', name='unpruned-again')
train_folder = "./cifar10_data"

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        printed = False
        if i % 200 == 199:  # Print every 200 mini-batches
            printed = True
            print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

        if len(trainloader) < 199 and not printed:
            if i % 100 == 99:
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 100))
                wandb.log({"Loss": running_loss / 100}, step=epoch * len(trainloader) + i)
                # Log the loss to wandb, so that we can visualize it
                running_loss = 0.0
                step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

0,1
Accuracy,▁▂▄▅▅▆▆▆▆▆▇▇▇▇▇▇▇█▇███████████
Final-Accuracy,▁
Loss,█▆▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,80.34
Final-Accuracy,80.34
Loss,0.19435
Top-5 Accuracy,98.61
Training Time,682.71345


Files already downloaded and verified
[1,   200] loss: 1.646
[epoch:1,  accuracy: 59.050
[2,   200] loss: 1.065
[epoch:2,  accuracy: 66.510
[3,   200] loss: 0.852
[epoch:3,  accuracy: 69.740
[4,   200] loss: 0.724
[epoch:4,  accuracy: 74.000
[5,   200] loss: 0.638
[epoch:5,  accuracy: 78.680
[6,   200] loss: 0.577
[epoch:6,  accuracy: 78.700
[7,   200] loss: 0.536
[epoch:7,  accuracy: 80.880
[8,   200] loss: 0.495
[epoch:8,  accuracy: 80.900
[9,   200] loss: 0.467
[epoch:9,  accuracy: 81.580
[10,   200] loss: 0.433
[epoch:10,  accuracy: 83.060
[11,   200] loss: 0.393
[epoch:11,  accuracy: 83.640
[12,   200] loss: 0.370
[epoch:12,  accuracy: 83.610
[13,   200] loss: 0.344
[epoch:13,  accuracy: 84.510
[14,   200] loss: 0.330
[epoch:14,  accuracy: 85.040
[15,   200] loss: 0.301
[epoch:15,  accuracy: 84.910
[16,   200] loss: 0.284
[epoch:16,  accuracy: 83.940
[17,   200] loss: 0.272
[epoch:17,  accuracy: 85.600
[18,   200] loss: 0.243
[epoch:18,  accuracy: 85.220
[19,   200] loss: 0.230
[e

In [42]:
import random

In [43]:
for prune_percentage in [0.1, 0.2, 0.5]:
    str_prune_percentage = str(int(prune_percentage * 100))
    wandb.init(project="cifar10_pruning", name="random-prune-again-" + str_prune_percentage)

    trainset = datasets.ImageFolder(train_folder, transform=transform_train)

    num_samples = len(trainset)
    frac_to_keep = 1 - prune_percentage
    num_samples_to_keep = int(frac_to_keep * num_samples)

    # Generate a random list of indices to keep
    indices_to_keep = random.sample(range(num_samples), num_samples_to_keep)

    pruned_trainset = torch.utils.data.Subset(trainset, indices_to_keep)


    trainloader = torch.utils.data.DataLoader(
        pruned_trainset, batch_size=128, shuffle=True, num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform_test
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False, num_workers=2
    )

    # Initialize the ConvNet model
    net = ResNet18().to(device)
    # Define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    # Train the ConvNet
    start_time = time.time()
    for epoch in range(30):  # Adjust the number of epochs as needed
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            printed = False
            if i % 200 == 199:  # Print every 200 mini-batches
                printed = True
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 200))
                wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
                # Log the loss to wandb, so that we can visualize it
                running_loss = 0.0
                step_val = epoch * len(trainloader) + i + 1

            if len(trainloader) < 199 and not printed:
                if i % 100 == 99:
                    print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 100))
                    wandb.log({"Loss": running_loss / 100}, step=epoch * len(trainloader) + i)
                    # Log the loss to wandb, so that we can visualize it
                    running_loss = 0.0
                    step_val = epoch * len(trainloader) + i + 1

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print("[epoch:%d,  accuracy: %.3f" % (epoch + 1, accuracy))
        wandb.log({"Accuracy": accuracy}, step=step_val)

        scheduler.step()


    end_time = time.time()
    training_time = end_time - start_time

    # Test the ConvNet on the test set
    correct = 0
    total = 0
    top5_correct = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

    accuracy = 100 * correct / total
    top5_accuracy = 100 * top5_correct / total
    wandb.log(
        {
            "Final-Accuracy": accuracy,
            "Top-5 Accuracy": top5_accuracy,
            "Training Time": training_time,
        }
    )

    wandb.finish()

Files already downloaded and verified
[1,   200] loss: 1.700
[epoch:1,  accuracy: 56.370
[2,   200] loss: 1.142
[epoch:2,  accuracy: 64.340
[3,   200] loss: 0.937
[epoch:3,  accuracy: 70.450
[4,   200] loss: 0.790
[epoch:4,  accuracy: 74.980
[5,   200] loss: 0.701
[epoch:5,  accuracy: 75.830
[6,   200] loss: 0.628
[epoch:6,  accuracy: 78.930
[7,   200] loss: 0.572
[epoch:7,  accuracy: 80.300
[8,   200] loss: 0.533
[epoch:8,  accuracy: 80.120
[9,   200] loss: 0.493
[epoch:9,  accuracy: 81.420
[10,   200] loss: 0.461
[epoch:10,  accuracy: 82.420
[11,   200] loss: 0.428
[epoch:11,  accuracy: 82.890
[12,   200] loss: 0.396
[epoch:12,  accuracy: 83.410
[13,   200] loss: 0.382
[epoch:13,  accuracy: 82.970
[14,   200] loss: 0.360
[epoch:14,  accuracy: 83.860
[15,   200] loss: 0.328
[epoch:15,  accuracy: 84.690
[16,   200] loss: 0.314
[epoch:16,  accuracy: 85.450
[17,   200] loss: 0.298
[epoch:17,  accuracy: 85.630
[18,   200] loss: 0.275
[epoch:18,  accuracy: 86.080
[19,   200] loss: 0.256
[e

0,1
Accuracy,▁▃▄▅▅▆▆▆▇▇▇▇▇▇▇███████████████
Final-Accuracy,▁
Loss,█▆▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,86.97
Final-Accuracy,86.97
Loss,0.12119
Top-5 Accuracy,99.55
Training Time,1164.85288


Files already downloaded and verified
[1,   200] loss: 1.656
[epoch:1,  accuracy: 56.670
[2,   200] loss: 1.169
[epoch:2,  accuracy: 62.050
[3,   200] loss: 0.952
[epoch:3,  accuracy: 68.670
[4,   200] loss: 0.806
[epoch:4,  accuracy: 72.860
[5,   200] loss: 0.710
[epoch:5,  accuracy: 75.110
[6,   200] loss: 0.643
[epoch:6,  accuracy: 78.590
[7,   200] loss: 0.597
[epoch:7,  accuracy: 76.840
[8,   200] loss: 0.550
[epoch:8,  accuracy: 79.680
[9,   200] loss: 0.505
[epoch:9,  accuracy: 78.730
[10,   200] loss: 0.484
[epoch:10,  accuracy: 81.340
[11,   200] loss: 0.443
[epoch:11,  accuracy: 81.000
[12,   200] loss: 0.430
[epoch:12,  accuracy: 81.920
[13,   200] loss: 0.405
[epoch:13,  accuracy: 82.820
[14,   200] loss: 0.376
[epoch:14,  accuracy: 84.140
[15,   200] loss: 0.348
[epoch:15,  accuracy: 83.760
[16,   200] loss: 0.324
[epoch:16,  accuracy: 84.400
[17,   200] loss: 0.306
[epoch:17,  accuracy: 84.490
[18,   200] loss: 0.293
[epoch:18,  accuracy: 84.240
[19,   200] loss: 0.269
[e

0,1
Accuracy,▁▂▄▅▅▆▆▆▆▇▇▇▇█▇███████████████
Final-Accuracy,▁
Loss,█▆▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,85.89
Final-Accuracy,85.89
Loss,0.13626
Top-5 Accuracy,99.22
Training Time,1035.78863


Files already downloaded and verified
[1,   100] loss: 1.828
[epoch:1,  accuracy: 48.110
[2,   100] loss: 1.353
[epoch:2,  accuracy: 56.050
[3,   100] loss: 1.148
[epoch:3,  accuracy: 59.720
[4,   100] loss: 1.023
[epoch:4,  accuracy: 66.470
[5,   100] loss: 0.890
[epoch:5,  accuracy: 69.280
[6,   100] loss: 0.805
[epoch:6,  accuracy: 73.170
[7,   100] loss: 0.724
[epoch:7,  accuracy: 74.730
[8,   100] loss: 0.671
[epoch:8,  accuracy: 75.560
[9,   100] loss: 0.610
[epoch:9,  accuracy: 75.190
[10,   100] loss: 0.590
[epoch:10,  accuracy: 77.540
[11,   100] loss: 0.541
[epoch:11,  accuracy: 77.190
[12,   100] loss: 0.516
[epoch:12,  accuracy: 78.680
[13,   100] loss: 0.501
[epoch:13,  accuracy: 79.510
[14,   100] loss: 0.461
[epoch:14,  accuracy: 80.490
[15,   100] loss: 0.427
[epoch:15,  accuracy: 80.720
[16,   100] loss: 0.409
[epoch:16,  accuracy: 80.170
[17,   100] loss: 0.386
[epoch:17,  accuracy: 81.460
[18,   100] loss: 0.372
[epoch:18,  accuracy: 81.140
[19,   100] loss: 0.342
[e

0,1
Accuracy,▁▃▃▅▅▆▆▆▆▇▇▇▇▇▇▇██████████████
Final-Accuracy,▁
Loss,█▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
Top-5 Accuracy,▁
Training Time,▁

0,1
Accuracy,82.14
Final-Accuracy,82.14
Loss,0.18526
Top-5 Accuracy,98.92
Training Time,683.03586
