# DEEP PHOTO STYLE TRANSFER
## PAPER CODE: 54

#### GROUP MEMBERS

2017A7PS0122P Rohit Jain

2017A7PS0088P Vaishnavi Kotturu

2017A3PS0267P Khushi Gupta

(The code has been submitted in the partial fulfillment of the course Neural Networks and Fuzzy Logic 2019-20 Semester 1)

In [0]:
# Uncomment the following lines if using google colab

# from google.colab import files
# uploaded = files.upload()

# Header Files

The code has following dependencies:
1. torch
2. matplotlib
3. PIL
4. torchvision

In [0]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import copy

from PIL import Image
from skimage.transform import resize
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

%matplotlib inline

# Image Loading and Preprocessing

In [0]:
def tensor2img(x):

    '''
    Converts a torch tensor to an array
    '''

    x = x.detach().numpy()
    x = x.transpose(0,2,3,1)
    x = x.squeeze()
    x = x.clip(0,1)
    return x


def img2tensor(x):
    
    '''
    Converts an image to a tensor
    '''

    if x.ndim == 2:
        y=np.expand_dims(x, axis=0)
        y=np.expand_dims(y, axis=0)
        return torch.Tensor(y)
    elif x.ndim == 3:
        y=x.transpose(2, 0, 1)
        y=np.expand_dims(y, axis=0)
        return torch.Tensor(y)
    elif x.ndim == 4:
        y=x.transpose(0, 3, 1, 2)
        return torch.Tensor(y)


def extract_func(seg):

    '''
    Extracts the color masks from the segmented masks:
    Blue, Green, Black, White, Red, Yellow, Grey, Light Blue, Purple
    '''

    ans = []

    # BLUE
    r = seg[..., 0] < 0.1
    g = seg[..., 1] < 0.1
    b = seg[..., 2] > 0.9
    mask = r & g & b
    ans.extend([mask])

    # GREEN
    r = seg[..., 0] < 0.1
    g = seg[..., 1] > 0.9
    b = seg[..., 2] < 0.1
    mask = r & g & b
    ans.extend([mask])

    # BLACK
    r = seg[..., 0] < 0.1
    g = seg[..., 1] < 0.1
    b = seg[..., 2] < 0.1
    mask = r & g & b
    ans.extend([mask])

    # WHITE
    r = seg[..., 0] > 0.9
    g = seg[..., 1] > 0.9
    b = seg[..., 2] > 0.9
    mask = r & g & b
    ans.extend([mask])

    # RED
    r = seg[..., 0] > 0.9
    g = seg[..., 1] < 0.1
    b = seg[..., 2] < 0.1
    mask = r & g & b
    ans.extend([mask])

    # YELLOW
    r = seg[..., 0] > 0.9
    g = seg[..., 1] > 0.9
    b = seg[..., 2] < 0.1
    mask = r & g & b
    ans.extend([mask])

    # GREY
    r = (seg[..., 0] > 0.4) & (seg[..., 0] < 0.6)
    g = (seg[..., 1] > 0.4) & (seg[..., 1] < 0.6)
    b = (seg[..., 2] > 0.4) & (seg[..., 2] < 0.6)
    mask = r & g & b
    ans.extend([mask])

    # LIGHT BLUE
    r = seg[..., 0] < 0.1
    g = seg[..., 1] > 0.9
    b = seg[..., 2] > 0.9
    mask = r & g & b
    ans.extend([mask])

    # PURPLE
    r = seg[..., 0] > 0.9
    g = seg[..., 1] < 0.1
    b = seg[..., 2] > 0.9
    mask = r & g & b
    ans.extend([mask])

    return ans


def masks_func(img):
    
    '''
    Returns the segmentation masks from the segmentated image.
    '''

    image = Image.open(img)
    result = np.array(image, dtype=np.float) / 255
    return extract_func(result)


def get_masks(path1, path2):
    
    '''
    Finds the dense masks to be considered for augmented style loss
    '''

    masks1 = masks_func(path1)
    masks2 = masks_func(path2)

    temp = []
    for c, s in zip(masks2, masks1):
        temp.append(np.mean(c) > 0.01 and np.mean(s) > 0.01)
    
    new1=[]
    for i, j in zip(masks1, temp):
        if j:
            new1.append(i)
        
    new2=[]
    for i, j in zip(masks2, temp):
        if j:
            new2.append(i)

    return new1, new2


