config.py

In [7]:
import argparse

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1', True):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0', False):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


parser = argparse.ArgumentParser(description='')

#Image setting
image_size = 128
mask_size = int(image_size/2)

parser.add_argument('--GPU', type=int, default=0)
parser.add_argument('--CROP_SIZE', type=int, default=image_size+128) # 178
parser.add_argument('--IMG_SIZE', type=int, default=image_size)

parser.add_argument('--G_LR', type=float, default=1e-4)
parser.add_argument('--D_LR', type=float, default=1e-4)
parser.add_argument('--GLOBAL_WGAN_LOSS_ALPHA', type=float, default=1.)
parser.add_argument('--WGAN_GP_LAMBDA', type=float, default=10)
parser.add_argument('--GAN_LOSS_ALPHA', type=float, default=0.001)
parser.add_argument('--COARSE_L1_ALPHA', type=float, default=2)
parser.add_argument('--L1_LOSS_ALPHA', type=float, default=2)
parser.add_argument('--AE_LOSS_ALPHA', type=float, default=2)
parser.add_argument('--D_TRAIN_REPEAT', type=int, default=5)

# Training settings
parser.add_argument('--DATASET', type=str, default='CelebA', choices=['CelebA'])
parser.add_argument('--NUM_EPOCHS', type=int, default=100)
parser.add_argument('--NUM_EPOCHS_DECAY', type=int, default=10)
parser.add_argument('--NUM_ITERS', type=int, default=200000)
parser.add_argument('--NUM_ITERS_DECAY', type=int, default=100000)
parser.add_argument('--BATCH_SIZE', type=int, default=32)
parser.add_argument('--BETA1', type=float, default=0.5)
parser.add_argument('--BETA2', type=float, default=0.9)
parser.add_argument('--PRETRAINED_MODEL', type=str, default=None)
parser.add_argument('--SPATIAL_DISCOUNTING_GAMMA', type=float, default=0.9)

# Test settings
parser.add_argument('--TEST_MODEL', type=str, default='20_1000')

# Misc
parser.add_argument('--MODE', type=str, default='train', choices=['train', 'test'])
parser.add_argument('--USE_TENSORBOARD', type=str2bool, default=False)

# Path
parser.add_argument('--IMAGE_PATH', type=str, default='./data/CelebA/images')
parser.add_argument('--METADATA_PATH', type=str, default='./data/list_attr_celeba.txt')
parser.add_argument('--LOG_PATH', type=str, default='logs')
parser.add_argument('--MODEL_SAVE_PATH', type=str, default='models')
parser.add_argument('--SAMPLE_PATH', type=str, default='samples')

# Step size
parser.add_argument('--PRINT_EVERY', type=int, default=400)
parser.add_argument('--SAMPLE_STEP', type=int, default=400)
parser.add_argument('--MODEL_SAVE_STEP', type=int, default=400)

# etc
parser.add_argument('--IMG_SHAPE', type=list, default=[image_size,image_size,3])
parser.add_argument('--MASK_HEIGHT', type=int, default=mask_size)
parser.add_argument('--MASK_WIDTH', type=int, default=mask_size)
parser.add_argument('--VERTICAL_MARGIN', type=int, default=0)
parser.add_argument('--HORIZONTAL_MARGIN', type=int, default=0)
parser.add_argument('--MAX_DELTA_HEIGHT', type=int, default=int(image_size/8))
parser.add_argument('--MAX_DELTA_WIDTH', type=int, default=int(image_size/8))
args = parser.parse_args(args=[])

data_loader.py

In [8]:
import torch
import os
import random
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image


class CelebDataset(Dataset):
    def __init__(self, image_path, metadata_path, transform, mode, crop_size):
        self.image_path = image_path
        self.transform = transform
        self.mode = mode
        self.lines = open(metadata_path, 'r').readlines()
        self.num_data = int(self.lines[0])
        self.crop_size = crop_size

        print ('Start preprocessing dataset..!')
        random.seed(1234)
        self.preprocess()
        print ('Finished preprocessing dataset..!')

        if self.mode == 'train':
            self.num_data = len(self.train_filenames)
        elif self.mode == 'test':
            self.num_data = len(self.test_filenames)

    def preprocess(self):
        self.train_filenames = []
        self.test_filenames = []

        lines = self.lines[2:]
        random.shuffle(lines)   # random shuffling
        for i, line in enumerate(lines):

            splits = line.split()
            filename = splits[0]

            if (i+1) < 20000:
                self.test_filenames.append(filename)
            else:
                self.train_filenames.append(filename)

    def __getitem__(self, index):
        if self.mode == 'train':
            image = Image.open(os.path.join(self.image_path, self.train_filenames[index]))
        elif self.mode in ['test']:
            image = Image.open(os.path.join(self.image_path, self.test_filenames[index]))
        # self.check_size(image, index)
        return self.transform(image)

    def __len__(self):
        return self.num_data


