# Style Transfer with segmentation

This notebook can take input as style and content, segment the content into foreground and background and then perform style transfer

In [None]:
!pip install torch torchvision
!pip install Pillow==4.0.0

In [3]:
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
# Import COCO config
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))  # To find local version

import coco

%matplotlib inline 

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

# Directory of images to run detection on
IMAGE_DIR = os.path.join(ROOT_DIR, "images")

ModuleNotFoundError: No module named 'imgaug'

## Configurations

We'll be using a model trained on the MS-COCO dataset. The configurations of this model are in the ```CocoConfig``` class in ```coco.py```.

For inferencing, modify the configurations a bit to fit the task. To do so, sub-class the ```CocoConfig``` class and override the attributes you need to change.

In [1]:
class InferenceConfig(coco.CocoConfig):
    # Set batch size to 1 since we'll be running inference on
    # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
config.display()

NameError: name 'coco' is not defined

## Create Model and Load Trained Weights

In [None]:
# Create model object in inference mode.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

# Load weights trained on MS-COCO
model.load_weights(COCO_MODEL_PATH, by_name=True)

## Class Names

The model classifies objects and returns class IDs, which are integer value that identify each class. Some datasets assign integer values to their classes and some don't. For example, in the MS-COCO dataset, the 'person' class is 1 and 'teddy bear' is 88. The IDs are often sequential, but not always. The COCO dataset, for example, has classes associated with class IDs 70 and 72, but not 71.

To improve consistency, and to support training on data from multiple sources at the same time, our ```Dataset``` class assigns it's own sequential integer IDs to each class. For example, if you load the COCO dataset using our ```Dataset``` class, the 'person' class would get class ID = 1 (just like COCO) and the 'teddy bear' class is 78 (different from COCO). Keep that in mind when mapping class IDs to class names.

To get the list of class names, you'd load the dataset and then use the ```class_names``` property like this.
```
# Load COCO dataset
dataset = coco.CocoDataset()
dataset.load_coco(COCO_DIR, "train")
dataset.prepare()

# Print class names
print(dataset.class_names)
```

We don't want to require you to download the COCO dataset just to run this demo, so we're including the list of class names below. The index of the class name in the list represent its ID (first class is 0, second is 1, third is 2, ...etc.)

In [None]:
# COCO Class names
# Index of the class in the list is its ID. For example, to get ID of
# the teddy bear class, use: class_names.index('teddy bear')
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']

## 1. Run Object Detection to get segmented image

In [None]:
import cv2
# Load a random image from the images folder
# file_names = next(os.walk(IMAGE_DIR))[2]
image = skimage.io.imread(os.path.join(IMAGE_DIR, "../images/uncle_pic.jpg"))
results = model.detect([image], verbose=1)

# Visualize results
r = results[0]

In [None]:
## this is to apply the mask on the image

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

clone = image.copy()
foreground_image = np.zeros_like(clone)

(H, W, C) = clone.shape[:3]

roi = r['rois']

for i in range(roi.shape[0]):
    mask = r['masks'][:,:,i]
    for c in range(3):
        foreground_image[:, :, c] = np.where(mask == 1,
                                  clone[:, :, c],
                                  foreground_image[:, :, c])
imgplot = plt.imshow(foreground_image)

## 2. Apply style transfer on the foreground

In [None]:
## Style Transfer on segmented image
%matplotlib inline
import torch
import torch.optim as optim
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import skimage.io

