In [19]:
import os

from munch import Munch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image, ImageFilter

In [48]:
opt = Munch()

opt.dataroot = '../datasets/UIEB_HCLR/'
opt.model_path = '../checkpoints/type_extractor/uiebhclr_cc_sc_add_400/100_net_E.pth'

opt.mosaic = False
opt.mosaic_size = (16, 16)
opt.blur = False
opt.use_vgg = False
opt.adv_norm = 'cc_sc(add)'

opt.batch_size = 32
opt.num_workers = 8
opt.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [53]:
class GaussianBlur:
    def __init__(self, window_size):
        self.window_size = window_size

    def __call__(self, img):
        return img.filter(ImageFilter.GaussianBlur(self.window_size))

In [54]:
class Mosaic:
    def __init__(self, a=16, b=16):
        self.a = a
        self.b = b

    def __call__(self, input_tensor):
        c, H, W = input_tensor.size()

        m, n = H // self.a, W // self.b

        input_tensor = input_tensor.view(c, m, self.a, n, self.b)
        input_tensor = input_tensor.permute(1, 3, 0, 2, 4).contiguous()
        input_tensor = input_tensor.view(m * n, c, self.a, self.b)

        indices = torch.randperm(m * n)
        input_tensor = input_tensor[indices]

        input_tensor = input_tensor.view(m, n, c, self.a, self.b)
        input_tensor = input_tensor.permute(2, 0, 3, 1, 4).contiguous()
        input_tensor = input_tensor.view(c, H, W)

        return input_tensor

In [56]:
from torchvision.models import vgg

class Vgg19(nn.Module):
    def __init__(self):
        super(Vgg19, self).__init__()
        self.vgg = vgg.vgg19(weights=vgg.VGG19_Weights.DEFAULT).features
        for param in self.vgg.parameters():
            param.requires_grad_(False)
        self.flatten = nn.Flatten(start_dim=1)
        self.head = nn.Sequential(
            nn.Linear(7*7*512, 1),
            nn.Sigmoid())
        
    def forward(self, img):
        out = self.vgg(img)
        out = self.flatten(out)
        return self.head(out)

In [21]:
import sys
sys.path.append('../')
import functools
from models.networks import DownSampleLayer, ResnetBlock, ResnetBlockSC, SCNorm

class Extractor(nn.Module):
    def __init__(self, norm_layer=nn.InstanceNorm2d, use_dropout=False, ngf=64, padding_type='reflect', adv_norm='cc_sc'):
        super(Extractor, self).__init__()
        
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(3, ngf, kernel_size=7, padding=0, stride=2, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(inplace=True),
                 nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]
        mult = 1
        # add ResNet blocks
        if adv_norm == '':
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                            use_bias=use_bias)]
        elif 'sc' in adv_norm and 'sc(add)' not in adv_norm:
            model += [
                ResnetBlockSC(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                use_bias=use_bias)]
        elif 'sc(add)' in adv_norm:
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                            use_bias=use_bias), SCNorm(ngf * mult)]

        for _ in range(3):
            # add downsample layer
            if 'cc' in adv_norm:
                if 'cc_' in adv_norm or 'cc(bn)' in adv_norm:
                    model.append(DownSampleLayer(ngf * mult, ngf * mult * 2, norm_layer, use_bias, residual_norm='bn'))
                elif 'cc(in)' in adv_norm:
                    model.append(DownSampleLayer(ngf * mult, ngf * mult * 2, norm_layer, use_bias, residual_norm='in'))
            else:
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
                            norm_layer(ngf * mult * 2),
                            nn.ReLU(True)]
            mult *= 2

            # add ResNet blocks
            if adv_norm == '':
                model += [
                    ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                use_bias=use_bias)]
            elif 'sc' in adv_norm and 'sc(add)' not in adv_norm:
                model += [
                    ResnetBlockSC(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                use_bias=use_bias)]
            elif 'sc(add)' in adv_norm:
                model += [
                    ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                use_bias=use_bias), SCNorm(ngf * mult)]

        model.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
        self.backbone = nn.Sequential(*model)
        self.flatten = nn.Flatten()
        self.head = nn.Sequential(
            nn.Linear(512, 1),
            # nn.ReLU(inplace=True),
            # nn.Linear(256, 2)  # 二分类输出
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.backbone(input)
        output = self.flatten(output)
        return self.head(output)

In [22]:
class DataSet:
    IMG_EXTENSIONS = [
            '.jpg', '.JPG', '.jpeg', '.JPEG',
            '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
            '.tif', '.TIF', '.tiff', '.TIFF',
        ]

    def __init__(self, dir_A_list, dir_B_list, opt):
        self.dir_A_list = dir_A_list
        self.dir_B_list = dir_B_list
        self.opt = opt

        self.paths_A = []
        for dir_A in self.dir_A_list:
            self.paths_A += sorted(self.make_dataset(dir_A))
        self.paths_B = []
        for dir_B in self.dir_B_list:
            self.paths_B += sorted(self.make_dataset(dir_B))
        self.num_A = len(self.paths_A)
        self.paths = self.paths_A + self.paths_B
        if not opt.use_vgg:
            img_size = 256
            mean = (0.5, 0.5, 0.5)
            std = (0.5, 0.5, 0.5)
        else:
            img_size = 224
            mean = (0.485, 0.456, 0.406)
            std = (0.229, 0.224, 0.225)
        transform_list = [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ]
        if opt.blur:
            transform_list.insert(1, GaussianBlur(img_size))
        if opt.mosaic:
            transform_list.append(Mosaic(*opt.mosaic_size))
        self.transform = transforms.Compose(transform_list)

    @classmethod
    def is_image_file(cls, filename):
        return any(filename.endswith(extension) for extension in cls.IMG_EXTENSIONS)

    def make_dataset(self, dir):
        images = []
        assert os.path.isdir(dir), '%s is not a valid directory' % dir
        for root, _, fnames in sorted(os.walk(dir)):
            for fname in fnames:
                if self.is_image_file(fname):
                    path = os.path.join(root, fname)
                    images.append(path)
        return images

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path).convert('RGB')
        if index < self.num_A:
            label = 0
        else:
            label = 1

        return {'img': self.transform(img), 'label': label, 'path': path}

    def __len__(self):
        return len(self.paths)

In [49]:
# model = Classifier()

if not opt.use_vgg:
    model = Extractor(adv_norm=opt.adv_norm)
else:
    model = Vgg19()
model.load_state_dict(torch.load(opt.model_path))
model = model.to(opt.device)
model = model.eval()

In [24]:
dataset = DataSet([os.path.join(opt.dataroot, 'testA')], [os.path.join(opt.dataroot, 'testB_gt')], opt=opt)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers)
print(len(dataset))

180


In [None]:
correct_predictions = 0
total_samples = 0
with torch.no_grad():
    for data in dataloader:
        imgs, labels = data['img'].to(opt.device), data['label'].to(opt.device)
        outs = model(imgs)
        # _, predicted = torch.max(outs, dim=1)
        predicted = outs.ge(0.5).squeeze()

        for item in zip(data['path'], data['label'].tolist(), predicted.tolist()):
            if item[1] != item[2]:
                print(item)

        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
    accuracy = correct_predictions / total_samples
    print(f'Test Accuracy: {accuracy * 100:.2f}%')