def get_loader(image_path, metadata_path, crop_size, image_size, batch_size, dataset='CelebA', mode='train'):
    """Build and return data loader."""

    if mode == 'train':
        transform = transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.Resize(image_size, interpolation=Image.ANTIALIAS),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
    else:
        transform = transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.Scale(image_size, interpolation=Image.ANTIALIAS),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    if dataset == 'CelebA':
        dataset = CelebDataset(image_path, metadata_path, transform, mode, crop_size)

    shuffle = False
    if mode == 'train':
        shuffle = True

    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=shuffle)
    return data_loader

model.py

In [9]:
import torch
import torch.nn as nn
from model_module import *
from torch.autograd import Variable
from util import *

class Generator(nn.Module):
    def __init__(self, first_dim=32):
        super(Generator,self).__init__()
        self.stage_1 = CoarseNet(5, first_dim)
        self.stage_2 = RefinementNet(5, first_dim)

    def forward(self, masked_img, mask): # mask : 1 x 1 x H x W
        # border, maybe
        mask = mask.expand(masked_img.size(0),1,masked_img.size(2),masked_img.size(3))
        ones = to_var(torch.ones(mask.size()))

        # stage1
        stage1_input = torch.cat([masked_img, ones, ones*mask], dim=1)
        stage1_output, resized_mask = self.stage_1(stage1_input, mask[0].unsqueeze(0))

        # stage2
        new_masked_img = stage1_output*mask + masked_img.clone()*(1.-mask)
        stage2_input = torch.cat([new_masked_img, ones, ones*mask], dim=1)
        stage2_output, offset_flow = self.stage_2(stage2_input, resized_mask[0].unsqueeze(0))

        return stage1_output, stage2_output, offset_flow


class CoarseNet(nn.Module):
    '''
    # input: B x 5 x W x H
    # after down: B x 128(32*4) x W/4 x H/4
    # after atrous: same with the output size of the down module
    # after up : same with the input size
    '''
    def __init__(self, in_ch, out_ch):
        super(CoarseNet,self).__init__()
        self.down = Down_Module(in_ch, out_ch)
        self.atrous = Dilation_Module(out_ch*4, out_ch*4)
        self.up = Up_Module(out_ch*4, 3)

    def forward(self, x, mask):
        x = self.down(x)
        resized_mask = down_sample(mask, scale_factor=0.25, mode='nearest')
        x = self.atrous(x)
        x = self.up(x)
        return x, resized_mask


class RefinementNet(nn.Module):
    '''
    # input: B x 5 x W x H
    # after down: B x 128(32*4) x W/4 x H/4
    # after atrous: same with the output size of the down module
    # after up : same with the input size
    '''
    def __init__(self, in_ch, out_ch):
        super(RefinementNet,self).__init__()
        self.down_conv_branch = Down_Module(in_ch, out_ch)
        self.down_attn_branch = Down_Module(in_ch, out_ch, activation=nn.ReLU())
        self.atrous = Dilation_Module(out_ch*4, out_ch*4)
        self.CAttn = Contextual_Attention_Module(out_ch*4, out_ch*4)
        self.up = Up_Module(out_ch*8, 3, isRefine=True)

    def forward(self, x, resized_mask):
        # conv branch
        conv_x = self.down_conv_branch(x)
        conv_x = self.atrous(conv_x)
        
        # attention branch
        attn_x = self.down_attn_branch(x)
        attn_x, offset_flow = self.CAttn(attn_x, attn_x, mask=resized_mask) # attn_x => B x 128(32*4) x W/4 x H/4

        # concat two branches
        deconv_x = torch.cat([conv_x, attn_x], dim=1) # deconv_x => B x 256 x W/4 x H/4
        x = self.up(deconv_x)

        return x, offset_flow


class Discriminator(nn.Module):
    def __init__(self, first_dim=64):
        super(Discriminator,self).__init__()
        self.global_discriminator = Flatten_Module(3, first_dim, False) 
        self.local_discriminator = Flatten_Module(3, first_dim, True)

    def forward(self, global_x, local_x):
        global_y = self.global_discriminator(global_x)
        local_y = self.local_discriminator(local_x)
        return global_y, local_y # B x 256*(256 or 512)


model_module.py

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from util import *


class Conv(nn.Module):
    def __init__(self, in_ch, out_ch, K=3, S=1, P=1, D=1, activation=nn.ELU()):
        super(Conv, self).__init__()
        if activation:
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D),
                activation
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=K, stride=S, padding=P, dilation=D)
            )

    def forward(self, x):
        x = self.conv(x)
        return x


# conv 1~6
class Down_Module(nn.Module):
    def __init__(self, in_ch, out_ch, activation=nn.ELU()):
        super(Down_Module, self).__init__()
        layers = []
        layers.append(Conv(in_ch, out_ch, K=5))

        curr_dim = out_ch
        for i in range(2):
            layers.append(Conv(curr_dim, curr_dim*2, K=3, S=2))
            layers.append(Conv(curr_dim*2, curr_dim*2))
            curr_dim *= 2
        
        layers.append(Conv(curr_dim, curr_dim, activation=activation))
        self.out = nn.Sequential(*layers)

    def forward(self, x):
        return self.out(x)


