In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt

## Hyperparameters

In [31]:
total_steps = 6000
learning_rate = 0.001
alpha = 1 # content
beta = 0.3 # style

device = torch.device('cpu')
image_size = 328
LOAD_DATA = True
LOAD_MODEL_FILE = "overfit.pth.tar"

## Transforms

In [32]:
loader = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[], std=[]) // we don't do it here
])

### VGG Model - Transfer Learning

Remember: We're taking a number of "intermediate" output from the VGG19 models: Take 0, 5, 10, 19, 28 layers

In [33]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.chosen_features = ['0','5','10','19','28']
        self.model = models.vgg19(pretrained=True).features[:29] # we don't need 29~ layers
    
    def forward(self, x):
        features = []
        
        for layer_num, layer in enumerate(self.model):
            x = layer(x)
            if str(layer_num) in self.chosen_features:
                features.append(x)
                
        return features

In [34]:
def load_image(image_name):
    image = Image.open(image_name) # PIL library
    image = loader(image).unsqueeze(0) # we need additional dimension
    return image

## Content/Style/Generated Images

### Content & Style Images

In [35]:
content_img = load_image("mom1.jpeg")
style_img = load_image('filter2.jpeg')

In [36]:
content_img.shape

torch.Size([1, 3, 328, 328])

### Generated Images - we perform gradient descent on this image

In [37]:
generated = torch.randn(content_img.shape, device=device, requires_grad = True)

if LOAD_DATA == True:
    generated = load_image('generated.png')
else:
    generated = content_img.clone().requires_grad_(True)

### Model

In [38]:
model = VGG().to(device).eval() #VGG().to(device=device).eval() - freeze parameters

### Load & Save checkpoint

In [39]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, optimizer):
    print("=> Loading checkpoint")
    optimizer.load_state_dict(checkpoint["optimizer"])

### Optimizer

In [40]:
optimizer = optim.Adam([generated], lr=learning_rate)

In [41]:
# if LOAD_DATA:
#     load_checkpoint(torch.load(LOAD_MODEL_FILE), optimizer)

In [42]:
for step in range(total_steps):
    generated_features = model(generated)
    content_features = model(content_img)
    style_features = model(style_img)
    
    style_loss = content_loss = 0
    
    for generated_F, content_F, style_F in zip(
        generated_features, content_features, style_features
    ):
        batch_size, channel, height, width = generated_F.shape # batch_size = 1
        
        # Content Loss
        content_loss += torch.mean((generated_F - content_F) ** 2) # Here, we have 2 so torch.mean() is just divide / 2
        
        # Style Loss
        
        # Gram Matrix - (nc x nc)
        G = generated_F.view(channel, height * width).mm(
            generated_F.view(channel, height * width).t()
        )
        
        S = style_F.view(channel, height * width).mm(
            style_F.view(channel, height * width).t()
        )
        
        style_loss += torch.mean((G - S)**2) # torch.mean <=> divide / 2
    
    # Total Loss
    total_loss = (alpha * content_loss) + (beta * style_loss)

    # Backprop
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
        
    if step % 10 == 0:
        print(f"--------Step {step}--------")
        print(f"Total Loss: {total_loss.item()}")
        
        # Save Checkpoint
        checkpoint = {
               "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, filename=LOAD_MODEL_FILE)
        
        save_image(generated, 'generated.png')
        print("---------Image Saved---------\n")

--------Step 0--------
Total Loss: 7881.265625
=> Saving checkpoint
---------Image Saved---------

--------Step 10--------
Total Loss: 7881.265625
=> Saving checkpoint
---------Image Saved---------



KeyboardInterrupt: 