# **0. Installations and Imports**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir data

In [None]:
!mount --bind /content/drive/My\ Drive/data /content/data/
%cd /content/data

In [None]:
%cd /usr/local/lib/python3.7/dist-packages/

/usr/local/lib/python3.7/dist-packages


In [None]:
!pip install pyyaml==5.1
!pip uninstall fastai
!git clone https://github.com/fastai/fastai
!pip install -e "fastai[dev]"

Collecting pyyaml==5.1
[?25l  Downloading https://files.pythonhosted.org/packages/9f/2c/9417b5c774792634834e730932745bc09a7d36754ca00acf1ccd1ac2594d/PyYAML-5.1.tar.gz (274kB)
[K     |█▏                              | 10kB 19.0MB/s eta 0:00:01[K     |██▍                             | 20kB 10.8MB/s eta 0:00:01[K     |███▋                            | 30kB 9.1MB/s eta 0:00:01[K     |████▉                           | 40kB 8.1MB/s eta 0:00:01[K     |██████                          | 51kB 5.4MB/s eta 0:00:01[K     |███████▏                        | 61kB 5.9MB/s eta 0:00:01[K     |████████▍                       | 71kB 5.9MB/s eta 0:00:01[K     |█████████▋                      | 81kB 6.6MB/s eta 0:00:01[K     |██████████▊                     | 92kB 6.9MB/s eta 0:00:01[K     |████████████                    | 102kB 6.8MB/s eta 0:00:01[K     |█████████████▏                  | 112kB 6.8MB/s eta 0:00:01[K     |██████████████▍                 | 122kB 6.8MB/s eta 0:00:01[

#**1. Style Transfer**

In [None]:
import torch
import cv2
import os
from PIL import Image, ImageFile
import numpy as np
import torch.nn as nn
from torch import Tensor
from torchvision.models import vgg19_bn, resnet18
from torchvision.models.resnet import BasicBlock
from torchvision.transforms import ToTensor, Normalize, Compose, Resize as TResize
from torch.nn import functional as F
from torch.nn import AvgPool2d, Conv2d, Module, Sigmoid
from torchvision.utils import save_image
from fastai.data.block import DataBlock
from fastai.callback.schedule import fine_tune
from fastai.vision.learner import unet_learner, create_body
from fastai.vision.data import ImageBlock
from fastai.vision.augment import Resize, RandomCrop, setup_aug_tfms
from fastai.vision.models.unet import DynamicUnet
from fastai.data.transforms import get_image_files, RandomSplitter, parent_label
from fastai.data.block import RegressionBlock
from fastai.torch_core import flatten_check, TensorImage, Module
from fastai.metrics import mse, accuracy
from fastai.layers import PixelShuffle_ICNR, ConvLayer, SigmoidRange
from fastai.learner import Learner
from sys import getsizeof
from inspect import getsource
from pathlib import Path
from types import MethodType
from collections.abc import Iterable
from matplotlib import pyplot as plt
from math import ceil
from enum import Enum

In [None]:
%cd /content/data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg_net = vgg19_bn(pretrained=True, progress=True)

In [None]:
style_net = vgg_net.features
style_net.to(device)
pass

In [None]:
def create_gram_array(self, layer):

    indices = torch.combinations(torch.range(0, layer.size()[1] - 1)).to(device).to(torch.int64)
    filter_i = torch.index_select(layer, 1, torch.flatten(indices[:, 0]))
    filter_j = torch.index_select(layer, 1, torch.flatten(indices[:, 1]))
    #arr = torch.einsum("ijkl,ijkl->ij", filter_i, filter_j)
    arr = torch.sum(filter_i*filter_j, (2, 3))
    #print("Gram array: ", arr.size())
    return arr

style_net.create_gram_array = MethodType(create_gram_array, style_net)
style_net.switch = 0

transf = TResize((512, 512))
def forward(self, input):
    input = transf(input)
    layer_nums = {1, 8, 15, 28, 41} if self.switch == 0 else {28}
    output = None
    for l, module in enumerate(self):
        if (input.device.type == "cpu" and device.type == "cuda") or \
           (input.device.type == "cuda:0" and device.type == "cpu"): 
            input = input.float().to(device)
        input = module(input)
        if l in layer_nums:
            if self.switch == 0:
                gram_array = self.create_gram_array(input)        
                output = gram_array if type(output) == type(None) else \
                torch.hstack((output, gram_array))
            else:
                output = torch.flatten(input, 1)
    return output
    
style_net.forward = MethodType(forward, style_net)

In [None]:
dir = 'data/images'
path = Path(dir)
BATCH_SIZE = 1

## Dataset preparation: Only needed once per dataset

Generate the grams for the style image and save to a tensor file, to avoid having to recalculate it on each pass through the network

In [None]:
# This is the point to swap in a different style image
# Currently using Starry Night by Van Gogh jpeg
img = Image.open("starry.jpg")
img = img.resize((512, 512))
arr = np.asarray(img) / 255
style_tensor = TensorImage(arr).permute(2, 0, 1).unsqueeze(0)
style_batch = torch.cat((style_tensor,)*BATCH_SIZE, dim=0).float()#.to(device)
style_net.switch = 0
with torch.no_grad():
    with torch.cuda.amp.autocast():
        targ_style = style_net(style_batch)
        torch.save(targ_style, "starry_night.pt")

  This is separate from the ipykernel package so we can avoid doing imports until


Do the same for each image in the training set and save to a numpy file

In [None]:
#os.mkdir(dir + "_tensors")
style_net.switch = 1
with torch.no_grad():
    with torch.cuda.amp.autocast():
        for filename in os.listdir(dir):
            dest = dir + "_tensors/" + filename[:-4] + ".npy"
            img = Image.open(dir + "/" + filename)
            arr = np.asarray(img) / 255
            tensor = TensorImage(arr).permute(2, 0, 1).unsqueeze(0)
            output = style_net(tensor).cpu().numpy()
            np.save(dest, output)

### Autoencoder Label Getter

In [None]:
# Mapping from filename to pixelmaps for autoencoder priming
def get_labels(o):
    img = cv2.imread(str(o))
    arr = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    tensor = torch.from_numpy((arr/255)).float()
    dims = tensor.size()
    tensor = tensor.permute(2, 0, 1)#.unsqueeze(0)
    return tensor

### Main Model Label Getter

In [None]:
# Mapping from filename to gram arrays for full training
def get_labels(o):
    path = str(o.parent) + "_tensors/" + str(o.name[:-4]) + ".npy"
    arr = np.load(path)
    return torch.from_numpy(arr).squeeze(0)

### Dataloaders

In [None]:
pics = DataBlock(
    blocks=(ImageBlock, RegressionBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=get_labels)

dls = pics.dataloaders(path, bs=BATCH_SIZE)

### Loss Functions

In [None]:
targ_style = torch.load("starry_night.pt")

# Loss function for full network
def style_and_structure_mse(inp, targ_struct):

    # Input to the loss function is a tuple containing a structure feature stack
    # and a set of 5 style gram arrays concatenated together 
  
    loss = F.mse_loss(inp[0], targ_struct)*1e1

    idxs = [2016, 8128, 32640, 130816, 130816]
    i1, i2, i3, i4, i5 = torch.split(inp[1], idxs, dim=1)
    s1, s2, s3, s4, s5 = torch.split(targ_style, idxs, dim=1)
    sl1 = F.mse_loss(i1, s1)
    sl2 = F.mse_loss(i2, s2)
    sl3 = F.mse_loss(i3, s3)
    sl4 = F.mse_loss(i4, s4)
    sl5 = F.mse_loss(i5, s5)

    # Style loss produces much larger numbers than the structural MSE and
    # is reduced appropriately 

    style_loss = sl1*2e-7 + sl2*2e-6 + sl3*2e-5 + sl4*2e-4 + sl5*2e-3

    #print(loss.item(), sl1.item(), sl2.item(), sl3.item(), sl4.item(), sl5.item())
    
    loss += style_loss
    return loss
    #return style_loss

# Loss function for autoencoder priming
def autoencoder_mse(inp, targ):
    tform = TResize((inp.size()[-2], inp.size()[-1]))
    targ = tform(targ)
    return F.mse_loss(inp, targ)

### Model Setup

### Training the UNet

In [None]:
#learn = unet_learner(dls, arch=resnet18, n_out=3, y_range=[0, 1], loss_func=autoencoder_mse)
learn = unet_learner(dls, arch=resnet18, n_out=3, y_range=[0, 1], loss_func=style_and_structure_mse)

# Remove the residual connection on the highest layer of the UNet
def forward(self, up_in):
    up_out = self.shuf(up_in)
    return self.conv2(self.conv1(self.relu(up_out)))

learn.model[7].forward = MethodType(forward, learn.model[7])
learn.model[7].conv1[0] = Conv2d(128, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

Only run the following cell when training the full model, not when priming the autoencoder 

In [None]:
learn.model = torch.nn.Sequential(learn.model, style_net)

### Loading Weights

Load the vanilla autoencoder weights if initial priming is finished but yet to start full model training

In [None]:
checkpoint = torch.load("snipped_unet.pth")
learn.model[0].load_state_dict(checkpoint["model_state_dict"])
pass

Load the model weights if resuming full training

In [None]:
checkpoint = torch.load("snipped_unet_starry.pth")
learn.model.load_state_dict(checkpoint["model_state_dict"])
learn.model.to(device)
pass

### Forward Method and Optimizer & Training Steps



In [None]:
# Modified version
def forward(self, x):
    torch.cuda.empty_cache()
    res = x
    #noise_tensor = TensorImage(torch.randn(*x.size())).float().to(device)
    #res = noise_batch
    #for l in self.layers:     
    for l in self: 
        #print(f"Bytes Allocated {torch.cuda.memory_allocated()}\n")
        #print(torch.cuda.memory_summary(device=device))
        #print(torch.cuda.list_gpu_processes(device=device))
        res.orig = x
        #res.orig = noise_tensor
        if not isinstance(l, DynamicUnet):
            l.switch = 0 
            self.stored = res
            
        nres = l(res)
        # We have to remove res.orig to avoid hanging refs and therefore memory leaks
        res.orig, nres.orig = None, None
        res = nres

        if not isinstance(l, DynamicUnet):
            style = res
            l.switch = 1
            structure = l(self.stored)
            del self.stored

    #del noise_tensor
    return structure, style
    
learn.model.forward = MethodType(forward, learn.model)

The following two cells help convert tensor values to 16-bit half types to reduce memory and stop the GPU from becoming overloaded

In [None]:
scaler = torch.cuda.amp.GradScaler()

class OptState(Enum):
    READY = 0
    UNSCALED = 1
    STEPPED = 2

def step(self, model):

    if (not self._enabled):
        model._with_events(model.opt.step, 'step', CancelStepException)
        return

    #if "closure" in kwargs:
    #    raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.")

    self._check_scale_growth_tracker("step")

    optimizer_state = self._per_optimizer_states[id(model.opt)]

    if optimizer_state["stage"] is OptState.STEPPED:
        raise RuntimeError("step() has already been called since the last update().")

    retval = None

    if (hasattr(model.opt, "_step_supports_amp_scaling") and model.opt._step_supports_amp_scaling):
        # it can query its own state, invoke unscale_ on itself, etc
        model._with_events(model.opt.step, 'step', CancelStepException)
        optimizer_state["stage"] = OptState.STEPPED
        return

    if optimizer_state["stage"] is OptState.READY:
        self.unscale_(model.opt)

    #assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."

    if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
        model._with_events(model.opt.step, 'step', CancelStepException)
        return

    optimizer_state["stage"] = OptState.STEPPED

scaler.step = MethodType(step, scaler)

In [None]:
CancelStepException = ""
def _do_one_batch(self):
    #print("1 ", next(iter(self.model.parameters()))[0, 0, 0])
    with torch.cuda.amp.autocast(): #
        self.pred = self.model(*self.xb)
        self('after_pred')
        if len(self.yb):
            self.loss_grad = self.loss_func(self.pred, *self.yb)
            self.loss = self.loss_grad.clone()
        self('after_loss')
        if not self.training or not len(self.yb): return
        self('before_backward')
    scaler.scale(self.loss_grad).backward() # 
    scaler.unscale_(self.opt)
    scaler.step(self)
    #self._with_events(self.opt.step, 'step', CancelStepException)
    scaler.update() #
    self.opt.zero_grad()
    del self.pred, self.xb

learn._do_one_batch = MethodType(_do_one_batch, learn)

### Training

In [None]:
# Use for either autoencoder priming or full model training
# changing the number of epochs as appropriate
num_epochs = 120
learn.fine_tune(num_epochs)

Save the autoencoder

In [None]:
torch.save({'model_state_dict': learn.model.state_dict(), 
            'optimizer_state_dict': learn.opt.state_dict()}, "snipped_unet.pth")

Save the full model

In [None]:
torch.save({'model_state_dict': learn.model.state_dict(), 
            'optimizer_state_dict': learn.opt.state_dict()}, "snipped_unet_starry.pth")

### Testing

In [None]:
x, y = dls.train.one_batch()

In [None]:
img = Image.open("starry.tif")
img = img.resize((512, 512))
arr = np.asarray(img) / 255
tensor = TensorImage(arr).permute(2, 0, 1)
save_image(tensor, "starry_resized.jpg")

In [None]:
save_image(x, "orig8.jpg")

In [None]:
#learn.model.eval()
#save_image(x.squeeze(0), "prestarrified.jpg")
#save_image(y[0].squeeze(0), "target.jpg")
with torch.cuda.amp.autocast():
    res = learn.model[0](x)
    save_image(res[0].squeeze(0), "snipped_unet8.jpg")