def load_masks(path1, path2, size):
    
    '''
    Loads the tensors for the mask images
    '''

    arr1, arr2 = get_masks(path1, path2)
    resize_f = lambda x: resize(x, size, mode="reflect")

    arr1 = map(resize_f,arr1)
    arr2 = map(resize_f,arr2)
    arr1 = map(img2tensor,arr1)
    arr2 = map(img2tensor,arr2)

    return arr1, arr2

def load_images(img_name, size):

    '''
    Loads images
    '''

    func = transforms.Compose([transforms.Resize(size), transforms.ToTensor()])
    output = func(Image.open(img_name)).unsqueeze(0)
    
    return output


def plot_output(content, style, output):
    
    '''
    Utility function to plot the content, style and output image
    '''

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(tensor2img(style))
    plt.title("Style Image")

    plt.subplot(1, 3, 2)
    plt.imshow(tensor2img(output))
    plt.title("Output Image")

    plt.subplot(1, 3, 3)
    plt.imshow(tensor2img(content))
    plt.title("Content Image")

    plt.tight_layout()
    plt.show()

# Deep Photo Style Transfer

In [0]:
class Normalization(nn.Module):
    
    '''
    Normalizes the input image to be put in our model
    '''
    
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std


class Content_loss(nn.Module):
    
    '''
    Calculates the content variation between the content image and the output image at the given layer
    '''
    
    def __init__(self, target):
        super(Content_loss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input


def gram_matrix(input_image):

    '''
    Measures the style in the input image by calculating the correlation between its channels
    '''
    
    batch_size, number_of_features, feature_dim1, feature_dim2 = input_image.size() 
    features = input_image.view(batch_size * number_of_features, feature_dim1 * feature_dim2)
    G = torch.mm(features, features.t())
    G = G.div(batch_size * number_of_features* feature_dim1 * feature_dim2)
    
    return G
    
    

class augmented_style_loss(nn.Module):

    """
    Exploits the semantic information of the style image  and calculates the style variation segment wise rather than calculating on the entire image at once 
    """

    def __init__(self, target_feature, target_masks, input_masks):
        super(augmented_style_loss, self).__init__()
        self.input_masks = [mask.detach() for mask in input_masks]
        self.targets = [
            gram_matrix(target_feature * mask).detach() for mask in target_masks
        ]

    def forward(self, input):
        gram_matrices = [
            gram_matrix(input * mask.detach()) for mask in self.input_masks
        ]
        self.loss = sum(
            F.mse_loss(gram, target) for gram, target in zip(gram_matrices, self.targets)
        )
        return input


def model_and_losses(vgg, mean, std, content_image, style_image, content_layers, style_layers, content_masks, style_masks):

    '''
    Constructs and returns the models and losses while stylizing the content image
    '''
    
    vgg = copy.deepcopy(vgg)
    normalization = Normalization(mean, std)

    content_losses = []
    style_losses = []

    model = nn.Sequential(normalization)

    num_pool_layers = 0 
    num_conv_layers = 0
    
    for model_layer in vgg.children():
        if isinstance(model_layer, nn.Conv2d):
            num_conv_layers += 1
            name = "conv{}_{}".format(num_pool_layers, num_conv_layers)

        elif isinstance(model_layer, nn.ReLU):
            name = "relu{}_{}".format(num_pool_layers, num_conv_layers)
            model_layer = nn.ReLU(inplace=False)

        elif isinstance(model_layer, nn.MaxPool2d):
            num_pool_layers += 1
            num_conv_layers = 0
            name = "pool_{}".format(num_pool_layers)
            model_layer = nn.AvgPool2d(kernel_size=model_layer.kernel_size, stride=model_layer.stride, padding=model_layer.padding)
            style_masks = [model_layer(mask) for mask in style_masks]
            content_masks = [model_layer(mask) for mask in content_masks]

        elif isinstance(model_layer, nn.BatchNorm2d):
            name = "bn{}_{}".format(num_pool_layers, num_conv_layers)

        model.add_module(name, model_layer)

        if name in content_layers:
            target = model(content_image).detach()
            content_loss = Content_loss(target)
            model.add_module("content_loss_{}".format(num_pool_layers), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_image).detach()

            style_loss = augmented_style_loss(target_feature, style_masks, content_masks)
            model.add_module("style_loss_{}".format(num_pool_layers), style_loss)
            style_losses.append(style_loss)

    return model, content_losses, style_losses


def get_input_optimizer(input_image):

    '''
    Returns the LBFGS optimizer to be used for backpropagation.
    '''
    
    optimizer = optim.LBFGS([input_image.requires_grad_()])
    return optimizer


def photoreg(input_image, L):
    
    '''
    Calculates the the losses and gradients for photorealism regularization
    '''
    
    image = tensor2img(input_image)
    gradient = L.dot(image.reshape(-1, 3))
    loss = (gradient * image.reshape(-1, 3))
    gradient = gradient.reshape(image.shape)
    loss = loss.sum()
    
    return loss, 2.0 * gradient


def stylize(vgg, mean, std, content_image, style_image, output_image, content_layers, style_layers, content_masks, style_masks, num_steps=300, style_weight=100000, content_weight=1000, reg_weight=1000):

    """
    Runs the style transfer by using:
    1. Matting Laplacian to locally affine in color space
    2. Semantic Segmentation for augmented style loss
    3. Photorealism Regularization to make photos look more natural
    """
    
    model, content_losses, style_losses = model_and_losses(vgg, mean, std, content_image, style_image, content_layers, style_layers, content_masks, style_masks)
    optimizer = get_input_optimizer(output_image)
    L = compute_laplacian(tensor2img(content_image))

    step = 0
    while step <= num_steps:

        def closure():
            output_image.data.clamp_(0, 1)

            optimizer.zero_grad()
            model(output_image)

            get_loss = lambda x: x.loss
            style_score = style_weight * sum(map(get_loss, style_losses))
            content_score = content_weight * sum(map(get_loss, content_losses))

            loss = style_score + content_score
            loss.backward()

            reg_loss, reg_grad = photoreg(output_image, L)
            reg_grad_tensor = img2tensor(reg_grad)

            output_image.grad += reg_weight * reg_grad_tensor

            loss += reg_weight * reg_loss

            nonlocal step
            step += 1

            if step % 50 == 0:
                print("iteration {:>4d}:".format(step), "Style Loss = {:.3f} Content Loss = : {:.3f} Reg Loss = :{:.3f}".format(style_score.item(), content_score.item(), reg_loss))

            return loss

        optimizer.step(closure)

    output_image.data.clamp_(0, 1)

    return output_image

# Matting Laplacian

In [0]:
from __future__ import division

import logging

import cv2
import numpy as np
from numpy.lib.stride_tricks import as_strided
import scipy.sparse
import scipy.sparse.linalg


def _rolling_block(A, block=(3, 3)):
    """Applies sliding window to given matrix."""
    shape = (A.shape[0] - block[0] + 1, A.shape[1] - block[1] + 1) + block
    strides = (A.strides[0], A.strides[1]) + A.strides
    return as_strided(A, shape=shape, strides=strides)


def compute_laplacian(img, mask=None, eps=10 ** (-7), win_rad=1):
    """Computes Matting Laplacian for a given image.
    Args:
        img: 3-dim numpy matrix with input image
        mask: mask of pixels for which Laplacian will be computed.
            If not set Laplacian will be computed for all pixels.
        eps: regularization parameter controlling alpha smoothness
            from Eq. 12 of the original paper. Defaults to 1e-7.
        win_rad: radius of window used to build Matting Laplacian (i.e.
            radius of omega_k in Eq. 12).
    Returns: sparse matrix holding Matting Laplacian.
    """

    win_size = (win_rad * 2 + 1) ** 2
    h, w, d = img.shape
    # Number of window centre indices in h, w axes
    c_h, c_w = h - 2 * win_rad, w - 2 * win_rad
    win_diam = win_rad * 2 + 1

    indsM = np.arange(h * w).reshape((h, w))
    ravelImg = img.reshape(h * w, d)
    win_inds = _rolling_block(indsM, block=(win_diam, win_diam))

    win_inds = win_inds.reshape(c_h, c_w, win_size)
    if mask is not None:
        mask = cv2.dilate(
            mask.astype(np.uint8), np.ones((win_diam, win_diam), np.uint8)
        ).astype(np.bool)
        win_mask = np.sum(mask.ravel()[win_inds], axis=2)
        win_inds = win_inds[win_mask > 0, :]
    else:
        win_inds = win_inds.reshape(-1, win_size)

    winI = ravelImg[win_inds]

    win_mu = np.mean(winI, axis=1, keepdims=True)
    win_var = np.einsum("...ji,...jk ->...ik", winI, winI) / win_size - np.einsum(
        "...ji,...jk ->...ik", win_mu, win_mu
    )

    inv = np.linalg.inv(win_var + (eps / win_size) * np.eye(3))

    X = np.einsum("...ij,...jk->...ik", winI - win_mu, inv)
    vals = np.eye(win_size) - (1.0 / win_size) * (
        1 + np.einsum("...ij,...kj->...ik", X, winI - win_mu)
    )

    nz_indsCol = np.tile(win_inds, win_size).ravel()
    nz_indsRow = np.repeat(win_inds, win_size).ravel()
    nz_indsVal = vals.ravel()
    L = scipy.sparse.coo_matrix(
        (nz_indsVal, (nz_indsRow, nz_indsCol)), shape=(h * w, h * w)
    )
    return L


def closed_form_matting_with_prior(image, prior, prior_confidence, consts_map=None):
    """Applies closed form matting with prior alpha map to image.
    Args:
        image: 3-dim numpy matrix with input image.
        prior: matrix of same width and height as input image holding apriori alpha map.
        prior_confidence: matrix of the same shape as prior hodling confidence of prior alpha.
        consts_map: binary mask of pixels that aren't expected to change due to high
            prior confidence.
    Returns: 2-dim matrix holding computed alpha map.
    """

    assert image.shape[:2] == prior.shape, (
        "prior must be 2D matrix with height and width equal " "to image."
    )
    assert image.shape[:2] == prior_confidence.shape, (
        "prior_confidence must be 2D matrix with " "height and width equal to image."
    )
    assert (consts_map is not None) or image.shape[
        :2
    ] == consts_map.shape, (
        "consts_map must be 2D matrix with height and width equal to image."
    )

    logging.info("Computing Matting Laplacian.")
    laplacian = compute_laplacian(
        image, ~consts_map if consts_map is not None else None
    )

    confidence = scipy.sparse.diags(prior_confidence.ravel())
    logging.info("Solving for alpha.")
    solution = scipy.sparse.linalg.spsolve(
        laplacian + confidence, prior.ravel() * prior_confidence.ravel()
    )
    alpha = np.minimum(np.maximum(solution.reshape(prior.shape), 0), 1)
    return alpha


def closed_form_matting_with_trimap(image, trimap, trimap_confidence=100.0):
    """Apply Closed-Form matting to given image using trimap."""

    assert image.shape[:2] == trimap.shape, (
        "trimap must be 2D matrix with height and width equal " "to image."
    )
    consts_map = (trimap < 0.1) | (trimap > 0.9)
    return closed_form_matting_with_prior(
        image, trimap, trimap_confidence * consts_map, consts_map
    )


def closed_form_matting_with_scribbles(image, scribbles, scribbles_confidence=100.0):
    """Apply Closed-Form matting to given image using scribbles image."""

    assert (
        image.shape == scribbles.shape
    ), "scribbles must have exactly same shape as image."
    consts_map = np.sum(abs(image - scribbles), axis=-1) > 0.001
    return closed_form_matting_with_prior(
        image, scribbles[:, :, 0], scribbles_confidence * consts_map, consts_map
    )


closed_form_matting = closed_form_matting_with_trimap



# Code in Execution

In [0]:
imsize = (128,128)

# Please pass the image file path in the corresponding fields

style_image = load_images("tar16.png", imsize)
content_image = load_images("in16.png", imsize)
output_image = content_image.clone()

style_masks, content_masks = load_masks("segtar16.png","segin16.png",imsize)

vgg = models.vgg19(pretrained=True).features.eval()

mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

style_layers = ["conv1_1", "conv2_1", "conv3_1", "conv4_1", "conv5_1"]
content_layers = ["conv4_2"]

output = stylize(vgg, mean, std, content_image, style_image, output_image, content_layers, style_layers, content_masks, style_masks,style_weight=1e6,
    content_weight=1e4,
    reg_weight=1e-4,
    num_steps=500)



  if __name__ == '__main__':
  # Remove the CWD from sys.path while we load stuff.


iteration   50: Style Loss = 73.587 Content Loss = : 3.396 Reg Loss = :16.821
iteration  100: Style Loss = 62.521 Content Loss = : 3.163 Reg Loss = :54.484


In [0]:
plot_output(content_image, style_image, output_image)