In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import random
import numpy as np
import time
import cv2
import vgg
import transformer

In [2]:
# Gram Matrix
def gram(tensor):
    B, C, H, W = tensor.shape
    x = tensor.view(B, C, H*W)
    x_t = x.transpose(1, 2)
    return  torch.bmm(x, x_t) / (C*H*W)

In [3]:
def imgtotensor(img):
    # Method to convert Image to Tensor
    if (True == True):
        itot_t = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.mul(255))
        ]) 

    # Convert image to tensor
    tensor = itot_t(img)
    # Add the batch_size dimension
    tensor = tensor.unsqueeze(dim=0)
    return tensor


def tensortoimg(tensor):
    # Method to convert Tensor to Image

    # Remove the batch_size dimension
    tensor = tensor.squeeze()
    img = tensor.cpu().numpy()
    # Transpose from [C, H, W] -> [H, W, C]
    img = img.transpose(1, 2, 0)
    return img

In [4]:
# save image to watch the progress
def saveimg(img, image_path):
    img = img.clip(0, 255)
    cv2.imwrite(image_path, img)

In [11]:
# train method

def train():
    # Setting device as cuda
    device = "cpu"
#     device = ("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 25
    CONTENT_WEIGHT = 15
    STYLE_WEIGHT = 30
    
    # Seeds to get same results
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Dataset and Dataloader
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder("E:/UNH/Sem3/DL/Pro/Multi-Style-Transfer/dataset", transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)

    # Load networks to device
    TransformerNetwork = transformer.TransformerNetwork().to(device)
    VGG = vgg.VGG16().to(device)
    
    # Get Style Features
#     imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).to(device)
    style_image = cv2.imread("images/rain_princess.jpg")
    style_tensor = imgtotensor(style_image).to(device)
#     style_tensor = style_tensor.add(imagenet_neg_mean)
    B, C, H, W = style_tensor.shape
    style_features = VGG(style_tensor.expand([4, C, H, W]))
    style_gram = {}
    for key, value in style_features.items():
        style_gram[key] = gram(value)

    # Optimizing using Adam
    optimizer = optim.Adam(TransformerNetwork.parameters(), lr=0.001)
    # Loss trackers
    content_loss_history = []
    style_loss_history = []
    total_loss_history = []
    batch_content_loss_sum = 0
    batch_style_loss_sum = 0
    batch_total_loss_sum = 0

    # Optimization/Training Loop
    batch_count = 1
    start_time = time.time()
    
    NUM_EPOCHS = 1
    
    for epoch in range(NUM_EPOCHS):
        print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS))
        for content_batch, _ in train_loader:
            # Get current batch size in case of odd batch sizes
            curr_batch_size = content_batch.shape[0]

            # Free-up unneeded cuda memory
            torch.cuda.empty_cache()

            # Zero-out Gradients
            optimizer.zero_grad()

            # Generate images and get features
            content_batch = content_batch[:,[2,1,0]].to(device)
            generated_batch = TransformerNetwork(content_batch)
            content_features = VGG(content_batch)
            generated_features = VGG(generated_batch)

            # Content Loss
            MSELoss = nn.MSELoss().to(device)
            content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2'])            
            batch_content_loss_sum += content_loss

            # Style Loss
            style_loss = 0
            for key, value in generated_features.items():
                s_loss = MSELoss(gram(value), style_gram[key][:curr_batch_size])
                style_loss += s_loss
            style_loss *= STYLE_WEIGHT
            batch_style_loss_sum += style_loss.item()

            # Total Loss
            total_loss = content_loss + style_loss
            batch_total_loss_sum += total_loss.item()

            # Backprop and Weight Update
            total_loss.backward()
            optimizer.step()

            SAVE_MODEL = 500
            
            # Save Model and Print Losses
            if (((batch_count-1)%SAVE_MODEL == 0) or (batch_count==NUM_EPOCHS*len(train_loader))):
                # Print Losses
                print("********Iteration {}/{}********".format(batch_count, NUM_EPOCHS*len(train_loader)))
                print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/batch_count))
                print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/batch_count))
                print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/batch_count))

                # Save Model
                checkpoint_path = "models" + "checkpoint_" + str(batch_count-1) + ".pth"
                torch.save(TransformerNetwork.state_dict(), checkpoint_path)
                print("Saved TransformerNetwork Checkpoint at {}".format("models"))

                # Save generated image
                sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0)
                sample_image = tensortoimg(sample_tensor.clone().detach())
                sample_image_path = "images/out/" + "sample0_" + str(batch_count-1) + ".png"
                saveimg(sample_image, sample_image_path)
                print("Saved sample tranformed image at {}".format(sample_image_path))

                # Save loss histories
                content_loss_history.append(batch_total_loss_sum/batch_count)
                style_loss_history.append(batch_style_loss_sum/batch_count)
                total_loss_history.append(batch_total_loss_sum/batch_count)

            # Iterate Batch Counter
            batch_count+=1
    
    stop_time = time.time()
    # Print loss histories
    print("Finished Training the Transformer Network!")
    print("Training Time: {} seconds".format(stop_time-start_time))
    print("********Content Loss********")
    print(content_loss_history) 
    print("********Style Loss********")
    print(style_loss_history) 
    print("********Total Loss********")
    print(total_loss_history) 

    # Save TransformerNetwork weights
    TransformerNetwork.eval()
    TransformerNetwork.cpu()
    final_path = "models/" + "transformer_generated_model.pth"
    print("Saving Transformer Network Model at {}".format("models"))
    torch.save(TransformerNetwork.state_dict(), "models")
    print("Done saving final model")



In [12]:
train()

********Iteration 1/30822********
	Content Loss:	441839.75
	Style Loss:	603995.75
	Total Loss:	1045835.50
Saved TransformerNetwork Checkpoint at models
Saved sample tranformed image at images/out/sample0_0.png


KeyboardInterrupt: 