In [9]:
%matplotlib notebook
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as td
import torchvision as tv
import pandas as pd
from PIL import Image
from matplotlib import pyplot as plt
import nntools as nt

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [22]:
class NNRegressor(nt.NeuralNetwork):
    def __init__(self):
        super(NNRegressor, self).__init__()
        self.MSE = nn.MSELoss(size_average=None, reduce=None, reduction='sum')
    def criterion(self, y, d):
        return self.MSE(y, d)

In [23]:
vgg = tv.models.vgg16(pretrained=True)

In [24]:
class ImageTransform(NNClassifier):
    def __init__(self):
        super(ImageTransform, self).__init__()
    
        self.pad = torch.nn.ReflectionPad2d(40)
        self.conv1 = torch.nn.Conv2d(3,32,9)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.conv2 = torch.nn.Conv2d(32,64,3,stride=2)
        self.bn2 = torch.nn.BatchNorm2d(64)
        self.conv3 = torch.nn.Conv2d(64,128,3,stride=2)
        self.bn3 = torch.nn.BatchNorm2d(128)
        self.res1 = nn.Sequential(
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128)
        )
        self.res2 = nn.Sequential(
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128)
        )
        self.res3 = nn.Sequential(
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128)
        )
        self.res4 = nn.Sequential(
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128,128,3),
            torch.nn.BatchNorm2d(128)
        )
        self.conv4 = torch.nn.Conv2d(128,64,3,stride=0.5)
        self.bn4 = torch.nn.BatchNorm2d(64)
        self.conv5 = torch.nn.Conv2d(64,32,3,stride=0.5)
        self.bn5 = torch.nn.BatchNorm2d(32)
        self.conv6 = torch.nn.Conv2d(32,3,3)
        self.tanh = torch.nn.Tanh()
        
    def forward(self, x):
        h = self.pad(x)
        h = F.relu(self.bn1(self.conv1(h)))
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.relu(self.bn3(self.conv3(h)))
        h = F.relu(h + self.res1(h))
        h = F.relu(h + self.res2(h))
        h = F.relu(h + self.res3(h))
        h = F.relu(h + self.res4(h))
        h = F.relu(self.bn4(self.conv4(h)))
        h = F.relu(self.bn5(self.conv5(h)))
        h = self.conv6(h)
        y = (self.tanh(h) + 1) * 255 / 2 # scaled tanh from paper
        return y
i = ImageTransform()

In [25]:
print(i)

ImageTransform(
  (MSE): MSELoss()
  (pad): ReflectionPad2d((40, 40, 40, 40))
  (conv1): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (res1): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (res2): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentu