In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import os
import itertools
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim

def pixelate_image(image, percent_pixels):
    """
    :param image: numpy array of shape (H, W) or (1, H, W)
    :param percent_pixels: from 0 to 1, percentage of removed pixels
    :return: original image with some pixels removed, mask of removed pixels, and the correct pixels
    """
    if image.ndim == 3 and image.shape[0] == 1:
        image = image.squeeze(0)
    elif image.ndim != 2:
        raise ValueError("Image must be of shape (H, W) or (1, H, W)")

    mask = np.ones_like(image, dtype=np.uint8)
    total_pixels = image.size
    num_pixels_to_remove = int(total_pixels * percent_pixels)
    indices_to_remove = np.random.choice(total_pixels, num_pixels_to_remove, replace=False)
    indices_to_remove = np.unravel_index(indices_to_remove, image.shape)

    correct_pixels = np.zeros_like(image)
    correct_pixels[indices_to_remove] = image[indices_to_remove]
    image[indices_to_remove] = 0
    mask[indices_to_remove] = 0

    return torch.from_numpy(image), torch.from_numpy(mask), torch.from_numpy(correct_pixels)

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc = self.encoder(x)
        middle = self.middle(enc)
        dec = self.decoder(middle)
        return dec

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.mse_loss = nn.MSELoss(reduction='none')

    def gaussian_window(self, size, sigma):
        gauss = np.array([np.exp(-(x - size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(size)])
        return torch.tensor(gauss / gauss.sum(), dtype=torch.float32)

    def create_window(self, window_size, channel):
        _1D_window = self.gaussian_window(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
        return window

    def ssim(self, img1, img2, window_size=11, size_average=True):
        (_, channel, _, _) = img1.size()
        window = self.create_window(window_size, channel)
        window = window.to(img1.device)

        mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
        mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

        mu1_sq = mu1.pow(2)
        mu2_sq = mu2.pow(2)
        mu1_mu2 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

        C1 = 0.01 ** 2
        C2 = 0.03 ** 2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

        if size_average:
            return ssim_map.mean()
        else:
            return ssim_map.mean(1).mean(1).mean(1)

    def forward(self, outputs, targets, mask):
        mse = self.mse_loss(outputs * (1 - mask), targets * (1 - mask))
        num_corrupted_pixels = torch.sum(1 - mask)
        mse = mse.sum() / num_corrupted_pixels
        ssim_value = self.ssim(outputs * (1 - mask), targets * (1 - mask))
        return self.alpha * mse + (1 - self.alpha) * (1 - ssim_value)

class GrayScaleImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, subset_size=None):
        self.image_dir = image_dir
        self.image_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if f.endswith('.png')]
        if subset_size:
            self.image_files = self.image_files[:subset_size]
        self.transform = transform
        self.masks = {}  # to store masks for each image

    def __getitem__(self, index):
        img_path = self.image_files[index]
        img = Image.open(img_path).convert('L')
        if self.transform:
            img = self.transform(img)
        pixelated_img, mask, correct_pixels = pixelate_image(np.array(img), percent_pixels=0.2)
        self.masks[index] = mask  # store the mask
        return img, pixelated_img.float().unsqueeze(0), mask.float().unsqueeze(0)

    def __len__(self):
        return len(self.image_files)

def evaluate_model(model, dataloader, device, num_samples=3):
    model.eval()
    sample_count = 0

    with torch.no_grad():
        for inputs, pixelated_imgs, masks in dataloader:
            inputs = inputs.unsqueeze(1).to(device)
            pixelated_imgs = pixelated_imgs.to(device)
            masks = masks.to(device)
            outputs = model(pixelated_imgs)
            outputs = outputs.squeeze(1)

            for i in range(inputs.size(0)):
                if sample_count >= num_samples:
                    break

                original = inputs[i].cpu().squeeze().numpy()
                corrupted = pixelated_imgs[i].cpu().squeeze().numpy()
                reconstructed = outputs[i].cpu().numpy()
                mask = masks[i].cpu().squeeze().numpy()

                # apply the reconstructed pixels only where the mask is 0 (where pixels were removed)
                filled_reconstructed = corrupted.copy()
                filled_reconstructed[mask == 0] = reconstructed[mask == 0]

                print(f'Sample {sample_count + 1}')
                plt.figure(figsize=(15, 5))

                plt.subplot(1, 3, 1)
                plt.imshow(original, cmap='gray')
                plt.title('Original Image')
                plt.axis('off')

                plt.subplot(1, 3, 2)
                plt.imshow(corrupted, cmap='gray')
                plt.title('Pixelated Image')
                plt.axis('off')

                plt.subplot(1, 3, 3)
                plt.imshow(filled_reconstructed, cmap='gray')
                plt.title('Filled Reconstructed Image')
                plt.axis('off')

                plt.show()

                mse = np.mean((original - filled_reconstructed) ** 2)
                psnr = 20 * np.log10(1.0 / np.sqrt(mse))
                ssim_value = ssim(original, filled_reconstructed, data_range=filled_reconstructed.max() - filled_reconstructed.min())

                print(f'MSE: {mse:.4f}, PSNR: {psnr:.4f}, SSIM: {ssim_value:.4f}')

                sample_count += 1

            if sample_count >= num_samples:
                break

def train_model(learning_rate, optimizer, activation_function, batch_size, early_stopping_patience=3):
    model = UNet().to(device)
    criterion = CombinedLoss(alpha=0.5)
    optimizer = optimizer(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    num_epochs = 10
    model.train()

    best_val_loss = float('inf')
    best_hyperparams = None
    best_model = None
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch')
        for inputs, pixelated_imgs, masks in progress_bar:
            inputs = inputs.to(device)
            pixelated_imgs = pixelated_imgs.to(device)
            masks = masks.to(device)
            optimizer.zero_grad()
            outputs = model(pixelated_imgs)

            # loss only for masked pixels
            loss = criterion(outputs, inputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            progress_bar.set_postfix(loss=running_loss / len(train_loader))

            # update progress bar every 100 (adjust to make it less frequently)
            if progress_bar.n % 100 == 0:
                progress_bar.set_postfix(loss=running_loss / (progress_bar.n + 1))

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for val_inputs, val_pixelated_imgs, val_masks in val_loader:
                val_inputs = val_inputs.to(device)
                val_pixelated_imgs = val_pixelated_imgs.to(device)
                val_masks = val_masks.to(device)
                val_outputs = model(val_pixelated_imgs)
                val_loss += criterion(val_outputs, val_inputs, val_masks).item()
        val_loss /= len(val_loader)
        print(f'Validation Loss: {val_loss:.4f}')

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict()
            best_hyperparams = (learning_rate, optimizer, activation_function, batch_size)
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == early_stopping_patience:
            print("Early stopping!")
            break

    return best_val_loss, best_model, best_hyperparams

transform_list = [transforms.ToTensor()]
transform = transforms.Compose(transform_list)

image_dir = 'data/training_preprocessed'
val_dir = 'data/val_data'

train_dataset = GrayScaleImageDataset(image_dir=image_dir, transform=transform)
val_dataset = GrayScaleImageDataset(image_dir=val_dir, transform=transform)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

hyperparams = {
    'learning_rate': [0.001, 0.01],
    'optimizer': [optim.Adam],
    'activation_function': [F.relu],
    'batch_size': [8, 16]
}

hyperparam_combinations = list(itertools.product(*hyperparams.values()))

best_loss = float('inf')
best_hyperparams = None
best_model = None

os.makedirs("models", exist_ok=True)
os.makedirs("plots", exist_ok=True)

for idx, combination in enumerate(hyperparam_combinations):
    lr, opt, act_func, batch_size = combination
    model_name = f"models/model_v_20{idx + 1}.pth"
    plot_name = f"plots/results_v_20{idx + 1}.png"
    print(f'Training with learning_rate={lr}, optimizer={opt.__name__}, activation_function={act_func.__name__}, batch_size={batch_size}')

    val_loss, model_state, best_hyperparams = train_model(lr, opt, act_func, batch_size, early_stopping_patience=10)

    torch.save(model_state, model_name)

    if val_loss < best_loss:
        best_loss = val_loss
        best_hyperparams = combination
        best_model = model_state

    model = UNet().to(device)
    model.load_state_dict(model_state)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    evaluate_model(model, val_loader, device, num_samples=3)

In [None]:
if best_hyperparams:
    print(f'Best Hyperparameters: learning_rate={best_hyperparams[0]}, optimizer={best_hyperparams[1].__class__.__name__}, activation_function={best_hyperparams[2].__name__}, batch_size={best_hyperparams[3]}')
    print(f'Best Validation Loss: {best_loss:.4f}')

    torch.save(best_model, "models/model_64.pth")

    model = UNet().to(device)
    model.load_state_dict(best_model)
    val_loader = DataLoader(val_dataset, batch_size=best_hyperparams[3], shuffle=False)
    evaluate_model(model, val_loader, device, num_samples=3)
else:
    print("No valid hyperparameter combination found.")