In [15]:
import torch
import numpy as np
from PIL import Image
from models import Generator

device = torch.device("cuda")

model = Generator()
model.load_state_dict(torch.load('model_4.pth'))
model.eval()
model = model.to(device)

def split_into_patches(image, patch_size=256):
    patches = []
    img_np = np.array(image)
    for i in range(0, img_np.shape[0], patch_size):
        for j in range(0, img_np.shape[1], patch_size):
            patch = img_np[i:i+patch_size, j:j+patch_size, :]
            patches.append(patch)
    return patches

def process_patches(patches, model):
    processed_patches = []
    for patch in patches:
        patch_tensor = torch.tensor(patch).unsqueeze(0)  # Add batch dimension
        patch_tensor = patch_tensor.permute(0, 3, 1, 2)  # Rearrange to [B, C, H, W]
        # Assuming model expects float and normalized input
        patch_tensor = (patch_tensor.float() / 255.0).to(device)
        with torch.no_grad():
            processed_patch = model(patch_tensor)
        processed_patches.append(processed_patch.squeeze(0).cpu())  # Remove batch dimension
    return processed_patches

def reassemble_patches(patches, original_image_size, patch_size=256, magnification_factor=4):
    # Calculate the new dimensions of the upscaled image
    upscaled_image_width = original_image_size[0] * magnification_factor
    upscaled_image_height = original_image_size[1] * magnification_factor
    
    # Initialize an empty array for the reassembled image with upscaled dimensions
    reassembled_image = np.zeros((upscaled_image_height, upscaled_image_width, 3), dtype=np.float32)
    
    num_patches_per_row = int(np.ceil(original_image_size[0] / patch_size))
    
    for patch_index, patch in enumerate(patches):
        # Find the row and column number for the current patch
        row = patch_index // num_patches_per_row
        col = patch_index % num_patches_per_row
        
        # Calculate the position in the upscaled image where the patch should be placed
        start_row = row * patch_size * magnification_factor
        start_col = col * patch_size * magnification_factor
        
        # Convert tensor to numpy array, adjust for processing as needed
        patch_np = patch.permute(1, 2, 0).detach().cpu().numpy()
        
        end_row = start_row + patch_np.shape[0]
        end_col = start_col + patch_np.shape[1]
        
        # Place the patch in the correct position of the reassembled image
        reassembled_image[start_row:end_row, start_col:end_col, :] = patch_np

    return reassembled_image


# Example usage
image = Image.open('image.png').convert('RGB')
patches = split_into_patches(image, patch_size=256)
processed_patches = process_patches(patches, model)
reassembled_image = reassemble_patches(processed_patches, image.size, patch_size=256)
final_image = Image.fromarray((reassembled_image * 255).astype(np.uint8))
final_image.save('shaass.png')


1024