In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install -qU kmeans_pytorch timm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
%cd /content/drive/MyDrive/GR

/content/drive/MyDrive/GR


In [4]:
import torch
from torchvision import datasets
from torch.utils.data import Dataset
from functools import partial
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models as torchvision_models
import os
import shutil
from pathlib import Path
from torch.cuda.amp import autocast
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from kmeans_pytorch import kmeans

import logging
import torch.nn.functional as F
import torch.distributed as dist

import numpy as np
import torchvision
from math import sqrt

import timm.models.vision_transformer
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchvision_archs = sorted(name for name in torchvision_models.__dict__
                           if name.islower() and not name.startswith("__")
                           and callable(torchvision_models.__dict__[name]))

print(device)

cuda


In [6]:
class ReturnIndexDataset(datasets.ImageFolder):
    def __getitem__(self, idx):
        img, lab = super(ReturnIndexDataset, self).__getitem__(idx)
        # path = super(ReturnIndexDataset, self).samples[idx]
        return idx, img, lab, idx

In [7]:
def find_class_means(X, labels, num_clusters):
    dim = X[0].shape[0]
    labels_sum = {i: torch.zeros(dim) for i in range(num_clusters)}
    labels_count = {i: 0 for i in range(num_clusters)}
    for i in range(len(X)):
        tensor = X[i]
        label = int(labels[i].item())
        labels_sum[label] += tensor
        labels_count[label] += 1
    labels_mean_tensor = torch.zeros((num_clusters, dim))
    for i in range(num_clusters):
        labels_mean_tensor[i] = labels_sum[i] / labels_count[i]
    return labels, labels_mean_tensor


In [8]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [24]:
def find_desired_samples(reps, indices, labels, base_dataset, target_dataset, cluster_centers, cluster_ids_x, quantile):
    res_values = []
    res_indices = []
    res_class_labels = []
    res_cluster_labels = []

    batch_size = 16
    num_clusters = len(cluster_centers)
    reps_dataset = CustomDataset(reps.detach())
    reps_dataloader = DataLoader(reps_dataset, batch_size=batch_size, shuffle=False)

    indices = torch.squeeze(indices)
    labels = torch.squeeze(labels)
    cluster_ids_x = torch.squeeze(cluster_ids_x)
    cluster_centers = cluster_centers.to(device)

    # calculate norm
    i = 0
    for tensor in tqdm(reps_dataloader, desc='Calculating norms'):
        tensor = tensor.to(device)
        norm_tensor = torch.linalg.norm(tensor.unsqueeze(dim=1) - cluster_centers.unsqueeze(dim=0), dim=2).detach()
        norm_tensor, norm_tensor_indecies = torch.sort(norm_tensor, dim=1)
        res_values += (norm_tensor[:, 0] - norm_tensor[:, 1] - norm_tensor[:, 2]).tolist()
        res_indices += (indices[batch_size * i: (i + 1) * batch_size]).tolist()
        res_class_labels += (labels[batch_size * i: (i + 1) * batch_size]).tolist()
        res_cluster_labels += norm_tensor_indecies[:, 0].tolist()
        i += 1

    # reordering samples and finding quantiles baesd on each class
    cluster_scores = {k: [res_values[i] for i in range(len(res_values)) if int(res_cluster_labels[i]) == k] for k in
                        range(len(cluster_centers))}

    # save representation's distribution histogram
    save_histograms(cluster_scores)

    quantiles = {k: torch.quantile(torch.tensor(cluster_scores[k]), q=quantile) for k in
                    range(num_clusters) if len(cluster_scores[k]) != 0}
    score_dicts = {int(res_indices[i]): (res_values[i], int(res_class_labels[i]), int(res_cluster_labels[i])) for i
                    in
                    range(len(res_values))}
    results_based_on_class = {i: [] for i in range(len(target_dataset.classes))}

    # finding images which are in the quantile period
    for k, v in tqdm(score_dicts.items(), desc='Finding images in quntile'):
        if v[0] > quantiles[v[2]].item():
            results_based_on_class[v[1]].append(k)

    # find path of desired samples
    img_paths = {}
    for idx, img, label, ind in tqdm(target_dataset, desc='Gathering paths of desired samples'):
        image_path = target_dataset.samples[idx][0]
        if ind in results_based_on_class[label]:
            try:
                img_paths[label].append(image_path)
            except KeyError:
                img_paths[label] = [image_path]

    return img_paths


def save_histograms(cluster_scores):
    histograms_path = 'histograms'
    if os.path.exists(histograms_path):
        shutil.rmtree(histograms_path)
    os.mkdir(histograms_path)

    for cls, scores in cluster_scores.items():
        sns.distplot(scores)
        plt.title('class :' + str(cls))
        plt.savefig(os.path.join(histograms_path, 'class ' + str(cls) + '.jpg'))
        plt.clf()


