In [13]:
import argparse
import os
import time
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, ConcatDataset
from torchvision import models
import timm # state-of-the-art models (e.g. vit...)

from tqdm import tqdm

In [3]:
!nvidia-smi

Fri Oct 25 01:13:07 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A40                     On  |   00000000:01:00.0 Off |                    0 |
|  0%   25C    P8             24W /  300W |       1MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     On  |   00

In [4]:
!conda info --envs

# conda environments:
#
base                  *  /coc/scratch/debopam/env
MIA                      /coc/scratch/debopam/env/envs/MIA
adversarial-robustness-toolbox     /coc/scratch/debopam/env/envs/adversarial-robustness-toolbox
bild                     /coc/scratch/debopam/env/envs/bild
clipper                  /coc/scratch/debopam/env/envs/clipper
clockwork                /coc/scratch/debopam/env/envs/clockwork
clockwork-new            /coc/scratch/debopam/env/envs/clockwork-new
erdos                    /coc/scratch/debopam/env/envs/erdos
flame                    /coc/scratch/debopam/env/envs/flame
graphing                 /coc/scratch/debopam/env/envs/graphing
psml                     /coc/scratch/debopam/env/envs/psml
ray_server               /coc/scratch/debopam/env/envs/ray_server
snnet                    /coc/scratch/debopam/env/envs/snnet
snnet_manav              /coc/scratch/debopam/env/envs/snnet_manav
testing_pruning          /coc/scratch/debopam/env/envs/tes

In [6]:
SEED = 1583745484
pl.seed_everything(SEED)

Seed set to 1583745484


1583745484

In [10]:
DATA_DIR = '/serenity/scratch/psml/data/ILSVRC2012'

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## for validation set

In [21]:


imagenet = datasets.ImageNet(root=DATA_DIR, split='val', transform=transform)
train_ds, test_ds = random_split(imagenet, [0.8, 0.2])

In [8]:
imagenet

Dataset ImageNet
    Number of datapoints: 50000
    Root location: /serenity/scratch/psml/data/ILSVRC2012
    Split: val
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [22]:
len(train_ds)

40000

In [23]:
len(test_ds)

10000

In [29]:
imagenet = datasets.ImageNet(root=DATA_DIR, split='val', transform=transform)
train_ds, test_ds = random_split(imagenet, [0.8, 0.2])

targets_ = [train_ds.dataset.targets[i] for i in train_ds.indices]
targets_ = np.array(targets_)

array([ 66, 833, 788, ..., 757, 146, 521])

In [32]:
np.max(targets_)

999

In [26]:
n_shadows = 64
size = len(train_ds)
pkeep = 0.5 
shadow_id = 6

np.random.seed(SEED)
keep = np.random.uniform(0, 1, size=(n_shadows, size))
order = keep.argsort(0)
keep = order < int(pkeep * n_shadows)
keep = np.array(keep[shadow_id], dtype=bool)
keep = keep.nonzero()[0]

In [29]:
keep.shape

(19910,)

In [30]:
labels = [imagenet[i][1] for i in range(len(imagenet))]
class_counts = Counter(labels)

print(class_counts)

KeyboardInterrupt: 

In [44]:
labels = [train_ds[i][1] for i in range(len(train_ds))]
class_counts = Counter(labels)

print(class_counts)

NameError: name 'Counter' is not defined

In [45]:
labels

