# Ba-Nanos Art Generator

### Import the necessary libraries

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pymongo import MongoClient
from PIL import Image
import matplotlib.pyplot as plt
import io
import os

import torchvision.transforms as transforms
import torchvision.models as models

ModuleNotFoundError: No module named 'torch'

## Define the style and content layers to be used later in the code

In [None]:
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']

## Establish the connection to the MongoDB server

In [None]:
#########################################################################
# This will need to be updated to the server where the database is hosted
#########################################################################
client = MongoClient('mongodb+srv://bnanos-user:LeWIpAO2oQ9uMFgg@bananos.w7ajfnm.mongodb.net/?retryWrites=true&w=majority')
db = client['image_database']
images_collection = db['images']

## Defining Functions for Loading Images and Establishing Classes

### the fetch_image_from_database function is currently just a random image picker, it would be nice to change it up to be something that generates an ai image compiled from all of the images that we have in the database

In [None]:
def fetch_image_from_database(image_id):
    
    ####################################################################
    # This code needs to be tested once the database is linked
    ####################################################################
    image_document = images_collection.find_one({"image_id": image_id})
    ####################################################################
    
    
    if image_document:
        image_data = image_document["image_data"]
        return image_data
    else:
        return None

### Creating a place to house the loaded image after it has been uploaded

In [None]:
def image_loader(image_data):
    image = Image.open(io.BytesIO(image_data))
    # Fake batch dimension required to fit network's input dimensions
    image = loader(image).unsqueeze(0)
    return image.to(device, torch.float)

### The function that allows for user uploads to the style image variable

In [None]:
def load_user_image(image_path):
    valid_extensions = ['.jpg', '.jpeg', '.png', '.pdf']
    _, extension = os.path.splitext(image_path)
    if extension.lower() in valid_extensions:
        image = Image.open(image_path)
        image = loader(image).unsqueeze(0)
        return image.to(device, torch.float)
    else:
        raise ValueError("Invalid file format. Supported formats: .jpg, .jpeg, .png, .pdf")

## Establishing Classes

### This establishes the ContentLoss class, which establishes the initial content loss in the image generated by the database

In [None]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        self.loss = F.mse_loss(input, self.target)
        return input

### This establishes the StyleLoss class, which determines the initial content loss in the image provided by the user that will have the generated style mapped onto it

In [None]:
class StyleLoss(nn.Module):
    def __init__(self, target_feature):
        super(StyleLoss, self).__init__()
        self.target = gram_matrix(target_feature).detach()

    def forward(self, input):
        G = gram_matrix(input)
        self.loss = F.mse_loss(G, self.target)
        return input

### Computes how gram matrices will be used for the style images. This is how we will determine the most important factors that need to be present in order for a reinterpretation to appear similar to a specific art style

In [None]:
def gram_matrix(input):
    a, b, c, d = input.size()
    features = input.view(a * b, c * d)
    G = torch.mm(features, features.t())
    return G.div(a * b * c * d)

### Includes the StyleLoss class and initializes the style loss using the computed gram matrix

In [None]:
class Normalization(nn.Module):
    def __init__(self, mean, std):
        super(Normalization, self).__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def forward(self, img):
        return (img - self.mean) / self.std

## Initial Train and Incorporation

In [None]:
cnn = models.vgg19(pretrained=True).features.eval()

### Includes the normalization setup and initializes the VGG19 model with the normalization layer and incorporates the style and content layers as well as the function for building the style transfer model with the associated loss layers

In [None]:
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
                               style_img, content_img,
                               content_layers=content_layers_default,
                               style_layers=style_layers_default):
    normalization = Normalization(normalization_mean, normalization_std)

    content_losses = []
    style_losses = []
    model = nn.Sequential(normalization)

    i = 0
    for layer in cnn.children():
        if isinstance(layer, nn.Conv2d):
            i += 1
            name = 'conv_{}'.format(i)
        elif isinstance(layer, nn.ReLU):
            name = 'relu_{}'.format(i)
            layer = nn.ReLU(inplace=False)
        elif isinstance(layer, nn.MaxPool2d):
            name = 'pool_{}'.format(i)
        elif isinstance(layer, nn.BatchNorm2d):
            name = 'bn_{}'.format(i)
        else:
            raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))

        model.add_module(name, layer)

        if name in content_layers:
            target = model(content_img).detach()
            content_loss = ContentLoss(target)
            model.add_module("content_loss_{}".format(i), content_loss)
            content_losses.append(content_loss)

        if name in style_layers:
            target_feature = model(style_img).detach()
            style_loss = StyleLoss(target_feature)
            model.add_module("style_loss_{}".format(i), style_loss)
            style_losses.append(style_loss)

    for i in range(len(model) - 1, -1, -1):
        if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
            break

    model = model[:(i + 1)]

    return model, style_losses, content_losses