def reverse_normalization(images):
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    return un_normalize(images)


def save_outputs(dst_path, dataset, farthest_samples_paths):
    try:
        shutil.rmtree(dst_path)
    except FileNotFoundError:
        pass
    Path(dst_path).mkdir(parents=True, exist_ok=True)
    for cls in dataset.classes:
        Path(os.path.join(dst_path, str(cls))).mkdir(parents=True, exist_ok=True)
    for cls, paths in farthest_samples_paths.items():
        for i, path in enumerate(paths):
            shutil.copy(path, os.path.join(dst_path, dataset.classes[cls]))


def generate_representations(batch_size, model, dataloader, dataset, desc=''):
    model.eval()

    reps = torch.zeros((len(dataloader) * batch_size, 1000))
    indices = torch.zeros((len(dataloader) * batch_size, 1))
    labels = torch.zeros((len(dataloader) * batch_size, 1))
    i = 0
    for idx, tensor, label, index in tqdm(dataloader, desc=desc):
        tensor = tensor.to(device)
        with autocast(enabled=True):
            feats = model(tensor)
        reps[i * batch_size: min((i + 1) * batch_size, len(dataset))] = feats.detach().cpu()
        labels[i * batch_size: min((i + 1) * batch_size, len(dataset))] = label[:, None]
        indices[i * batch_size: min((i + 1) * batch_size, len(dataset))] = index[:, None]
        i += 1
    return reps, indices, labels

In [10]:
# pretrained resnet-18 model
model = torch.hub.load('facebookresearch/swav:main', 'resnet50', pretrained=True)
model = model.to(device)

Downloading: "https://github.com/facebookresearch/swav/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar" to /root/.cache/torch/hub/checkpoints/swav_800ep_pretrain.pth.tar
100%|██████████| 108M/108M [00:00<00:00, 274MB/s]


In [11]:
# cifar_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)

# # Create directories for each class
# classes = cifar_dataset.classes
# data_dir = './cifar10_data'

# for cls in classes:
#     os.makedirs(os.path.join(data_dir, cls), exist_ok=True)

# import torchvision.transforms.functional as TF

# # Move images to respective class directories
# for idx, (image, label) in enumerate(cifar_dataset):
#     class_dir = os.path.join(data_dir, classes[label])
#     image_path = os.path.join(class_dir, f"img_{idx}.jpg")
#     tensor_image = TF.to_tensor(image)  # Convert PIL image to tensor
#     torchvision.utils.save_image(tensor_image, image_path)

# print("CIFAR10 dataset downloaded and organized successfully.")

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 48468316.91it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
CIFAR10 dataset downloaded and organized successfully.


In [12]:
def get_data(data_path):
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    dataset = ReturnIndexDataset(data_path, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=16,
        num_workers=2,
        pin_memory=True,
        drop_last=False,
        shuffle=True
    )
    return dataloader, dataset

In [13]:
dataloader, dataset = get_data("/content/drive/MyDrive/GR/cifar10_data/")

In [14]:
reps, indices, labels = generate_representations(
    batch_size=16,
    model=model,
    dataloader=dataloader,
    dataset=dataset,
    desc='Generating representations'
)

  self.pid = os.fork()
Generating representations: 100%|██████████| 3125/3125 [05:46<00:00,  9.02it/s]


In [15]:
data_size, dims = reps.shape
num_clusters = len(dataset.classes)

cluster_ids_x, cluster_centers = kmeans(X=reps,
                                        num_clusters=num_clusters,
                                        distance='euclidean',
                                        device=device,
                                        tol=1e-5)

running k-means on cuda..


[running kmeans]: 83it [04:16,  3.09s/it, center_shift=0.000005, iteration=83, tol=0.000010]


In [18]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.9
                                              )

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 1997.71it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [19]:
save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_90/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

In [20]:
!find cifar10_data -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

cifar10_data/airplane: 5000
cifar10_data/automobile: 5000
cifar10_data/bird: 5000
cifar10_data/cat: 5000
cifar10_data/deer: 5000
cifar10_data/dog: 5000
cifar10_data/frog: 5000
cifar10_data/horse: 5000
cifar10_data/ship: 5000
cifar10_data/truck: 5000


In [21]:
!find hard_samples_90 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples/airplane: 480
hard_samples/automobile: 439
hard_samples/bird: 475
hard_samples/cat: 529
hard_samples/deer: 705
hard_samples/dog: 429
hard_samples/frog: 355
hard_samples/horse: 560
hard_samples/ship: 502
hard_samples/truck: 529


### Quantile 0.5

In [30]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.5
                                              )

