In [1]:
import numpy as np
import torch
from torchvision.transforms import transforms
from FastAutoAugment.train import GroupAugloader
from FastAutoAugment.data import get_dataloaders, GrAugCIFAR10, CutoutDefault
from FastAutoAugment.group_search import assign_group
from FastAutoAugment.archive import arsaug_policy, autoaug_policy, autoaug_paper_cifar10, fa_reduced_cifar10, fa_reduced_svhn, fa_resnet50_rimagenet
from theconf import Config as C, ConfigArgumentParser
from time import time

In [2]:
dataset = "cifar10"
batch = 128
dataroot = "/home/server32/data/"
gr_assign = assign_group
_CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)

In [3]:
C.get()["cutout"] = 16
C.get()["aug"] = \
{0: [[['TranslateX', 0.8933767594939269, 0.09389941707439498],
   ['Equalize', 0.009283253282199254, 0.6216709279401148]],
  [['TranslateX', 0.6042254125178238, 0.3840147802903179],
   ['Invert', 0.5884490988202267, 0.26983244239417226]],
  [['TranslateX', 0.86623214046742, 0.8015893993849521],
   ['Cutout', 0.7629023447855925, 0.18040604440693134]],
  [['Color', 0.3242635884647428, 0.8193831710695015],
   ['AutoContrast', 0.8617900583601097, 0.2745968792879573]],
  [['Cutout', 0.9932170309667162, 0.34739437256963557],
   ['Posterize', 0.22013736848753707, 0.6423520078623738]]],
 1: [[['ShearX', 0.9527555981799529, 0.840939180847609],
   ['Brightness', 0.8812351461652247, 0.08476267857828268]],
  [['Rotate', 0.2329661175045321, 0.37779148571576426],
   ['TranslateY', 0.5042632755940204, 0.45960105920394156]],
  [['Solarize', 0.5026188423517557, 0.6115728014024047],
   ['ShearY', 0.3390709029909095, 0.002670085542884504]],
  [['Cutout', 0.9991400014995132, 0.8739501611807547],
   ['Contrast', 0.7952676665284365, 0.8408743613281001]],
  [['Equalize', 0.011156446032402956, 0.13099359833583812],
   ['Color', 0.43931113219589335, 0.6656400172139114]]],
 2: [[['TranslateX', 0.6513948638994308, 0.6204299285398741],
   ['Posterize', 0.7842262934822425, 0.25273070431283734]],
  [['ShearY', 0.5609290367869623, 0.9452655885485695],
   ['AutoContrast', 0.003634851336694944, 0.5115370485649638]],
  [['Contrast', 0.0536986696863798, 0.997987299785687],
   ['Cutout', 0.9333367903585464, 0.5098100423108692]],
  [['Posterize', 0.5047403532237664, 0.8196327208059964],
   ['Color', 0.46892732660093384, 0.4726497928422181]]],
 3: [[['AutoContrast', 0.7213610218628191, 0.8416505099325328],
   ['Rotate', 0.3682722275983766, 0.4179430092062895]],
  [['Posterize', 0.079407408927398, 0.42541337413692293],
   ['Contrast', 0.46058986564574006, 0.5817370708138572]],
  [['Contrast', 0.7549878468501052, 0.4429642761204333],
   ['Posterize', 0.6097943846630738, 0.41517079787104294]],
  [['Cutout', 0.6737586712335524, 0.29603717863325374],
   ['Posterize', 0.5305135986372213, 0.28136589567182224]],
  [['Sharpness', 0.7334435795732214, 0.7550440511925695],
   ['AutoContrast', 0.2499262684273169, 0.5822370638832948]]],
 4: [[['Brightness', 0.25197963676329305, 0.5993017712343321],
   ['TranslateY', 0.6276663840325515, 0.17444942645475542]],
  [['Sharpness', 0.2501497210502566, 0.9994546724708385],
   ['Contrast', 0.1678019343332606, 0.804693027679855]],
  [['Equalize', 0.53438204404003, 0.22042762708232622],
   ['AutoContrast', 0.6228907795771712, 0.9983253108022203]],
  [['ShearY', 0.730964431113039, 0.500783309172612],
   ['Rotate', 0.28839704924929926, 0.8669441267961313]],
  [['TranslateY', 0.05479562529503221, 0.7442538413000629],
   ['AutoContrast', 0.31896164940398547, 0.9264790385929341]]]}

In [4]:
dataloaders1 = get_dataloaders(dataset, batch, dataroot, split=0., split_idx=0, gr_assign=gr_assign)
trainsampler1, trainloader1, validloader1, testloader1 = dataloaders1
trainloader1 = GroupAugloader(trainloader1, gr_assign, C.get()["aug"])

[2020-09-28 00:34:57,757] [Fast AutoAugment] [DEBUG] group augmentation provided.


In [5]:
transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
        ])
transform_train.transforms.append(CutoutDefault(16))
total_trainset = GrAugCIFAR10(root=dataroot, gr_assign=gr_assign, gr_policies=C.get()["aug"], train=True, download=False, transform=transform_train)
train_sampler = None
trainloader3 = torch.utils.data.DataLoader(
        total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=8, pin_memory=True,
        sampler=train_sampler, drop_last=True)

In [6]:
C.get()["aug"] = fa_reduced_cifar10()
dataloaders2 = get_dataloaders(dataset, batch, dataroot, split=0., split_idx=0, gr_assign=None)
trainsampler2, trainloader2, validloader2, testloader2 = dataloaders2

[2020-09-28 00:34:59,376] [Fast AutoAugment] [DEBUG] augmentation provided.


In [7]:
t1 = time()
for image2, label2 in trainloader2:
    pass
t2 = time()
print(t2-t1)

6.504937171936035


In [8]:
t1 = time()
for image1, label1 in trainloader1:
    pass
t2 = time()
print(t2-t1)

29.220066785812378


In [9]:
t1 = time()
for image3, label3 in trainloader3:
    pass
t2 = time()
print(t2-t1)

4.020648717880249


In [10]:
from torchvision.utils import save_image
save_image(image1, "tmp1.png")
save_image(image2, "tmp2.png")

In [11]:
t1 = time()
for image2, label2 in trainloader2.dataset:
    pass
t2 = time()
print(t2-t1)

23.654066801071167


In [12]:
t1 = time()
for image1, label1 in trainloader1.dataloader.dataset:
    pass
t2 = time()
print(t2-t1)

3.5268869400024414


In [13]:
t1 = time()
for image2, label2 in trainloader3.dataset:
    pass
t2 = time()
print(t2-t1)

23.589488983154297