### Initialize optimizer for the input image

In [None]:
def get_input_optimizer(input_img):
    optimizer = optim.LBFGS([input_img])
    return optimizer

### Integrates the style transfer process using the input optimizer using the provided number of steps, style weight, and content weight

In [None]:
def run_style_transfer(cnn, normalization_mean, normalization_std,
                       content_img, style_img, input_img, num_steps=300,
                       style_weight=1000000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        normalization_mean, normalization_std, style_img, content_img)

    # We want to optimize the input and not the model parameters so we
    # update all the requires_grad fields accordingly
    input_img.requires_grad_(True)
    # We also put the model in evaluation mode, so that specific layers
    # such as dropout or batch normalization layers behave correctly.
    model.eval()
    model.requires_grad_(False)

    optimizer = get_input_optimizer(input_img)

    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:

        def closure():
            # correct the values of updated input image
            with torch.no_grad():
                input_img.clamp_(0, 1)

            optimizer.zero_grad()
            model(input_img)
            style_score = 0
            content_score = 0

            for sl in style_losses:
                style_score += sl.loss
            for cl in content_losses:
                content_score += cl.loss

            style_score *= style_weight
            content_score *= content_weight

            loss = style_score + content_score
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:4f} Content Loss: {:4f}'.format(
                    style_score.item(), content_score.item()))
                print()

            return style_score + content_score

        optimizer.step(closure)

    # a last correction...
    with torch.no_grad():
        input_img.clamp_(0, 1)

    return input_img

### This loop puts it all together and prints the output after converting the image into a number that can be used (between 0 and 1) to convert images

In [None]:
if __name__ == "__main__":
    # Set device for PyTorch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.set_default_device(device)
    
    # Initialize the VGG19 model
    cnn = models.vgg19(pretrained=True).features.eval()
    
    # Initialize normalization
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
    cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])
    normalization = Normalization(cnn_normalization_mean, cnn_normalization_std)
    
    # Initialize content and style layers
    content_layers_default = ['conv_4']
    style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
    
    # Desired size of the output image
    imsize = 512 if torch.cuda.is_available() else 128  # use small size if no GPU
    
    loader = transforms.Compose([
        transforms.Resize(imsize),  # scale imported image
        transforms.ToTensor()])  # transform it into a torch tensor
    
    
    #####################################################################################
    image_id = 1  # Replace with the appropriate image ID from your database
    style_img_data = fetch_image_from_database(image_id)
    #####################################################################################
    
    
    if style_img_data:
        style_img = image_loader(style_img_data)
        
        # Get user-provided content image
        content_image_path = input("Enter path to content image (.jpg, .png, or .pdf): ")
        try:
            content_img = load_user_image(content_image_path)
            
            assert style_img.size() == content_img.size(), \
                "We need to import style and content images of the same size"
            
            # Initialize input image
            input_img = content_img.clone()
            
            # Display input, style, and content images using Matplotlib
            unloader = transforms.ToPILImage()  # reconvert into PIL image
            plt.ion()
            
            def imshow(tensor, title=None):
                image = tensor.cpu().clone()
                image = image.squeeze(0)
                image = unloader(image)
                plt.imshow(image)
                if title is not None:
                    plt.title(title)
                plt.pause(0.001)
            
            plt.figure()
            imshow(input_img, title='Input Image')
            
            plt.figure()
            imshow(style_img, title='Style Image')
            
            plt.figure()
            imshow(content_img, title='Content Image')
            
            # Run style transfer
            output = run_style_transfer(
                cnn, cnn_normalization_mean, cnn_normalization_std,
                content_img, style_img, input_img
            )
            
            # Display the final output image
            plt.figure()
            imshow(output, title='Output Image')
            plt.show()
            
            # Display the output image
            plt.figure()
            imshow(output, title='Output Image')
            
            plt.ioff()
            plt.show()
            
        except Exception as e:
            print("Error:", str(e))
    else:
        print("Image not found in the database.")

ServerSelectionTimeoutError: localhost:27017: [Errno 61] Connection refused, Timeout: 30s, Topology Description: <TopologyDescription id: 64e3f2c1dd8ee1f015a268d9, topology_type: Unknown, servers: [<ServerDescription ('localhost', 27017) server_type: Unknown, rtt: None, error=AutoReconnect('localhost:27017: [Errno 61] Connection refused')>]>