In [1]:
import cv2
import numpy as np
import os
import time, random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.models import vgg16, VGG16_Weights

In [2]:
class VGG16(nn.Module):
    def __init__(self, vgg_path="/kaggle/input/vgg-pretrained/pytorch/default/1/vgg16-00b39a1b.pth"):
        super(VGG16, self).__init__()
        vgg16_features = vgg16(weights=None)
        state = torch.load(vgg_path, map_location='cpu', weights_only=False)

        vgg16_features.load_state_dict(state, strict=False)
        self.features = vgg16_features.features

        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, x):
        layers = {'3': 'relu1_2', 
                  '8': 'relu2_2', 
                  '15': 'relu3_3', 
                  '22': 'relu4_3'}
        
        features = {}
        for name, layer in self.features._modules.items():
            x = layer(x)
            if name in layers:
                features[layers[name]] = x
                if (name=='22'):
                    break

        return features

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(c, c, 3, padding='same'),
            nn.InstanceNorm2d(c, affine=True, track_running_stats=False),
            nn.ReLU(),
            nn.Conv2d(c, c, 3, padding='same'),
            nn.InstanceNorm2d(c, affine=True, track_running_stats=False),
        )

    def forward(self, x):
        return x + self.block(x)

class TransformerNetModern(nn.Module):
    def __init__(self, tanh_multiplier=150.0):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 9, stride=1, padding='same'),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.resblocks = nn.Sequential(*[ResidualBlock(128) for _ in range(5)])
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 9, stride=1, padding='same'),
            nn.Tanh()
        )
        self.tanh_multiplier = tanh_multiplier

    def forward(self, x):
        out = self.decoder(self.resblocks(self.encoder(x))) * self.tanh_multiplier
        return out

In [None]:
def gram(tensor):
    B, C, H, W = tensor.shape
    x = tensor.view(B, C, H*W)
    x_t = x.transpose(1, 2)
    return  torch.bmm(x, x_t) / (C*H*W)

def load_image(path):
    # Images loaded as BGR
    img = cv2.imread(path)
    return img

def saveimg(img, image_path):
    img = img.clip(0, 255)
    cv2.imwrite(image_path, img)

def itot(img, max_size=None):
    if (max_size==None):
        itot_t = transforms.Compose([
            #transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])    
    else:
        H, W, C = img.shape
        image_size = tuple([int((float(max_size) / max([H,W]))*x) for x in [H, W]])
        itot_t = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ])

    tensor = itot_t(img)
    tensor = tensor.unsqueeze(dim=0)
    return tensor

def ttoi(tensor):

    # Remove the batch_size dimension
    tensor = tensor.squeeze()
    #img = ttoi_t(tensor)
    img = tensor.cpu().numpy()
    
    # Transpose from [C, H, W] -> [H, W, C]
    img = img.transpose(1, 2, 0)
    return img

def plot_loss_hist(c_loss, s_loss, total_loss, title="Loss History", save_dir="/kaggle/working/"):
    os.makedirs(save_dir, exist_ok=True)
    x = [i for i in range(len(total_loss))]
    plt.figure(figsize=[10, 6])
    plt.plot(x, c_loss, label="Content Loss")
    plt.plot(x, s_loss, label="Style Loss")
    plt.plot(x, total_loss, label="Total Loss")
    plt.legend()
    plt.xlabel('Every 500 iterations')
    plt.ylabel('Loss')
    plt.title(title)
    plt.grid(True, linestyle='--', alpha=0.6)  

    save_path = os.path.join(save_dir, f"{title.replace(' ', '_').lower()}.png")
    plt.savefig(save_path, bbox_inches='tight')
    plt.show()
    plt.close()

In [5]:
TRAIN_IMAGE_SIZE = 256
DATASET_PATH = "/kaggle/input/dataset"
NUM_EPOCHS = 4
STYLE_IMAGE_PATH = "/kaggle/input/the-scream/Edvard-Munch-The-Scream.jpg"
BATCH_SIZE = 8 
CONTENT_WEIGHT = 17 
STYLE_WEIGHT = 50 
ADAM_LR = 0.001
SAVE_MODEL_PATH = "/kaggle/working/models/"
SAVE_IMAGE_PATH = "/kaggle/working/images/"
SAVE_MODEL_EVERY = 1000
SEED = 35
PLOT_LOSS = 1

