In [None]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import pickle as pkl

# CIFAR10

In [8]:
normalize = transforms.Normalize(mean= (0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
transform = transforms.Compose([transforms.ToTensor(), normalize])

traindataset = datasets.CIFAR10('../../dataset', train=True, download=True,transform=transform)
train_loader = DataLoader(traindataset, batch_size=1, shuffle=False, num_workers=8)

Files already downloaded and verified


In [11]:
label2inds = defaultdict(lambda: [])
for index, (_, label) in enumerate(tqdm(train_loader)):
    label2inds[label.item()].append(index)
label2len = {kk: len(vv) for kk, vv in label2inds.items()}
label2len

  0%|          | 0/50000 [00:00<?, ?it/s]  0%|          | 1/50000 [00:00<5:49:27,  2.38it/s]  0%|          | 3/50000 [00:00<2:02:24,  6.81it/s]  0%|          | 6/50000 [00:00<1:07:06, 12.42it/s]  0%|          | 51/50000 [00:00<06:34, 126.51it/s]  0%|          | 130/50000 [00:00<02:47, 297.49it/s]  0%|          | 197/50000 [00:00<02:05, 396.54it/s]  1%|          | 272/50000 [00:01<01:40, 494.12it/s]  1%|          | 328/50000 [00:01<01:45, 470.13it/s]  1%|          | 414/50000 [00:01<01:26, 576.04it/s]  1%|          | 477/50000 [00:01<01:35, 521.28it/s]  1%|          | 534/50000 [00:01<01:35, 519.40it/s]  1%|          | 589/50000 [00:01<01:42, 483.64it/s]  1%|▏         | 643/50000 [00:01<01:39, 493.61it/s]  1%|▏         | 695/50000 [00:01<01:45, 468.07it/s]  1%|▏         | 746/50000 [00:02<01:42, 478.77it/s]  2%|▏         | 811/50000 [00:02<01:35, 517.61it/s]  2%|▏         | 864/50000 [00:02<01:35, 514.04it/s]  2%|▏         | 945/50000 [00:02<01:23, 585.19it/s]  2%|▏  

{6: 5000,
 9: 5000,
 4: 5000,
 1: 5000,
 2: 5000,
 7: 5000,
 8: 5000,
 3: 5000,
 5: 5000,
 0: 5000}

## sample 1% and 10%

In [None]:
ratio = 1
np.random.seed(42)
index_1 = []
for label, inds in label2inds.items():
    len_tgt = int(round(len(inds) * ratio / 100, 0))
    lst = np.random.choice(inds, len_tgt)
    index_1.extend(list(lst))
np.array(index_1)

In [20]:
pkl.dump(index_1, open('sample/cifar10_1.pkl', 'wb'))

In [25]:
ratio = 10
np.random.seed(42)
index_10 = []
for label, inds in label2inds.items():
    len_tgt = int(round(len(inds) * ratio / 100, 0))
    lst = np.random.choice(inds, len_tgt)
    index_10.extend(list(lst))
np.array(index_10)

array([ 8335, 37579, 30906, ..., 39628,  3063,  9059])

In [26]:
pkl.dump(index_10, open('sample/cifar10_10.pkl', 'wb'))

## 验证是否平衡

In [None]:
from cifar_sample import CIFAR10
new_data = CIFAR10('../../dataset', train=True, download=True,transform=transform, sample_id_path='sample/cifar10_10.pkl')
new_loader = DataLoader(new_data, batch_size=1, shuffle=False, num_workers=8)

In [28]:
nwe_label2inds = defaultdict(lambda: [])
for index, (_, label) in enumerate(tqdm(new_loader)):
    nwe_label2inds[label.item()].append(index)
new_label2len = {kk: len(vv) for kk, vv in nwe_label2inds.items()}
new_label2len

  0%|          | 0/5000 [00:00<?, ?it/s]  0%|          | 1/5000 [00:00<30:26,  2.74it/s]  0%|          | 19/5000 [00:00<01:35, 52.43it/s]  1%|▏         | 74/5000 [00:00<00:25, 194.49it/s]  2%|▏         | 105/5000 [00:00<00:22, 215.72it/s]  3%|▎         | 160/5000 [00:00<00:15, 306.91it/s]  5%|▍         | 228/5000 [00:00<00:11, 409.44it/s]  6%|▌         | 307/5000 [00:00<00:09, 517.60it/s]  7%|▋         | 365/5000 [00:01<00:09, 497.52it/s]  8%|▊         | 419/5000 [00:01<00:09, 483.74it/s] 10%|▉         | 488/5000 [00:01<00:08, 539.90it/s] 11%|█         | 549/5000 [00:01<00:08, 549.30it/s] 13%|█▎        | 630/5000 [00:01<00:07, 621.82it/s] 14%|█▍        | 694/5000 [00:01<00:07, 610.74it/s] 15%|█▌        | 757/5000 [00:01<00:07, 599.68it/s] 17%|█▋        | 828/5000 [00:01<00:06, 629.89it/s] 18%|█▊        | 913/5000 [00:01<00:05, 693.13it/s] 20%|█▉        | 984/5000 [00:02<00:06, 611.89it/s] 21%|██        | 1048/5000 [00:02<00:07, 536.52it/s] 23%|██▎       | 1140/5000 [

{6: 500,
 9: 500,
 4: 500,
 1: 500,
 2: 500,
 7: 500,
 8: 500,
 3: 500,
 5: 500,
 0: 500}

# CIFAR100

## loader

In [29]:
normalize = transforms.Normalize(mean= (0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
transform = transforms.Compose([transforms.ToTensor(), normalize])

traindataset = datasets.CIFAR100('../../dataset', train=True, download=True,transform=transform)
train_loader = DataLoader(traindataset, batch_size=1, shuffle=False, num_workers=8)

Files already downloaded and verified


In [30]:
label2inds = defaultdict(lambda: [])
for index, (_, label) in enumerate(tqdm(train_loader)):
    label2inds[label.item()].append(index)
label2len = {kk: len(vv) for kk, vv in label2inds.items()}
label2len

  0%|          | 0/50000 [00:00<?, ?it/s]  0%|          | 1/50000 [00:00<4:29:45,  3.09it/s]  0%|          | 7/50000 [00:00<41:35, 20.04it/s]    0%|          | 80/50000 [00:00<03:35, 231.16it/s]  0%|          | 133/50000 [00:00<02:36, 319.65it/s]  0%|          | 175/50000 [00:00<02:27, 336.93it/s]  0%|          | 243/50000 [00:00<01:54, 436.21it/s]  1%|          | 293/50000 [00:00<01:56, 426.85it/s]  1%|          | 340/50000 [00:01<01:59, 414.79it/s]  1%|          | 408/50000 [00:01<01:43, 479.67it/s]  1%|          | 473/50000 [00:01<01:34, 524.81it/s]  1%|          | 528/50000 [00:01<01:33, 528.22it/s]  1%|          | 599/50000 [00:01<01:25, 575.27it/s]  1%|▏         | 661/50000 [00:01<01:25, 578.10it/s]  1%|▏         | 720/50000 [00:01<01:29, 551.50it/s]  2%|▏         | 792/50000 [00:01<01:22, 598.53it/s]  2%|▏         | 853/50000 [00:01<01:24, 584.90it/s]  2%|▏         | 932/50000 [00:02<01:17, 636.93it/s]  2%|▏         | 1011/50000 [00:02<01:12, 674.03it/s]  2%|▏

{19: 500,
 29: 500,
 0: 500,
 11: 500,
 1: 500,
 86: 500,
 90: 500,
 28: 500,
 23: 500,
 31: 500,
 39: 500,
 96: 500,
 82: 500,
 17: 500,
 71: 500,
 8: 500,
 97: 500,
 80: 500,
 74: 500,
 59: 500,
 70: 500,
 87: 500,
 84: 500,
 64: 500,
 52: 500,
 42: 500,
 47: 500,
 65: 500,
 21: 500,
 22: 500,
 81: 500,
 24: 500,
 78: 500,
 45: 500,
 49: 500,
 56: 500,
 76: 500,
 89: 500,
 73: 500,
 14: 500,
 9: 500,
 6: 500,
 20: 500,
 98: 500,
 36: 500,
 55: 500,
 72: 500,
 43: 500,
 51: 500,
 35: 500,
 83: 500,
 33: 500,
 27: 500,
 53: 500,
 92: 500,
 50: 500,
 15: 500,
 18: 500,
 46: 500,
 75: 500,
 38: 500,
 66: 500,
 77: 500,
 69: 500,
 95: 500,
 99: 500,
 93: 500,
 4: 500,
 61: 500,
 94: 500,
 68: 500,
 34: 500,
 32: 500,
 88: 500,
 67: 500,
 30: 500,
 62: 500,
 63: 500,
 40: 500,
 26: 500,
 48: 500,
 79: 500,
 85: 500,
 54: 500,
 44: 500,
 7: 500,
 12: 500,
 2: 500,
 41: 500,
 37: 500,
 13: 500,
 25: 500,
 10: 500,
 57: 500,
 5: 500,
 60: 500,
 91: 500,
 3: 500,
 58: 500,
 16: 500}

## sample 1% 10%

In [None]:
ratio = 1
np.random.seed(42)
index_1 = []
for label, inds in label2inds.items():
    len_tgt = int(round(len(inds) * ratio / 100, 0))
    lst = np.random.choice(inds, len_tgt)
    index_1.extend(list(lst))
np.array(index_1)

In [32]:
pkl.dump(index_1, open('sample/cifar100_1.pkl', 'wb'))

In [35]:
ratio = 10
np.random.seed(42)
index_10 = []
for label, inds in label2inds.items():
    len_tgt = int(round(len(inds) * ratio / 100, 0))
    lst = np.random.choice(inds, len_tgt)
    index_10.extend(list(lst))
np.array(index_10)

array([ 9502, 42182, 33564, ..., 26346, 26710, 44376])

In [36]:
pkl.dump(index_10, open('sample/cifar100_10.pkl', 'wb'))

## 验证是否正确

In [None]:
from cifar_sample import CIFAR100
new_data = CIFAR100('../../dataset', train=True, download=True,transform=transform, sample_id_path='sample/cifar100_10.pkl')
new_loader = DataLoader(new_data, batch_size=1, shuffle=True, num_workers=8)

In [38]:
nwe_label2inds = defaultdict(lambda: [])
for index, (_, label) in enumerate(tqdm(new_loader)):
    nwe_label2inds[label.item()].append(index)
new_label2len = {kk: len(vv) for kk, vv in nwe_label2inds.items()}
new_label2len

  0%|          | 0/5000 [00:00<?, ?it/s]  0%|          | 1/5000 [00:00<28:05,  2.97it/s]  0%|          | 4/5000 [00:00<08:51,  9.40it/s]  1%|▏         | 67/5000 [00:00<00:28, 172.13it/s]  2%|▏         | 117/5000 [00:00<00:19, 255.61it/s]  4%|▎         | 183/5000 [00:00<00:13, 359.05it/s]  5%|▍         | 231/5000 [00:00<00:12, 388.64it/s]  6%|▌         | 291/5000 [00:01<00:10, 447.49it/s]  7%|▋         | 358/5000 [00:01<00:09, 510.91it/s]  8%|▊         | 425/5000 [00:01<00:08, 555.01it/s] 10%|▉         | 484/5000 [00:01<00:08, 560.48it/s] 11%|█         | 543/5000 [00:01<00:08, 504.55it/s] 12%|█▏        | 611/5000 [00:01<00:07, 551.84it/s] 14%|█▎        | 686/5000 [00:01<00:07, 605.29it/s] 15%|█▌        | 771/5000 [00:01<00:06, 670.41it/s] 17%|█▋        | 840/5000 [00:01<00:07, 558.97it/s] 18%|█▊        | 901/5000 [00:02<00:07, 542.49it/s] 19%|█▉        | 964/5000 [00:02<00:07, 561.16it/s] 20%|██        | 1023/5000 [00:02<00:07, 539.87it/s] 22%|██▏       | 1079/5000 [0

{40: 50,
 22: 50,
 56: 50,
 83: 50,
 4: 50,
 78: 50,
 42: 50,
 59: 50,
 55: 50,
 69: 50,
 17: 50,
 57: 50,
 19: 50,
 90: 50,
 15: 50,
 36: 50,
 99: 50,
 12: 50,
 39: 50,
 24: 50,
 20: 50,
 64: 50,
 60: 50,
 43: 50,
 76: 50,
 3: 50,
 27: 50,
 8: 50,
 1: 50,
 81: 50,
 2: 50,
 79: 50,
 92: 50,
 6: 50,
 97: 50,
 87: 50,
 9: 50,
 50: 50,
 33: 50,
 51: 50,
 94: 50,
 18: 50,
 46: 50,
 80: 50,
 67: 50,
 34: 50,
 53: 50,
 25: 50,
 75: 50,
 93: 50,
 58: 50,
 68: 50,
 49: 50,
 26: 50,
 98: 50,
 63: 50,
 73: 50,
 71: 50,
 38: 50,
 37: 50,
 66: 50,
 16: 50,
 48: 50,
 96: 50,
 29: 50,
 35: 50,
 52: 50,
 0: 50,
 5: 50,
 14: 50,
 32: 50,
 82: 50,
 61: 50,
 41: 50,
 21: 50,
 84: 50,
 89: 50,
 11: 50,
 72: 50,
 45: 50,
 88: 50,
 77: 50,
 74: 50,
 7: 50,
 70: 50,
 10: 50,
 28: 50,
 65: 50,
 91: 50,
 95: 50,
 62: 50,
 23: 50,
 31: 50,
 30: 50,
 47: 50,
 86: 50,
 54: 50,
 85: 50,
 44: 50,
 13: 50}