In [2]:
import os
from PIL import Image
from torchvision import transforms

def load_image(image_path, device, output_size=None):
    """Loads an image by transforming it into a tensor."""
    img = Image.open(image_path)

    output_dimensions = None
    if output_size is None:
        output_dimensions = (img.size[1], img.size[0]) #if there is no output size specified, we use the content image size
    elif isinstance(output_size, tuple):
        if (len(output_size) == 2) and isinstance(output_size[0], int) and isinstance(output_size[1], int): #checking if the height and width are both an int if they are specified as a tuple
            output_dimensions = output_size
    elif isinstance(output_size, int):
        output_dimensions = (output_size, output_size) #if only one value is provided, we make a square image
    else:
        raise ValueError("Provide dimension value as int, or tuple of int in format (height, width).")

    torch_load = transforms.Compose([transforms.Resize(output_dimensions),transforms.ToTensor()])
    
    return torch_load(img).unsqueeze(0).to(device)


In [3]:
import os
import torch
import torch.nn as nn
import numpy as np
from torchvision.models import vgg19
from torchvision.utils import save_image


class ImageStyleTransfer_VGG19(nn.Module):
    def __init__(self):
        super(ImageStyleTransfer_VGG19, self).__init__()

        self.chosen_features = {0: 'conv11', 5: 'conv21', 10: 'conv31', 19: 'conv41', 28: 'conv51'}
        self.model = vgg19(weights='DEFAULT').features[:29]

    def forward(self, x):
        feature_maps = dict()
        for i, layer in enumerate(self.model):
            x = layer(x)
            if i in self.chosen_features.keys():
                feature_maps[self.chosen_features[i]] = x
        
        return feature_maps


def _get_content_loss(content_feature, generated_feature):
    """Compute MSE between content feature map and generated feature map as content loss."""
    return torch.mean(np.square(generated_feature - content_feature))


def _get_style_loss(style_feature, generated_feature):
    """Compute MSE between gram matrix of style feature map and of generated feature map as style loss."""
    _, channel, height, width = generated_feature.shape
    style_gram = style_feature.view(channel, height*width).mm(style_feature.view(channel, height*width).t())
    generated_gram = generated_feature.view(channel, height*width).mm(generated_feature.view(channel, height*width).t())

    return torch.mean(np.square(generated_gram - style_gram))


