In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import torchvision
import timm
import kornia
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

class NormalizedModel(torch.nn.Module):

    def __init__(self, model, mean, std):
        super(NormalizedModel, self).__init__()
        self.model = model
        self.mean = torch.nn.Parameter(torch.Tensor(mean).view(-1, 1, 1), requires_grad=False)
        self.std = torch.nn.Parameter(torch.Tensor(std).view(-1, 1, 1), requires_grad=False)

    def forward(self, x):
        out = (x - self.mean) / self.std 
        out = self.model(out)
        return out

In [5]:
dataset = torchvision.datasets.ImageNet("/home/SSD/ImageNet/", 
                                        split="val", 
                                        transform=torchvision.transforms.Compose(
                                            [
                                                torchvision.transforms.Resize(256), 
                                                torchvision.transforms.CenterCrop(224), 
                                                torchvision.transforms.ToTensor()
                                            ]
                                        )
                                       )

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True, num_workers=1)

for x, y in (dataloader):
    bx = x.cuda()
    by = y.cuda()
    break
    
    
model = timm.create_model("vgg16_bn", pretrained=True)
model = NormalizedModel(model, model.default_cfg.get("mean"), model.default_cfg.get("std"))
model.cuda()
model.eval()
pass

Downloading: "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth" to /home/pgavrikov/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth


In [6]:
correct = 0
total = 0

with torch.no_grad():
    
    for x, y in tqdm(dataloader):
        bx = x.cuda()
        by = y.cuda()

        a = torch.empty(len(by))
        is_correct = torch.ones(len(by)).bool().cuda()

        for _ in range(10):
            a.data[is_correct] = torch.tensor(np.random.uniform(0, 1, len(a.data[is_correct]))).float()

            x_aug = kornia.enhance.solarize(bx, a)
            logits = model(x_aug)
            is_correct.data = (logits.argmax(dim=1) == by).detach()
            acc = is_correct.float().mean().item()
        
        correct += is_correct.float().sum().item()
        total += len(by)
        
print(correct / total)

100%|██████████| 391/391 [15:02<00:00,  2.31s/it]

0.11796



