In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import os

from tqdm import tqdm
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
import torch.optim as optim
from torchvision import transforms

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
cfg = {
    "data": {
        "content_image_path": "/content/drive/MyDrive/Colab Notebooks/data/content.jpg",
        "style_image_path": "/content/drive/MyDrive/Colab Notebooks/data/style.jpg",
    },
    "trainer": {
        "optimizer": "LBGFS", # "Adam"
        "lr" : 0.1,
        "epochs": 100, # 1000
        "step_size": 10,
        "alpha": 1.0,
        "beta": 1e6
    }
}

In [6]:
def preprocess_image(image: Image.Image):
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    image_tensor = transform(image)
    image_tensor = image_tensor.unsqueeze(0)

    return image_tensor

def postprocess_image(tensor: torch.Tensor):
    image = tensor.to("cpu").detach().numpy() # (1, c, h, w)
    image = image.squeeze() # (c,h,w)
    image = image.transpose(1,2,0) # (h, w, c)
    image = image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406] # denormalization
    image = image.clip(0,1)*255 # 픽셀 값을 0에서 1 사이로 제한한 후 255를 곱하여 0-255 범위로 변환
    image = image.astype(np.uint8)

    return Image.fromarray(image)

In [7]:
class DataModule:
    def __init__(self, content_image_path, style_image_path):
        self.content_image_path = content_image_path
        self.style_image_path = style_image_path

    def _load_preprocessed_image(self, image_path):
        image = Image.open(image_path)
        image_tensor = preprocess_image(image)

        return image_tensor

    def get_image_tensors(self):
        content_image_tensor = self._load_preprocessed_image(self.content_image_path)
        style_image_tensor = self._load_preprocessed_image(self.style_image_path)

        return content_image_tensor, style_image_tensor

    def get_noise_image_tensor_to_image(self, noise_image_tensor):
        return postprocess_image(noise_image_tensor)

In [8]:
# Load Data
datamodule = DataModule(cfg["data"]["content_image_path"], cfg["data"]["style_image_path"])
content_image_tensor, style_image_tensor = datamodule.get_image_tensors()
content_image_tensor = content_image_tensor.to(device)
style_image_tensor = style_image_tensor.to(device)

In [9]:
conv = {
    "conv1_1" : 0,
    "conv2_1" : 5,
    "conv3_1" : 10,
    "conv4_1" : 19,
    "conv5_1" : 28,
    "conv4_2" : 21
}

class StyleTransfer(nn.Module):
    def __init__(self):
        super(StyleTransfer, self).__init__()
        self.vgg19 = vgg19(pretrained=True)
        self.vgg19_features = self.vgg19.features

        self.style_layer = [conv["conv1_1"], conv["conv2_1"], conv["conv3_1"], conv["conv4_1"], conv["conv5_1"]]
        self.content_layer = conv["conv4_2"]

    def forward(self, x, mode: str):
        features = []

        if mode == "style":
            for i in range(len(self.vgg19_features)):
                x = self.vgg19_features[i](x)
                if i in self.style_layer:
                    features.append(x)
        elif mode == "content":
            for i in range(len(self.vgg19_features)):
                x = self.vgg19_features[i](x)
                if i == self.content_layer:
                    features.append(x)
        else:
            raise ValueError("Invalid mode")

        return features

In [10]:
# Load Model
model = StyleTransfer().to(device)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 164MB/s]


In [11]:
class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()

    def forward(self, input, target):
        loss = F.mse_loss(input, target)
        return loss

class StyleLoss(nn.Module):
    def __init__(self):
        super(StyleLoss, self).__init__()

    def gram_matrix(self, input):
        b, c, h, w = input.size()
        features = input.view(b, c, -1)
        gram = torch.bmm(features, features.transpose(1, 2))
        return gram.div(b * c * h * w) # normalize

    def forward(self, input, target):
        loss = F.mse_loss(self.gram_matrix(input), self.gram_matrix(target))
        return loss

In [12]:
# Load Loss
content_loss_fn = ContentLoss()
style_loss_fn = StyleLoss()

In [13]:
# Initialize Noise Image
noise_image_tensor = content_image_tensor.clone().requires_grad_(True)