def train():
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)
    os.makedirs(SAVE_MODEL_PATH, exist_ok=True)
    os.makedirs(SAVE_IMAGE_PATH, exist_ok=True)

    device = ("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize(TRAIN_IMAGE_SIZE),
        transforms.CenterCrop(TRAIN_IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    TransformerNetwork = TransformerNetModern().to(device)
    VGG = VGG16().to(device)

    MSELoss = nn.MSELoss().to(device)
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=ADAM_LR)

    imagenet_neg_mean = torch.tensor(
        [-103.939, -116.779, -123.68], 
        dtype=torch.float32).reshape(1,3,1,1).to(device)
    
    style_image = load_image(STYLE_IMAGE_PATH)
    style_tensor = itot(style_image).to(device)
    style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = gram(value)

    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    batch_count = 1
    start_time = time.time()
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            curr_batch_size = content_batch.shape[0]
            torch.cuda.empty_cache()
            optimizer.zero_grad()

            content_batch = content_batch[:,[2,1,0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch.add(imagenet_neg_mean))
            generated_features = VGG(generated_batch.add(imagenet_neg_mean))

            content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])            
            batch_content_loss_sum += content_loss.item()

            style_loss = 0.0
            for key, value in generated_features.items():
                s_loss = MSELoss(gram(value), style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            total_loss.backward()
            optimizer.step()

            if (((batch_count-1)%SAVE_MODEL_EVERY == 0) or (batch_count==NUM_EPOCHS*len(train_loader))):
                print("========Iteration {}/{}========".format(batch_count, NUM_EPOCHS*len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/batch_count))
                print("Time elapsed:\t{} seconds".format(time.time()-start_time))

                # Save Model
                checkpoint_path = os.path.join(
                    SAVE_MODEL_PATH, f"checkpoint_{batch_count-1}.pth"
                )
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path))

                # Save sample generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0)
                sample_image = ttoi(sample_tensor.clone().detach())
                sample_image_path = os.path.join(
                    SAVE_IMAGE_PATH, f"sample0_{batch_count-1}.png"
                )
                saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(sample_image_path))

                content_loss_history.append(batch_content_loss_sum/batch_count)
                style_loss_history.append(batch_style_loss_sum/batch_count)
                total_loss_history.append(batch_total_loss_sum/batch_count)

            batch_count+=1

    stop_time = time.time()
    print("Done Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time-start_time))
    print("========Content Loss========")
    print(content_loss_history) 
    print("========Style Loss========")
    print(style_loss_history) 
    print("========Total Loss========")
    print(total_loss_history) 

    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = os.path.join(SAVE_MODEL_PATH, "transformer_weight.pth")
    print("Saving TransformerNetwork weights at {}".format(final_path))
    torch.save(TransformerNetwork.state_dict(), final_path)
    print("Done saving final model")

    if (PLOT_LOSS):
        plot_loss_hist(content_loss_history, style_loss_history, total_loss_history)

In [None]:
train()

	Content Loss:	5231780.50
	Style Loss:	135423712.00
	Total Loss:	140655488.00
Time elapsed:	0.7170536518096924 seconds
Saved TransformerNetwork checkpoint file at /kaggle/working/models/checkpoint_0.pth
Saved sample tranformed image at /kaggle/working/images/sample0_0.png


[ WARN:0@106.580] global loadsave.cpp:1063 imwrite_ Unsupported depth image for selected encoder is fallbacked to CV_8U.


	Content Loss:	2248715.04
	Style Loss:	4402975.79
	Total Loss:	6651690.76
Time elapsed:	288.1394329071045 seconds
Saved TransformerNetwork checkpoint file at /kaggle/working/models/checkpoint_1000.pth
Saved sample tranformed image at /kaggle/working/images/sample0_1000.png
	Content Loss:	1937235.20
	Style Loss:	2899472.34
	Total Loss:	4836707.50
Time elapsed:	575.3273952007294 seconds
Saved TransformerNetwork checkpoint file at /kaggle/working/models/checkpoint_2000.pth
Saved sample tranformed image at /kaggle/working/images/sample0_2000.png
	Content Loss:	1777177.39
	Style Loss:	2380769.65
	Total Loss:	4157947.02
Time elapsed:	862.3278906345367 seconds
Saved TransformerNetwork checkpoint file at /kaggle/working/models/checkpoint_3000.pth
Saved sample tranformed image at /kaggle/working/images/sample0_3000.png
	Content Loss:	1676921.56
	Style Loss:	2115106.05
	Total Loss:	3792027.59
Time elapsed:	1149.2903530597687 seconds
Saved TransformerNetwork checkpoint file at /kaggle/working/mod