save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_50/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 1797.03it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [31]:
!find hard_samples_50 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples_50/airplane: 2496
hard_samples_50/automobile: 2374
hard_samples_50/bird: 2585
hard_samples_50/cat: 2542
hard_samples_50/deer: 3131
hard_samples_50/dog: 2372
hard_samples_50/frog: 1896
hard_samples_50/horse: 2615
hard_samples_50/ship: 2532
hard_samples_50/truck: 2455


### Quantile 0.1

In [25]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.1
                                              )

save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_10/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████| 3125/3125 [00:01<00:00, 2063.33it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [27]:
!find hard_samples_10 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples_10/airplane: 4498
hard_samples_10/automobile: 4459
hard_samples_10/bird: 4608
hard_samples_10/cat: 4481
hard_samples_10/deer: 4783
hard_samples_10/dog: 4456
hard_samples_10/frog: 4212
hard_samples_10/horse: 4506
hard_samples_10/ship: 4493
hard_samples_10/truck: 4501


### Quantile 20

In [32]:
farthest_samples_paths = find_desired_samples(reps=reps,
                                                indices=indices,
                                                labels=labels,
                                                base_dataset=dataset,
                                                target_dataset=dataset,
                                                cluster_centers=cluster_centers,
                                                cluster_ids_x=cluster_ids_x,
                                                quantile=0.2
                                              )

save_outputs(dst_path='/content/drive/MyDrive/GR/hard_samples_20/',
                 dataset=dataset,
                 farthest_samples_paths=farthest_samples_paths)

Calculating norms: 100%|██████████| 3125/3125 [00:02<00:00, 1243.69it/s]

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level function with
similar flexibility) or `histplot` (an axes-level function for histograms).

For a guide to updating your code to use the new functions, please see
https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751

  sns.distplot(scores)

`distplot` is a deprecated function and will be removed in seaborn v0.14.0.

Please adapt your code to use either `displot` (a figure-level fu

<Figure size 640x480 with 0 Axes>

In [34]:
!find hard_samples_20 -mindepth 1 -type d -exec sh -c 'echo -n "{}: "; ls -1 "{}" | wc -l' \;

hard_samples_20/airplane: 3991
hard_samples_20/automobile: 3912
hard_samples_20/bird: 4135
hard_samples_20/cat: 3989
hard_samples_20/deer: 4472
hard_samples_20/dog: 3940
hard_samples_20/frog: 3521
hard_samples_20/horse: 4014
hard_samples_20/ship: 4000
hard_samples_20/truck: 4022


In [36]:
!pip install wandb -qU

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.8/266.8 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [37]:
import torch
import torchvision
import time
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import wandb

In [38]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [39]:
prune_frac = "10"
wandb.init(project='cifar10_pruning', name='pruned-neural-scaling-'+prune_frac)

[34m[1mwandb[0m: Currently logged in as: [33mprasanga[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_folder = "/content/drive/MyDrive/GR/hard_samples_"+ prune_frac

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

# Load the CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:    # Print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

Files already downloaded and verified
Input shape: torch.Size([128, 3, 32, 32])
[1,   200] loss: 1.586
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test Input shape: torch.Size([100, 3, 32, 32])
test

### Prune 20

In [None]:
prune_frac = "20"
wandb.init(project='cifar10_pruning', name='pruned-neural-scaling-'+prune_frac)
train_folder = "/content/drive/MyDrive/GR/hard_samples_"+ prune_frac

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:    # Print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})

### Prune 50

In [None]:
prune_frac = "50"
wandb.init(project='cifar10_pruning', name='pruned-neural-scaling-'+prune_frac)
train_folder = "/content/drive/MyDrive/GR/hard_samples_"+ prune_frac

trainset = datasets.ImageFolder(train_folder, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

# Initialize the ConvNet model
net = ResNet18().to(device)
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

# Train the ConvNet
start_time = time.time()
for epoch in range(30):  # Adjust the number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data


        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:    # Print every 200 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            wandb.log({"Loss": running_loss / 200}, step=epoch * len(trainloader) + i)
            # Log the loss to wandb, so that we can visualize it
            running_loss = 0.0
            step_val = epoch * len(trainloader) + i + 1

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('[epoch:%d,  accuracy: %.3f' % (epoch + 1, accuracy))
    wandb.log({"Accuracy": accuracy}, step=step_val)

    scheduler.step()


end_time = time.time()
training_time = end_time - start_time

# Test the ConvNet on the test set
correct = 0
total = 0
top5_correct = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        _, top5_predicted = torch.topk(outputs.data, 5, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        top5_correct += (top5_predicted == labels.view(-1, 1)).sum().item()

accuracy = 100 * correct / total
top5_accuracy = 100 * top5_correct / total
wandb.log({"Final-Accuracy": accuracy, "Top-5 Accuracy": top5_accuracy, "Training Time": training_time})