In [14]:
if cfg["trainer"]["optimizer"] == "Adam":
    optimizer = optim.Adam([noise_image_tensor], lr=cfg["trainer"]["lr"])
elif cfg["trainer"]["optimizer"] == "LBGFS":
    optimizer = optim.LBFGS([noise_image_tensor], lr=cfg["trainer"]["lr"])
else:
    raise ValueError(f"Select optimizer from Adam and LBGFS")

In [15]:
def compute_loss(
    model,
    content_image_tensor,
    style_image_tensor,
    noise_image_tensor,
    content_loss_fn,
    style_loss_fn,
    cfg,
):
    x_content_list = model(noise_image_tensor, mode="content")
    y_content_list = model(content_image_tensor, mode="content")

    x_style_list = model(noise_image_tensor, mode="style")
    y_style_list = model(style_image_tensor, mode="style")

    content_loss, style_loss, total_loss = 0.0, 0.0, 0.0

    for x_content, y_content in zip(x_content_list, y_content_list):
        content_loss += content_loss_fn(x_content, y_content)

    for x_style, y_style in zip(x_style_list, y_style_list):
        style_loss += style_loss_fn(x_style, y_style)

    total_loss = cfg["trainer"]["alpha"] * content_loss + cfg["trainer"]["beta"] * style_loss

    return content_loss, style_loss, total_loss

In [16]:
def train(
    model,
    datamodule,
    content_image_tensor,
    style_image_tensor,
    noise_image_tensor,
    optimizer,
    content_loss_fn,
    style_loss_fn,
    current_epoch,
    cfg,
):
    model.eval()

    def closure():
        optimizer.zero_grad()
        content_loss, style_loss, total_loss = compute_loss(
            model,
            content_image_tensor,
            style_image_tensor,
            noise_image_tensor,
            content_loss_fn,
            style_loss_fn,
            cfg,
        )
        total_loss.backward()
        return total_loss

    if isinstance(optimizer, optim.LBFGS):
        optimizer.step(closure)
    else:
        content_loss, style_loss, total_loss = compute_loss(
            model,
            content_image_tensor,
            style_image_tensor,
            noise_image_tensor,
            content_loss_fn,
            style_loss_fn,
            cfg,
        )
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    if current_epoch % cfg["trainer"]["step_size"] == 0:
        if isinstance(optimizer, optim.LBFGS):
            with torch.no_grad():
                content_loss, style_loss, total_loss = compute_loss(
                    model,
                    content_image_tensor,
                    style_image_tensor,
                    noise_image_tensor,
                    content_loss_fn,
                    style_loss_fn,
                    cfg,
                )
                print(
                    f"Content Loss: {content_loss.cpu().item()}, Style Loss: {style_loss.cpu().item()}, Total Loss: {total_loss.cpu().item()}"
                )
        else:
            print(
                f"Content Loss: {content_loss.cpu().item()}, Style Loss: {style_loss.cpu().item()}, Total Loss: {total_loss.cpu().item()}"
            )

        gen_image = datamodule.get_noise_image_tensor_to_image(noise_image_tensor)
        if not os.path.exists(r"/content/drive/MyDrive/Colab Notebooks/output"):
            os.makedirs(r"/content/drive/MyDrive/Colab Notebooks/output", exist_ok=True)
        gen_image.save(
            f"/content/drive/MyDrive/Colab Notebooks/output/{cfg['trainer']['optimizer']}_epoch_{current_epoch}.jpg"
        )

In [None]:
for epoch in tqdm(range(cfg["trainer"]["epochs"])):
    train(
        model,
        datamodule,
        content_image_tensor,
        style_image_tensor,
        noise_image_tensor,
        optimizer,
        content_loss_fn,
        style_loss_fn,
        current_epoch=epoch,
        cfg=cfg,
    )

  1%|          | 1/100 [00:09<15:17,  9.27s/it]

Content Loss: 4.194981575012207, Style Loss: 0.0002535551320761442, Total Loss: 257.7501220703125


  3%|▎         | 3/100 [00:25<13:14,  8.19s/it]