[363,
 838,
 785,
 701,
 610,
 757,
 993,
 101,
 345,
 294,
 554,
 289,
 905,
 546,
 127,
 494,
 570,
 108,
 884,
 18,
 660,
 42,
 153,
 763,
 725,
 661,
 869,
 63,
 888,
 718,
 912,
 654,
 943,
 586,
 514,
 480,
 10,
 396,
 621,
 988,
 175,
 759,
 17,
 13,
 162,
 575,
 349,
 672,
 371,
 758,
 636,
 370,
 268,
 316,
 473,
 400,
 488,
 325,
 532,
 614,
 689,
 30,
 465,
 183,
 97,
 193,
 127,
 103,
 139,
 144,
 264,
 326,
 803,
 624,
 839,
 102,
 768,
 86,
 779,
 983,
 663,
 795,
 791,
 633,
 319,
 699,
 821,
 211,
 827,
 566,
 946,
 858,
 430,
 756,
 203,
 380,
 832,
 739,
 861,
 404,
 213,
 291,
 424,
 917,
 421,
 222,
 354,
 354,
 774,
 100,
 255,
 562,
 473,
 896,
 715,
 170,
 346,
 857,
 951,
 388,
 728,
 815,
 217,
 508,
 855,
 463,
 819,
 426,
 623,
 809,
 879,
 972,
 631,
 695,
 57,
 418,
 702,
 774,
 596,
 624,
 878,
 807,
 443,
 542,
 836,
 613,
 894,
 628,
 66,
 361,
 249,
 619,
 745,
 773,
 476,
 809,
 949,
 972,
 192,
 3,
 660,
 828,
 513,
 961,
 345,
 985,
 90,
 998,
 992,


## for training set

In [46]:
imagenet_ = datasets.ImageNet(root='/serenity/scratch/psml/data/ILSVRC2012', split='train', transform=transform)

In [51]:
T1_SIZE = 40000
T2_SIZE = 40000
OTHERS_SIZE = len(imagenet_) - T1_SIZE - T2_SIZE

T1, T2, _ = random_split(imagenet_, [T1_SIZE, T2_SIZE, OTHERS_SIZE])

In [47]:
imagenet_

Dataset ImageNet
    Number of datapoints: 1281167
    Root location: /serenity/scratch/psml/data/ILSVRC2012
    Split: train
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [52]:
len(T1)

40000

In [58]:
T1.indices

[475994,
 991579,
 279253,
 1113665,
 1061223,
 780724,
 356204,
 274112,
 294003,
 200519,
 409446,
 930179,
 165375,
 953686,
 1051205,
 1034824,
 1006158,
 342886,
 266850,
 85731,
 504803,
 1228127,
 1257092,
 209905,
 767982,
 128677,
 861196,
 1174381,
 56519,
 281088,
 787991,
 142596,
 1142400,
 756841,
 774292,
 888563,
 171122,
 738137,
 435342,
 1143031,
 94953,
 100415,
 9471,
 751313,
 817865,
 201134,
 581443,
 836191,
 628339,
 722662,
 1020551,
 919390,
 434434,
 708318,
 664468,
 738052,
 595081,
 716771,
 58895,
 1058905,
 1223750,
 466535,
 55731,
 41591,
 1257935,
 796576,
 280742,
 833559,
 509615,
 1203303,
 578195,
 852176,
 187518,
 723933,
 708370,
 761856,
 487983,
 935134,
 311142,
 396723,
 919627,
 857668,
 992322,
 1086502,
 686042,
 894082,
 267494,
 1132159,
 367477,
 850359,
 226651,
 314080,
 1232316,
 507004,
 877226,
 861969,
 235531,
 312303,
 1046490,
 875375,
 170387,
 781259,
 730595,
 1155631,
 278197,
 950459,
 1017987,
 724099,
 497758,
 12767

In [56]:
len(T2)

40000

In [55]:
T2.indices

[99054,
 561533,
 268692,
 1020652,
 721898,
 673526,
 215483,
 662742,
 490875,
 762736,
 211745,
 1001833,
 443634,
 590595,
 552507,
 640105,
 815621,
 689667,
 865696,
 1233211,
 556229,
 1079667,
 450235,
 1215998,
 270713,
 1120288,
 419445,
 709633,
 363599,
 369862,
 1079477,
 79695,
 664352,
 878716,
 823172,
 433557,
 355142,
 342912,
 744025,
 139638,
 1171744,
 230452,
 178416,
 145652,
 785630,
 337329,
 7742,
 932003,
 753841,
 167489,
 393941,
 674070,
 884389,
 27786,
 251158,
 1205953,
 710651,
 343605,
 218597,
 1003892,
 175615,
 1073692,
 303742,
 658427,
 16276,
 939148,
 128793,
 651048,
 551305,
 447891,
 414138,
 809143,
 603727,
 563918,
 935181,
 510084,
 344761,
 962300,
 435588,
 500936,
 123711,
 519767,
 604306,
 1186752,
 371931,
 918515,
 318932,
 166054,
 138641,
 578489,
 260530,
 1160867,
 879727,
 1223011,
 546131,
 755616,
 448669,
 465499,
 316756,
 803882,
 190546,
 1259045,
 939128,
 64224,
 834397,
 792357,
 441209,
 1008885,
 809549,
 723138,
 

In [None]:
# train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4)
# test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4)

# for impl, 

In [11]:
imagenet_t = datasets.ImageNet(root=DATA_DIR, split='train', transform=transform)
T_train_SIZE = 20000
T_eval_SIZE = 5000
T_others_SIZE = len(imagenet_t) - T_train_SIZE - T_eval_SIZE
T_train, T_eval, _ = random_split(imagenet_t, [T_train_SIZE, T_eval_SIZE, T_others_SIZE])

imagenet_s = datasets.ImageNet(root=DATA_DIR, split='val', transform=transform)
S_train_SIZE = 20000
S_eval_SIZE = 5000
S_others_SIZE = len(imagenet_s) - S_train_SIZE - S_eval_SIZE
S_train, S_eval, _ = random_split(imagenet_s, [S_train_SIZE, S_eval_SIZE, S_others_SIZE])


In [33]:
train_ds = ConcatDataset([T_train, S_train])
test_ds = ConcatDataset([T_eval, S_eval])

In [47]:
train_ds

<torch.utils.data.dataset.ConcatDataset at 0x7f43d90c6980>

In [44]:
img, label = train_ds.datasets[0][50]

In [46]:
label

410

In [53]:
def extract_labels(dataset):
    labels = []
    
    for i in tqdm(range(len(dataset)), desc=f'iterating {dataset}'):
        _, label = dataset[i]
        labels.append(label)
        
    return labels

In [54]:
labels = []

for dataset in train_ds.datasets:
    labels.extend(extract_labels(dataset))

iterating <torch.utils.data.dataset.Subset object at 0x7f40ddbf3c40>: 100%|████| 20000/20000 [01:44<00:00, 190.55it/s]
iterating <torch.utils.data.dataset.Subset object at 0x7f46a3f859c0>: 100%|████| 20000/20000 [01:48<00:00, 183.64it/s]


In [55]:
labels


[622,
 305,
 805,
 223,
 718,
 560,
 680,
 252,
 34,
 581,
 70,
 171,
 784,
 903,
 625,
 287,
 989,
 496,
 703,
 427,
 615,
 875,
 79,
 732,
 21,
 184,
 823,
 601,
 66,
 228,
 448,
 502,
 866,
 905,
 909,
 665,
 188,
 73,
 96,
 605,
 660,
 553,
 944,
 173,
 141,
 477,
 365,
 725,
 404,
 442,
 410,
 771,
 51,
 1,
 835,
 898,
 699,
 718,
 705,
 670,
 30,
 263,
 354,
 900,
 610,
 262,
 510,
 203,
 8,
 600,
 68,
 736,
 191,
 349,
 536,
 98,
 499,
 289,
 953,
 712,
 629,
 663,
 708,
 433,
 680,
 222,
 273,
 680,
 883,
 144,
 380,
 779,
 451,
 960,
 101,
 162,
 569,
 948,
 212,
 408,
 445,
 617,
 711,
 960,
 143,
 611,
 263,
 935,
 622,
 575,
 265,
 101,
 138,
 435,
 994,
 410,
 607,
 70,
 12,
 687,
 656,
 113,
 742,
 465,
 939,
 131,
 876,
 742,
 34,
 589,
 729,
 747,
 434,
 586,
 728,
 819,
 479,
 386,
 11,
 822,
 506,
 546,
 514,
 432,
 55,
 967,
 850,
 204,
 563,
 415,
 366,
 687,
 574,
 548,
 560,
 274,
 415,
 397,
 443,
 809,
 184,
 921,
 476,
 523,
 812,
 305,
 4,
 544,
 413,
 500,
 3

In [17]:
len(test_ds)

10000