Set-up and model loading

In [None]:
import sys
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt

# Add the path to the `pytorch-CycleGAN-and-pix2pix` repository
repo_path = '/Users/ls/Sites/pytorch-CycleGAN-and-pix2pix'  # Update this to your repository path
sys.path.append(repo_path)

# Import the necessary modules from the repository
from models.pix2pix_model import Pix2PixModel
from models.networks import define_G

# Paths to your generator .pth files
model1_path = '/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Trained Models/ma-boston-p2p-500-150-v100/500_net_G.pth'
model2_path = '/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Trained Models/nc-charlotte-500-150-v100/500_net_G.pth'
model3_path = '/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Trained Models/ny-manhattan-p2p-500-150-v100/500_net_G.pth'
model4_path = '/Users/ls/Library/CloudStorage/GoogleDrive-l.schrage@northeastern.edu/Shared drives/Drawing Participation/Million Neighborhoods/Trained Models/pa-pittsburgh-p2p-500-150-v100/500_net_G.pth'

model_paths = [
    model1_path,
    model2_path,
    model3_path,
    model4_path
]

### Step 3: Load generators

In [None]:
def load_generator(model_path):
    # Assuming the input_nc and output_nc are set as per your training configuration
    input_nc = 3
    output_nc = 3
    ngf = 64
    netG = define_G(input_nc, output_nc, ngf, 'unet_256', norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[])
    netG.load_state_dict(torch.load(model_path))
    return netG

generators = [load_generator(path) for path in model_paths]

### Step 4. Define the Ensemble Method:

def ensemble_output(generators, input_image):
    outputs = [generator(input_image) for generator in generators]
    averaged_output = torch.mean(torch.stack(outputs), dim=0)
    return averaged_output

### Step 5: Evaluation

In [None]:
def evaluate_meta_model(meta_model, generators, dataloader):
    meta_model.eval()
    total_loss = 0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for data in dataloader:
            input_image = data['A']
            target_image = data['B']
            generator_outputs = [generator(input_image) for generator in generators]
            meta_input = torch.cat(generator_outputs, dim=1)
            meta_output = meta_model(meta_input)
            
            loss = criterion(meta_output, target_image)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    return avg_loss

# Assuming you have a DataLoader for your validation dataset
val_dataset = Pix2pixDataset(dataroot='path_to_val_dataset', phase='val', transform=None)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

validation_loss = evaluate_meta_model(meta_model, generators, val_dataloader)
print(f'Validation Loss: {validation_loss}')

### Step 6: Final Output

In [None]:
from PIL import Image
from torchvision import transforms

# Define a transformation to preprocess the input image
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load and preprocess your input image
input_image_path = '/path/to/your/input/image.jpg'  # Update with your input image path
input_image = Image.open(input_image_path).convert('RGB')
input_tensor = transform(input_image).unsqueeze(0)  # Add batch dimension

# Generate the final output
final_output = ensemble_output(generators, input_tensor)

# Convert the output tensor to an image
output_image = final_output.squeeze().detach().cpu().numpy().transpose(1, 2, 0)
output_image = (output_image * 0.5 + 0.5) * 255  # Denormalize
output_image = output_image.astype('uint8')

# Display the output image
from matplotlib import pyplot as plt
plt.imshow(output_image)
plt.axis('off')
plt.show()