def train_image(content, style, generated, device, train_config, output_dir, output_img_fmt, content_img_name, style_img_name, verbose=False):
    """Update the output image using pre-trained VGG19 model."""
    model = ImageStyleTransfer_VGG19().to(device).eval()    # freeze parameters in the model

    # set default value for each configuration if not specified in train_config
    num_epochs = train_config.get('num_epochs') if train_config.get('num_epochs') is not None else 6000
    lr = train_config.get('learning_rate') if train_config.get('learning_rate') is not None else 0.001
    alpha = train_config.get('alpha') if train_config.get('alpha') is not None else 1
    beta = train_config.get('beta') if train_config.get('beta') is not None else 0.01
    capture_content_features_from = train_config.get('capture_content_features_from') \
        if train_config.get('capture_content_features_from') is not None else {'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}
    capture_style_features_from = train_config.get('capture_style_features_from') \
        if train_config.get('capture_style_features_from') is not None else {'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}
            
    # check if values passed to capture_content_features_from and capture_style_features_from are valid
    if not isinstance(capture_content_features_from, set):
        if isinstance(capture_content_features_from, str):
            capture_content_features_from = set([item.strip() for item in capture_content_features_from.split(',')])
        elif isinstance(capture_content_features_from, dict):
            capture_content_features_from = set(capture_content_features_from.keys())
        else:
            print(f"Invalid Capture Content Features")
            return 0
        
    if not capture_content_features_from.issubset({'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}):
        print(f"Invalid Capture Content Features")
        return 0
    
    if not isinstance(capture_style_features_from, set):
        if isinstance(capture_style_features_from, dict):
            capture_style_features_from = set(capture_style_features_from.keys())
        elif isinstance(capture_style_features_from, str):
            capture_style_features_from = set([item.strip() for item in capture_style_features_from.split(',')])
        else:
            print(f"Invalid Capture Content Features")
            return 0
        
    if not capture_style_features_from.issubset({'conv11', 'conv21', 'conv31', 'conv41', 'conv51'}):
        print(f"Invalid Capture Content Features")
        return 0

    optimizer = torch.optim.Adam([generated], lr=lr)

    if verbose:
        # create a directory to save intermediate results
        intermediate_dir = os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-intermediate')
        if not os.path.exists(intermediate_dir):
            os.makedirs(intermediate_dir)

    for epoch in range(num_epochs):
        # get features maps of content, style and generated images from chosen layers
        content_features = model(content)
        style_features = model(style)
        generated_features = model(generated)

        content_loss = style_loss = 0

        for layer_name in generated_features.keys():
            content_feature = content_features[layer_name]
            style_feature = style_features[layer_name]
            generated_feature = generated_features[layer_name]

            if layer_name in capture_content_features_from:
                content_loss_per_feature = _get_content_loss(content_feature, generated_feature)
                content_loss += content_loss_per_feature
            
            if layer_name in capture_style_features_from:
                style_loss_per_feature = _get_style_loss(style_feature, generated_feature)
                style_loss += style_loss_per_feature

        # compute loss 
        total_loss = alpha * content_loss + beta * style_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # print loss value and save progress every 100 epochs
        if verbose:
            if (epoch + 1) % 100 == 0:
                save_image(generated, os.path.join(intermediate_dir, f'nst-{content_img_name}-{style_img_name}-{epoch + 1}.{output_img_fmt}'))

                print(f"\tEpoch {epoch + 1}/{num_epochs}, loss = {total_loss.item()}")
    
    if verbose:
        print("\t================================")
        print(f"\tIntermediate images are saved in directory: '{intermediate_dir}'")
        print("\t================================")

    return 1

In [7]:
import argparse
import os
import sys
import PIL
import yaml
import torch
import torch.nn as nn
from torchvision.utils import save_image
from PIL import Image


def image_style_transfer(config):
    """Implements neural style transfer on a content image using a style image, applying provided configuration."""
    if config.get('image_dir') is not None:
        image_dir = config.get('image_dir')
        content_path = os.path.join(image_dir, config.get('content_filename'))
        style_path = os.path.join(image_dir, config.get('style_filename'))
        output_dir = config.get('output_dir') if config.get('output_dir') is not None else image_dir
    else:
        output_dir = config.get('output_dir')
        content_path = config.get('content_filepath')
        style_path = config.get('style_path')


    try:
        content_img = Image.open(content_path)
    except FileNotFoundError:
        print(f"ERROR: could not find such file: '{content_path}'.")
        return
    except PIL.UnidentifiedImageError:
        print(f"ERROR: could not identify image file: '{content_path}'.")
        return

    try:
        style_img = Image.open(style_path)
    except FileNotFoundError:
        print(f"ERROR: could not find such file: '{style_path}'.")
        return
    except PIL.UnidentifiedImageError:
        print(f"ERROR: could not identify image file: '{style_path}'.")
        return
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # load content and style images
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    output_size = config.get('output_image_size')
    if output_size is not None:
        if len(output_size) > 1: 
            output_size = tuple(output_size)
        else:
            output_size = output_size[0]

    content_tensor = load_image(content_path, device, output_size=output_size)
    output_size = (content_tensor.shape[2], content_tensor.shape[3])
    style_tensor = load_image(style_path, device, output_size=output_size)


    # initialize output image
    generated_tensor = content_tensor.clone().requires_grad_(True)


    # load training configuration if provided
    train_config = dict()
    if (train_config_path := config.get('train_config_path')) is not None:
        try:
            with open(train_config_path, 'r') as f:
                train_config = yaml.safe_load(f)
        except FileNotFoundError:
            print(f"ERROR: could not find such file: '{train_config_path}'.")
            return
        except yaml.YAMLError:
            print(f"ERROR: fail to load yaml file: '{train_config_path}'.")
            return

        print("Training configuration file successfully loaded.")
        print()
        
    print("Training Model Now: ")
    
    content_img_name, content_img_fmt = os.path.splitext(os.path.basename(content_path))[0], os.path.splitext(os.path.basename(content_path))[1][1:]
    style_img_name, style_img_fmt = os.path.splitext(os.path.basename(style_path))[0], os.path.splitext(os.path.basename(style_path))[1][1:]

    output_img_fmt = config.get('output_image_format')
    if output_img_fmt == 'same':
        output_img_fmt = content_img_fmt

    # train model
    success = train_image(content_tensor, style_tensor, generated_tensor, device, train_config, output_dir, output_img_fmt, content_img_name, style_img_name, verbose=False)

    # save output image to specified directory
    if success:
        save_image(generated_tensor, os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-final.{output_img_fmt}'))
        print(f"Output image successfully generated as {os.path.join(output_dir, f'nst-{content_img_name}-{style_img_name}-final.{output_img_fmt}')}.")




In [None]:
def main():
    """Entry point of the program."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--image_dir", type=str, help="Path to the directory where content image and style image are stored.")
    parser.add_argument("--content_filename", type=str, default="content.jpg", help="File name of the content image in image_dir. Will use \"content.jpg\" if not provided.")
    parser.add_argument("--style_filename", type=str, default="style.jpg", help="File name of the style image in image_dir. Will use \"style.jpg\" if not provided.")
    parser.add_argument("--content_filepath", required="--image_dir" not in sys.argv, type=str, help="Path to the content image if image_dir not provided.")
    parser.add_argument("--style_filepath", required="--image_dir" not in sys.argv, type=str, help="Path to the style image if image_dir not provided.")
    parser.add_argument("--output_dir", required="--image_dir" not in sys.argv, type=str, help="Directory that stores the output image. Will be the same as image_dir if not provided while image_dir provided.")
    parser.add_argument("--output_image_size", nargs="+", type=int, help="Size of the output image. Either one integer or two integers separated by space is accepted. Will use the dimensions of content image if not provided.")
    parser.add_argument("--output_image_format", choices=["jpg", "png", "jpeg", "same"], default="jpg", help="Format of the output image. Can be either \"jpg\", \"png\", \"jpeg\", or \"same\". If \"same\", output image will have the same format as the content image. \"jpg\" will be the default format.")
    parser.add_argument("--train_config_path", type=str, help="Path to training configuration file in .yaml format. May include: num_epochs, learning_rate, alpha, beta, capture_content_features_from, capture_style_features_from.")
    parser.add_argument("--quiet", type=bool, default=False, help="True stops showing debugging messages, loss function values during training process, and stops generating intermediate images.")

    args = parser.parse_args()
    config = dict()
    for arg in vars(args):
        config[arg] = getattr(args, arg)
    
    image_style_transfer(config)


if __name__ == '__main__':
    main()
    