# conv 7~10
class Dilation_Module(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Dilation_Module, self).__init__()
        layers = []
        dilation = 1
        for i in range(4):
            dilation *= 2
            layers.append(Conv(in_ch, out_ch, D=dilation, P=dilation))
        self.out = nn.Sequential(*layers)

    def forward(self, x):
        return self.out(x)


# conv 11~17
class Up_Module(nn.Module):
    def __init__(self, in_ch, out_ch, isRefine=False):
        super(Up_Module, self).__init__()
        layers = []
        curr_dim = in_ch
        if isRefine:
            layers.append(Conv(curr_dim, curr_dim//2))
            curr_dim //= 2
        else:
            layers.append(Conv(curr_dim, curr_dim))
        
        # conv 12~15
        for i in range(2):
            layers.append(Conv(curr_dim, curr_dim))
            layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
            layers.append(Conv(curr_dim, curr_dim//2))
            curr_dim //= 2

        layers.append(Conv(curr_dim, curr_dim//2))
        layers.append(Conv(curr_dim//2, out_ch, activation=0))

        self.out = nn.Sequential(*layers)

    def forward(self, x):
        output = self.out(x)
        return torch.clamp(output, min=-1., max=1.)


class Flatten_Module(nn.Module):
    def __init__(self, in_ch, out_ch, isLocal=True):
        super(Flatten_Module, self).__init__()
        layers = []
        layers.append(Conv(in_ch, out_ch, K=5, S=2, P=2, activation=nn.LeakyReLU()))
        curr_dim = out_ch

        for i in range(2):
            layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU()))
            curr_dim *= 2
        
        if isLocal:
            layers.append(Conv(curr_dim, curr_dim*2, K=5, S=2, P=2, activation=nn.LeakyReLU()))
        else:
            layers.append(Conv(curr_dim, curr_dim, K=5, S=2, P=2, activation=nn.LeakyReLU()))

        self.out = nn.Sequential(*layers)

    def forward(self, x):
        x = self.out(x)
        return x.view(x.size(0),-1) # 2B x 256*(256 or 512); front 256:16*16


# pmconv 9~10
class Contextual_Attention_Module(nn.Module):
    def __init__(self, in_ch, out_ch, rate=2, stride=1):
        super(Contextual_Attention_Module, self).__init__()
        self.rate = rate
        self.padding = nn.ZeroPad2d(1)
        self.up_sample = nn.Upsample(scale_factor=self.rate, mode='nearest')
        layers = []
        for i in range(2):
            layers.append(Conv(in_ch, out_ch))
        self.out = nn.Sequential(*layers)

    def forward(self, f, b, mask=None, ksize=3, stride=1, 
                fuse_k=3, softmax_scale=10., training=True, fuse=True):

        """ Contextual attention layer implementation.

        Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.

        Args:
            f: Input feature to match (foreground).
            b: Input feature for match (background).
            mask: Input mask for b, indicating patches not available.
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from b.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
            training: Indicating if current graph is training or inference.

        Returns:
            tf.Tensor: output

        """

        # get shapes
        raw_fs = f.size() # B x 128 x 64 x 64
        raw_int_fs = list(f.size())
        raw_int_bs = list(b.size())

        # extract patches from background with stride and rate
        kernel = 2*self.rate
        raw_w = self.extract_patches(b, kernel=kernel, stride=self.rate)
        raw_w = raw_w.contiguous().view(raw_int_bs[0], -1, raw_int_bs[1], kernel, kernel) # B*HW*C*K*K (B, 32*32, 128, 4, 4)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = down_sample(f, scale_factor=1/self.rate, mode='nearest')
        b = down_sample(b, scale_factor=1/self.rate, mode='nearest')
        fs = f.size() # B x 128 x 32 x 32
        int_fs = list(f.size())
        f_groups = torch.split(f, 1, dim=0) # Split tensors by batch dimension; tuple is returned

        # from b(B*H*W*C) to w(b*k*k*c*h*w)
        bs = b.size() # B x 128 x 32 x 32
        int_bs = list(b.size())
        w = self.extract_patches(b)

        w = w.contiguous().view(int_fs[0], -1, int_fs[1], ksize, ksize) # B*HW*C*K*K (B, 32*32, 128, 3, 3)

        # process mask
        if mask is not None:
            mask = down_sample(mask, scale_factor=1./self.rate, mode='nearest')
        else:
            mask = torch.zeros([1, 1, bs[2], bs[3]])

        m = self.extract_patches(mask)

        m = m.contiguous().view(1, 1, -1, ksize, ksize)  # B*C*HW*K*K
        m = m[0] # (1, 32*32, 3, 3)
        m = reduce_mean(m) # smoothing, maybe
        mm = m.eq(0.).float() # (1, 32*32, 1, 1)       
        
        w_groups = torch.split(w, 1, dim=0) # Split tensors by batch dimension; tuple is returned
        raw_w_groups = torch.split(raw_w, 1, dim=0) # Split tensors by batch dimension; tuple is returned
        y = []
        offsets = []
        k = fuse_k
        scale = softmax_scale
        fuse_weight = Variable(torch.eye(k).view(1, 1, k, k)).cuda() # 1 x 1 x K x K

        for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
            '''
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            '''
            # conv for compare
            wi = wi[0]
            escape_NaN = Variable(torch.FloatTensor([1e-4])).cuda()
            wi_normed = wi / torch.max(l2_norm(wi), escape_NaN)
            yi = F.conv2d(xi, wi_normed, stride=1, padding=1) # yi => (B=1, C=32*32, H=32, W=32)

            # conv implementation for fuse scores to encourage large patches
            if fuse:
                yi = yi.view(1, 1, fs[2]*fs[3], bs[2]*bs[3]) # make all of depth to spatial resolution, (B=1, I=1, H=32*32, W=32*32)
                yi = F.conv2d(yi, fuse_weight, stride=1, padding=1) # (B=1, C=1, H=32*32, W=32*32)

                yi = yi.contiguous().view(1, fs[2], fs[3], bs[2], bs[3]) # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, fs[2]*fs[3], bs[2]*bs[3])
                
                yi = F.conv2d(yi, fuse_weight, stride=1, padding=1)
                yi = yi.contiguous().view(1, fs[3], fs[2], bs[3], bs[2])
                yi = yi.permute(0, 2, 1, 4, 3)

            yi = yi.contiguous().view(1, bs[2]*bs[3], fs[2], fs[3]) # (B=1, C=32*32, H=32, W=32)

            # softmax to match
            yi = yi * mm  # mm => (1, 32*32, 1, 1)
            yi = F.softmax(yi*scale, dim=1)
            yi = yi * mm  # mask

            _, offset = torch.max(yi, dim=1) # argmax; index
            division = torch.div(offset, fs[3]).long()
            offset = torch.stack([division, torch.div(offset, fs[3])-division], dim=-1)

            # deconv for patch pasting
            # 3.1 paste center
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0) # back to the mini-batch
        y.contiguous().view(raw_int_fs)
        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view([int_bs[0]] + [2] + int_bs[2:])

        # case1: visualize optical flow: minus current position
        h_add = Variable(torch.arange(0,float(bs[2]))).cuda().view([1, 1, bs[2], 1])
        h_add = h_add.expand(bs[0], 1, bs[2], bs[3])
        w_add = Variable(torch.arange(0,float(bs[3]))).cuda().view([1, 1, 1, bs[3]])
        w_add = w_add.expand(bs[0], 1, bs[2], bs[3])

        offsets = offsets - torch.cat([h_add, w_add], dim=1).long()

        # to flow image
        flow = torch.from_numpy(flow_to_image(offsets.permute(0,2,3,1).cpu().data.numpy()))
        flow = flow.permute(0,3,1,2)

        # # case2: visualize which pixels are attended
        # flow = torch.from_numpy(highlight_flow((offsets * mask.int()).numpy()))
        if self.rate != 1:
            flow = self.up_sample(flow)
        return self.out(y), flow

    # padding1(16 x 128 x 64 x 64) => (16 x 128 x 64 x 64 x 3 x 3)
    def extract_patches(self, x, kernel=3, stride=1):
        x = self.padding(x)
        all_patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
        return all_patches


run.py

In [11]:
import torch as nn
import torch.optim as optim
from torch import cuda
from config import *
from model import *
from data_loader import *
from util import *
import time
import datetime
import os
from torchvision.utils import save_image

class Run(object):
    def __init__(self, args):
        # Data loader
        if args.DATASET == 'CelebA':
            self.data_loader = get_loader(args.IMAGE_PATH, args.METADATA_PATH, 
                                        args.CROP_SIZE, args.IMG_SIZE, 
                                        args.BATCH_SIZE, args.DATASET, args.MODE)

        # Model hyper-parameters
        self.image_shape = args.IMG_SHAPE

        # Hyper-parameteres
        self.stage1_lambda_l1 = args.COARSE_L1_ALPHA
        self.global_wgan_loss_alpha = args.GLOBAL_WGAN_LOSS_ALPHA
        self.wgan_gp_lambda = args.WGAN_GP_LAMBDA
        self.gan_loss_alpha = args.GAN_LOSS_ALPHA
        self.l1_loss_alpha = args.L1_LOSS_ALPHA
        self.ae_loss_alpha = args.AE_LOSS_ALPHA
        self.g_lr = args.G_LR
        self.d_lr = args.D_LR
        self.beta1 = args.BETA1
        self.beta2 = args.BETA2

        # Training settings
        self.dataset = args.DATASET
        self.num_epochs = args.NUM_EPOCHS
        self.num_epochs_decay = args.NUM_EPOCHS_DECAY
        self.num_iters = args.NUM_ITERS
        self.num_iters_decay = args.NUM_ITERS_DECAY
        self.batch_size = args.BATCH_SIZE
        self.use_tensorboard = args.USE_TENSORBOARD
        self.pretrained_model = args.PRETRAINED_MODEL
        self.d_train_repeat = args.D_TRAIN_REPEAT

        # Test settings
        self.test_model = args.TEST_MODEL

        # Path
        self.sample_path = args.SAMPLE_PATH
        self.model_save_path = args.MODEL_SAVE_PATH

        # Step size
        self.print_every = args.PRINT_EVERY
        self.sample_step = args.SAMPLE_STEP
        self.model_save_step = args.MODEL_SAVE_STEP

        # etc
        self.make_dir()
        self.init_network(args)
        self.loss = {}

        if self.pretrained_model:
            self.load_pretrained_model()

    def make_dir(self):
        if not os.path.exists(self.model_save_path):
            os.makedirs(self.model_save_path)
        if not os.path.exists(self.sample_path):
            os.makedirs(self.sample_path)

    def init_network(self, args):

        # Models
        self.G = Generator()
        self.D = Discriminator()

        # Optimizers
        self.g_optimizer = optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])

        # Loss
        self.L1 = Discounted_L1(args)
        self.torch_L1 = nn.L1Loss()

        # etc.
        self.util = Util(args)

        # Print networks
        # self.util.print_network(self.G, 'G')
        # self.util.print_network(self.D, 'D')

        if torch.cuda.is_available():
            self.G = self.G.cuda()
            self.D = self.D.cuda()
            self.L1 = self.L1.cuda()
            self.torch_L1 = nn.L1Loss().cuda()

    def load_pretrained_model(self):
        self.G.load_state_dict(torch.load(os.path.join(
            self.model_save_path, 'G_{}_L1_{}.pth'.format(self.pretrained_model, self.l1_loss_alpha))))
        self.D.load_state_dict(torch.load(os.path.join(
            self.model_save_path, 'D_{}_L1_{}.pth'.format(self.pretrained_model, self.l1_loss_alpha))))

        print('loaded trained models (step: {})..!'.format(self.pretrained_model))

    def train(self):   

        # The number of iterations per epoch
        iters_per_epoch = len(self.data_loader)

        # lr cache for decaying
        g_lr = self.g_lr
        d_lr = self.d_lr

        if self.pretrained_model:
            start = int(self.pretrained_model.split('_')[0])
        else:
            start = 0

        start_time = time.time()
        self.G.train()
        self.D.train()
        for epoch in range(start, self.num_epochs):
            for batch, real_image in enumerate(self.data_loader): # real_image : B x 3 x H x W
                
                batch_size = real_image.size(0)
                real_image = 2.*real_image - 1. # [-1,1]
                
                # one bbox for each batch, ( top, left, maxH, maxW )
                # W and H will be reduced at the function bbox2mask
                bbox = self.util.random_bbox()

                binary_mask = self.util.bbox2mask(bbox)
                inverse_mask = 1.- binary_mask
                masked_image = real_image.clone()*inverse_mask

                binary_mask = to_var(binary_mask)
                inverse_mask = to_var(inverse_mask)
                masked_image = to_var(masked_image)
                real_image = to_var(real_image)

                stage_1, stage_2, offset_flow = self.G(masked_image, binary_mask)
                
                fake_image = stage_2*binary_mask + masked_image*inverse_mask # mask_location: generated, around_mask: ground_truth

                real_patch = self.util.local_patch(real_image, bbox)
                stage_1_patch = self.util.local_patch(stage_1, bbox)
                stage_2_patch = self.util.local_patch(stage_2, bbox)
                mask_patch = self.util.local_patch(binary_mask, bbox)
                fake_patch = self.util.local_patch(fake_image, bbox)

                l1_alpha = self.stage1_lambda_l1
                self.loss['recon'] = l1_alpha * self.L1(stage_1_patch, real_patch) # Coarse Network reconstruction loss
                self.loss['recon'] = self.loss['recon'] + self.L1(stage_2_patch, real_patch) # Refinement Network reconstruction loss
                
                self.loss['ae_loss'] = l1_alpha * self.torch_L1(stage_1*inverse_mask, real_image*inverse_mask) # recon loss except mask
                self.loss['ae_loss'] = self.loss['ae_loss'] + self.torch_L1(stage_2*inverse_mask, real_image*inverse_mask) # recon loss except mask
                self.loss['ae_loss'] = self.loss['ae_loss'] / torch.mean(torch.mean(inverse_mask, dim=3), dim=2) # 1 x 1 tensor

                if (batch+1) % self.d_train_repeat == 0:
                    global_real_fake_image = torch.cat([real_image, fake_image], dim=0)
                    local_real_fake_image = torch.cat([real_patch, fake_patch], dim=0)
                else:
                    global_real_fake_image = torch.cat([real_image, fake_image.clone()], dim=0)
                    local_real_fake_image = torch.cat([real_patch, fake_patch.clone()], dim=0)

                global_real_fake_vector, local_real_fake_vector = self.D(global_real_fake_image, local_real_fake_image)

                global_real_vector, global_fake_vector = torch.split(global_real_fake_vector, batch_size, dim=0)
                local_real_vector, local_fake_vector = torch.split(local_real_fake_vector, batch_size, dim=0)

                global_G_loss, global_D_loss = self.wgan_loss(global_real_vector, global_fake_vector)
                local_G_loss, local_D_loss = self.wgan_loss(local_real_vector, local_fake_vector)

                self.loss['g_loss'] = self.global_wgan_loss_alpha * (global_G_loss + local_G_loss)
                self.loss['d_loss'] = global_D_loss + local_D_loss

                if (batch+1) % self.d_train_repeat == 0:
                    # gradient penalty
                    global_interpolate = self.random_interpolates(real_image, fake_image) 
                    local_interpolate = self.random_interpolates(real_patch, fake_patch)
                else:
                    global_interpolate = self.random_interpolates(real_image, fake_image.clone()) 
                    local_interpolate = self.random_interpolates(real_patch, fake_patch.clone())

                global_gp_vector, local_gp_vector = self.D(global_interpolate, local_interpolate)

                global_penalty = self.gradient_penalty(global_interpolate, global_gp_vector, mask=binary_mask)
                local_penalty = self.gradient_penalty(local_interpolate, local_gp_vector, mask=mask_patch)

                self.loss['gp_loss'] = self.wgan_gp_lambda * (local_penalty + global_penalty)
                self.loss['d_loss'] = self.loss['d_loss'] + self.loss['gp_loss']

                if (batch+1) % self.d_train_repeat == 0:             
                    self.loss['g_loss'] = self.gan_loss_alpha * self.loss['g_loss']
                    self.loss['g_loss'] = self.loss['g_loss'] + self.l1_loss_alpha * self.loss['recon'] + self.ae_loss_alpha * self.loss['ae_loss']
                    self.backprop(D=True,G=True)

                else:                  
                    self.loss['g_loss'] = to_var(torch.FloatTensor([0]))
                    self.backprop(D=True,G=False)


                if batch % self.print_every == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))

                    print('=====================================================')
                    print("Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format(
                        elapsed, epoch+1, self.num_epochs, batch+1, iters_per_epoch))
                    print('=====================================================')
                    print('reconstruction loss: ', self.loss['recon'].data[0])
                    print('ae loss: ', self.loss['ae_loss'].data[0][0])
                    print('g loss: ', self.loss['g_loss'].data[0])
                    print('d loss: ', self.loss['d_loss'].data[0])
                    show_image(real_image, (masked_image+binary_mask), stage_1, stage_2, fake_image, offset_flow)

                # Save model checkpoints
                if batch % self.model_save_step == 0:
                    torch.save(self.G.state_dict(),
                        os.path.join(self.model_save_path, 'G_{}_L1_{}.pth'.format(epoch+1, self.l1_loss_alpha)))
                    torch.save(self.D.state_dict(),
                        os.path.join(self.model_save_path, 'D_{}_L1_{}.pth'.format(epoch+1, self.l1_loss_alpha)))

                # Save sample image
                if batch % self.sample_step == 0:
                    save_image(self.denorm(fake_image.clone().data.cpu()),
                        os.path.join(self.sample_path, '{}_{}_fake.png'.format(epoch+1, batch+1)),nrow=1, padding=0)
                    print('Translated images and saved into {}..!'.format(self.sample_path))

    def backprop(self, D=True, G=True):
        if D:
            self.d_optimizer.zero_grad()
            self.loss['d_loss'].backward(retain_graph=G)
            self.d_optimizer.step()
        if G:
            self.g_optimizer.zero_grad()
            self.loss['g_loss'].backward()
            self.g_optimizer.step()

    def wgan_loss(self, real, fake):
        diff = fake - real
        d_loss = torch.mean(diff)
        g_loss = -torch.mean(fake)
        return g_loss, d_loss

    def random_interpolates(self, real, fake, alpha=None):
        shape = list(real.size())
        real = real.contiguous().view(shape[0], -1, 1, 1)
        fake = fake.contiguous().view(shape[0], -1, 1, 1)
        if alpha is None:
            alpha = Variable(torch.rand(shape[0], 1, 1, 1)).cuda()
        interpolates = fake + alpha*(real - fake)
        return interpolates.view(shape)

    def gradient_penalty(self, x, y, mask=None, norm=1.):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = Variable(torch.ones(y.size())).cuda()
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]
        dydx = dydx * mask
        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def denorm(self, x):
        out = (x + 1) / 2
        return out.clamp_(0, 1)

