In [1]:
import torch
from models import create_model


class Options:
    def __init__(self):
        self.model = 'pix2pix'
        self.gpu_ids = []
        self.isTrain = False
        self.checkpoints_dir = './checkpoints/floorplan_model'
        self.name = 'floorplan'
        self.preprocess = None
        self.input_nc = 3
        self.output_nc = 3
        self.ngf = 64
        self.netG = 'unet_256'
        self.netD = 'basic'
        self.norm = 'batch'
        self.no_dropout = True
        self.init_type = 'normal'
        self.init_gain = 0.02

opt = Options()
model = create_model(opt)

initialize network with normal
model [Pix2PixModel] was created


In [None]:
model.netG.load_state_dict(torch.load('checkpoints/floorplan_model/130_net_G.pth'))

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import torch

def gen_image(input_image):
    # Define the transformations
    transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize to the size that your model expects
        transforms.ToTensor(),  # Convert the PIL Image to a tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the tensor
    ])

    # Apply the transformations
    input_tensor = transform(input_image)

    # Add an extra batch dimension since pytorch treats all images as batches
    input_tensor = input_tensor.unsqueeze(0)

    # Ensure the tensor is on the same device as the model
    input_tensor = input_tensor.to(next(model.netG.parameters()).device)

    # Pass the tensor through the model
    with torch.no_grad():
        output_tensor = model.netG(input_tensor)

    # Remove the batch dimension
    output_tensor = output_tensor.squeeze(0)

    # Convert the tensor to an image
    output_image = transforms.ToPILImage()(output_tensor)

    return output_image

In [None]:
import os
import random
import matplotlib.pyplot as plt

# Load the input image
list_images = os.listdir('predicted_wob_imgs')
input_image = Image.open('predicted_wob_imgs/' + random.choice(list_images))

# Display the images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(input_image)
ax[0].set_title('Input Image')
ax[0].axis('off')

output_image = gen_image(input_image)
ax[1].imshow(output_image)
ax[1].set_title('Output Image')
ax[1].axis('off')
plt.show()