In [None]:
# we only require the features version of this model
vgg  = models.vgg19(pretrained = True).features
for param in vgg.parameters():
    param.requires_grad_(False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg.to(device)

In [None]:
# Image Transforms

# Cap on max size to speed up style transfer process
def load_image(img_path, max_size = 400, shape = None, style=False):
    image = Image.open(img_path).convert('RGB')
#     if max(image.size) > max_size:
#         size = max_size
#     else:
#         size = max(image.size)
    
    # To ensure that shape of style and content image is same
    
    if style:
        in_transform = transforms.Compose([
          transforms.Resize(shape),  # maintains aspect ratio
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ])
    else:
        in_transform = transforms.Compose([
          # transforms.Resize(size),  # maintains aspect ratio
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ])
    image = in_transform(image).unsqueeze(0)
    return image
 

In [None]:
content = load_image('./Lion.jpeg').to(device)
# we use -2 because we want only the shape of the image
style = load_image('./StarryNight.jpg', shape=content.shape[-2:], style=True).to(device)

In [None]:
def im_convert(tensor):
    image = tensor.cpu().clone().detach().numpy()
    image = image.squeeze()
    image = image.transpose(1, 2, 0)
    image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
    image = image.clip(0, 1)
    return image

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax1.axis('off')
ax2.imshow(im_convert(style))
ax2.axis('off')

In [None]:
# Style Transfer
def get_features(image, model):
    # Use same layers as used in the Style Transfer paper
    # 21 for content and rest for style as having more features will give a complete representation of the style
    layers = {
        '0': 'conv1_1',
        '5': 'conv2_1',
        '10': 'conv3_1',
        '19': 'conv4_1',
        '21': 'conv4_2',  # content extraction , will provide high depth features
        '28': 'conv5_1'
    }
    features = {}
    for name, layer in model._modules.items():
        image = layer(image)
        if name in layers:
            features[layers[name]] = image
            
    return features

In [None]:
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

In [None]:
# Gram matrix content information is eliminated while style information is retained.
def gram_matrix(tensor):
    _, d, h, w = tensor.size()
    tensor = tensor.view(d, h * w) # 2D tensor
    gram = torch.mm(tensor, tensor.t())
    return gram

In [None]:
# Reduce noise in the image
def tv_loss(y):
    return torch.tensor(( 1 * (
    torch.sum(torch.abs(y[:, :, :-1] - y[:, :, 1:])) + 
    torch.sum(torch.abs(y[:, :-1, :] - y[:, 1:, :]))
  )), dtype = torch.float)

In [None]:
style_grams = {
    layer: gram_matrix(style_features[layer]) for layer in style_features
}

In [None]:
style_weights = {
    'conv1_1': 1.,
    'conv2_1': 0.75,
    'conv3_1': 0.2,
    'conv4_1': 0.2,
    'conv5_1': 0.2
}
content_weight = 5
style_weight = 1e3
tv_weight = 1e-3

In [None]:
# Target image is a clone of the content image
target = content.clone().requires_grad_(True).to(device)

In [None]:
show_every = 300 
optimizer = optim.Adam([target], lr=0.03)
steps = 1000
height, width, channels = im_convert(target).shape
image_array = np.empty(shape = (300, height, width, channels))
capture_frame = steps / 300
counter = 0

In [None]:
for ii in range(1, steps + 1):
    target_features = get_features(target, vgg)
    content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
    style_loss = 0
    
    for layer in style_weights:
        target_feature = target_features[layer]
        target_gram = gram_matrix(target_feature)
        style_gram = style_grams[layer]
        layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
        _, d, h, w = target_feature.shape
        style_loss += layer_style_loss / (d * h * w)
        
    target_image = im_convert(target)
    t_loss = tv_loss(torch.from_numpy(target_image))
       
    total_loss = content_weight * content_loss + style_weight * style_loss + tv_weight * t_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if ii % show_every == 0:
        print('Total loss: ', total_loss.item())
        print('Iterations: ', ii)
        plt.imshow(im_convert(target))
        plt.axis('off')
        plt.show()
        
    if ii % capture_frame == 0:
        image_array[counter] = im_convert(target)
        counter += 1
    

In [None]:
skimage.io.imsave("sample.jpg", im_convert(target))

## 3. Add backgroud to the stylized image

In [None]:
## combine the styled image with the background
## the result of the style transfer is passed here.

import matplotlib.pyplot as plt
import matplotlib.image as mpimg


style_output = skimage.io.imread(os.path.join(IMAGE_DIR, "../images/sample-3.jpg"))

image_out = image.copy()

roi = r['rois']

for i in range(roi.shape[0]):
    mask = r['masks'][:,:,i]
    for c in range(3):
        image_out[:, :, c] = np.where(mask == 0,
                                  image_out[:, :, c],
                                  style_output[:, :, c])
imgplot = plt.imshow(image_out)