def main(_):

    cuda.set_device(args.GPU)
    print("Running on GPU : ", args.GPU)
    run = Run(args)

    if args.MODE == 'train':
        run.train()
    else:
        run.test()

main(args)

RuntimeError: cuda runtime error (35) : CUDA driver version is insufficient for CUDA runtime version at ..\torch\csrc\cuda\Module.cpp:33

util.py

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.backends import cudnn
from random import *
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

class Util(object):
    def __init__(self,args):
        self.args = args

    def print_network(self, model, name):
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(name)
        print(model)
        print("The number of parameters: {}".format(num_params))

    def random_bbox(self):
        img_shape = self.args.IMG_SHAPE
        img_height = img_shape[0]
        img_width = img_shape[1]

        maxt = img_height - self.args.VERTICAL_MARGIN - self.args.MASK_HEIGHT
        maxl = img_width - self.args.HORIZONTAL_MARGIN - self.args.MASK_WIDTH
        
        t = randint(self.args.VERTICAL_MARGIN, maxt)
        l = randint(self.args.HORIZONTAL_MARGIN, maxl)
        h = self.args.MASK_HEIGHT
        w = self.args.MASK_WIDTH
        return (t, l, h, w)

    def bbox2mask(self, bbox):
        """Generate mask tensor from bbox.

        Args:
            bbox: configuration tuple, (top, left, height, width)
            config: Config should have configuration including IMG_SHAPES,
                MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.

        Returns:
            tf.Tensor: output with shape [B, 1, H, W]

        """
        def npmask(bbox, height, width, delta_h, delta_w):
            mask = np.zeros((1, 1, height, width), np.float32)
            h = np.random.randint(delta_h//2+1)
            w = np.random.randint(delta_w//2+1)
            mask[:, :, bbox[0]+h : bbox[0]+bbox[2]-h,
                 bbox[1]+w : bbox[1]+bbox[3]-w] = 1.
            return mask

        img_shape = self.args.IMG_SHAPE
        height = img_shape[0]
        width = img_shape[1]

        mask = npmask(bbox, height, width, 
                        self.args.MAX_DELTA_HEIGHT, 
                        self.args.MAX_DELTA_WIDTH)
        
        return torch.FloatTensor(mask)

    def local_patch(self, x, bbox):
        '''
        bbox[0]: top
        bbox[1]: left
        bbox[2]: height
        bbox[3]: width
        '''
        x = x[:, :, bbox[0]:bbox[0]+bbox[2], bbox[1]:bbox[1]+bbox[3]]
        return x


class Discounted_L1(nn.Module):
    def __init__(self, args, size_average=True, reduce=True):
        super(Discounted_L1, self).__init__()
        self.reduce = reduce
        self.discounting_mask = spatial_discounting_mask(args.MASK_WIDTH, 
                                                        args.MASK_HEIGHT, 
                                                        args.SPATIAL_DISCOUNTING_GAMMA)
        self.size_average = size_average

    def forward(self, input, target):
        self._assert_no_grad(target)
        return self._pointwise_loss(lambda a, b: torch.abs(a - b), torch._C._nn.l1_loss,
                           input, target, self.discounting_mask, self.size_average, self.reduce)

    def _assert_no_grad(self, variable):
        assert not variable.requires_grad, \
        "nn criterions don't compute the gradient w.r.t. targets - please " \
        "mark these variables as volatile or not requiring gradients"

    def _pointwise_loss(self, lambd, lambd_optimized, input, target, discounting_mask, size_average=True, reduce=True):
        if target.requires_grad:
            d = lambd(input, target)
            d = d * discounting_mask
            if not reduce:
                return d
            return torch.mean(d) if size_average else torch.sum(d)
        else:
            return lambd_optimized(input, target, size_average, reduce)


def spatial_discounting_mask(mask_width, mask_height, discounting_gamma):
    """Generate spatial discounting mask constant.

    Spatial discounting mask is first introduced in publication:
        Generative Image Inpainting with Contextual Attention, Yu et al.

    Returns:
        tf.Tensor: spatial discounting mask

    """
    gamma = discounting_gamma
    shape = [1, 1, mask_width, mask_height]
    if True:
        print('Use spatial discounting l1 loss.')
        mask_values = np.ones((mask_width, mask_height))
        for i in range(mask_width):
            for j in range(mask_height):
                mask_values[i, j] = max(
                    gamma**min(i, mask_width-i),
                    gamma**min(j, mask_height-j))
        mask_values = np.expand_dims(mask_values, 0)
        mask_values = np.expand_dims(mask_values, 1)
        mask_values = mask_values
    else:
        mask_values = np.ones(shape)
    # it will be extended along the batch dimension suitably
    mask_values = torch.from_numpy(mask_values).float()
    return to_var(mask_values)


def down_sample(x, size=None, scale_factor=None, mode='nearest'):
    # define size if user has specified scale_factor
    if size is None: size = (int(scale_factor*x.size(2)), int(scale_factor*x.size(3)))
    # create coordinates
    h = torch.arange(0,size[0]) / (size[0]-1) * 2 - 1
    w = torch.arange(0,size[1]) / (size[1]-1) * 2 - 1
    # create grid
    grid =torch.zeros(size[0],size[1],2)
    grid[:,:,0] = w.unsqueeze(0).repeat(size[0],1)
    grid[:,:,1] = h.unsqueeze(0).repeat(size[1],1).transpose(0,1)
    # expand to match batch size
    grid = grid.unsqueeze(0).repeat(x.size(0),1,1,1)
    if x.is_cuda: grid = Variable(grid).cuda()
    # do sampling
    return F.grid_sample(x, grid, mode=mode)


def reduce_mean(x):
    for i in range(4):
        if i==1: continue
        x = torch.mean(x, dim=i, keepdim=True)
    return x


def l2_norm(x):
    def reduce_sum(x):
        for i in range(4):
            if i==1: continue
            x = torch.sum(x, dim=i, keepdim=True)
        return x

    x = x**2
    x = reduce_sum(x)
    return torch.sqrt(x)


def show_image(real, masked, stage_1, stage_2, fake, offset_flow):
    batch_size = real.shape[0]

    (real, masked, stage_1, stage_2, fake, offset_flow) = (
                                var_to_numpy(real), 
                                var_to_numpy(masked), 
                                var_to_numpy(stage_1),
                                var_to_numpy(stage_2),
                                var_to_numpy(fake),
                                var_to_numpy(offset_flow)
                              )
    # offset_flow = (offset_flow*2).astype(int) -1
    for x in range(batch_size):
        if x > 5 :
            break
        fig, axs = plt.subplots(ncols=5, figsize=(15,3))
        axs[0].set_title('real image')
        axs[0].imshow(real[x])
        axs[0].axis('off')

        axs[1].set_title('masked image')
        axs[1].imshow(masked[x])
        axs[1].axis('off')

        axs[2].set_title('stage_1 image')
        axs[2].imshow(stage_1[x])
        axs[2].axis('off')

        axs[3].set_title('stage_2 image')
        axs[3].imshow(stage_2[x])
        axs[3].axis('off')

        axs[4].set_title('fake_image')
        axs[4].imshow(fake[x])
        axs[4].axis('off')

        # axs[5].set_title('C_Attn')
        # axs[5].imshow(offset_flow[x])
        # axs[5].axis('off')

        plt.show()


def var_to_numpy(obj, for_vis=True):
    if for_vis:
        obj = obj.permute(0,2,3,1)
        obj = (obj+1) / 2
    return obj.data.cpu().numpy()


def to_var(x, volatile=False):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x, volatile=volatile)


def flow_to_image(flow):
    """Transfer flow map to image.
    Part of code forked from flownet.
    """
    out = []
    maxu = -999.
    maxv = -999.
    minu = 999.
    minv = 999.
    maxrad = -1
    for i in range(flow.shape[0]):
        u = flow[i, :, :, 0]
        v = flow[i, :, :, 1]
        idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
        u[idxunknow] = 0
        v[idxunknow] = 0
        maxu = max(maxu, np.max(u))
        minu = min(minu, np.min(u))
        maxv = max(maxv, np.max(v))
        minv = min(minv, np.min(v))
        rad = np.sqrt(u ** 2 + v ** 2)
        maxrad = max(maxrad, np.max(rad))
        u = u/(maxrad + np.finfo(float).eps)
        v = v/(maxrad + np.finfo(float).eps)
        img = compute_color(u, v)
        out.append(img)
    return np.float32(np.uint8(out))


def highlight_flow(flow):
    """Convert flow into middlebury color code image.
    """
    out = []
    s = flow.shape
    for i in range(flow.shape[0]):
        img = np.ones((s[1], s[2], 3)) * 144.
        u = flow[i, :, :, 0]
        v = flow[i, :, :, 1]
        for h in range(s[1]):
            for w in range(s[1]):
                ui = u[h,w]
                vi = v[h,w]
                img[ui, vi, :] = 255.
        out.append(img)
    return np.float32(np.uint8(out))


def compute_color(u,v):
    h, w = u.shape
    img = np.zeros([h, w, 3])
    nanIdx = np.isnan(u) | np.isnan(v)
    u[nanIdx] = 0
    v[nanIdx] = 0
    # colorwheel = COLORWHEEL
    colorwheel = make_color_wheel()
    ncols = np.size(colorwheel, 0)
    rad = np.sqrt(u**2+v**2)
    a = np.arctan2(-v, -u) / np.pi
    fk = (a+1) / 2 * (ncols - 1) + 1
    k0 = np.floor(fk).astype(int)
    k1 = k0 + 1
    k1[k1 == ncols+1] = 1
    f = fk - k0
    for i in range(np.size(colorwheel,1)):
        tmp = colorwheel[:, i]
        col0 = tmp[k0-1] / 255
        col1 = tmp[k1-1] / 255
        col = (1-f) * col0 + f * col1
        idx = rad <= 1
        col[idx] = 1-rad[idx]*(1-col[idx])
        notidx = np.logical_not(idx)
        col[notidx] *= 0.75
        img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
    return img


def make_color_wheel():
    RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = np.zeros([ncols, 3])
    col = 0
    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
    col += RY
    # YG
    colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
    colorwheel[col:col+YG, 1] = 255
    col += YG
    # GC
    colorwheel[col:col+GC, 1] = 255
    colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
    col += GC
    # CB
    colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
    colorwheel[col:col+CB, 2] = 255
    col += CB
    # BM
    colorwheel[col:col+BM, 2] = 255
    colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
    col += + BM
    # MR
    colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
    colorwheel[col:col+MR, 0] = 255
    return colorwheel


COLORWHEEL = make_color_wheel()
