In [13]:
import os
import math
import imageio

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import io
from matplotlib import pyplot as plt

In [14]:
OUTPUT_DIR = 'out/'
IMAGE_PATH = "fossil_hill_trail25.jpg"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def load_image(image_path, device):
    image = io.read_image(image_path).float() / 255.0  # Normalize to [0, 1]
    image = image.permute(1, 2, 0)  # Rearrange dimensions to (H, W, C)
    return image.to(device)

In [15]:
def positional_encoding(x, L):

    frequencies = 2.0 ** torch.arange(L).float().to(x.device)
    x_in = x.unsqueeze(-1) * frequencies * 2 * torch.pi

    encoding = torch.cat([torch.sin(x_in), torch.cos(x_in)], dim=-1)
    encoding = torch.cat([x, encoding.reshape(*x.shape[:-1], -1)], dim=-1)

    return encoding

def psnr(image1, image2):

    mse = np.mean((image1 - image2) ** 2)

    if mse == 0:
        return 100

    return 20 * math.log10(1.0 / math.sqrt(mse))


In [16]:
def predict_image(model, image_shape, device):
    h, w = image_shape
    # Normalize coordinates
    x = torch.linspace(0, w - 1, w).repeat(h, 1).to(device) / w
    y = torch.linspace(0, h - 1, h).repeat(w, 1).transpose(0, 1).to(device) / h
    all_coords = torch.stack([x, y], dim=-1).view(-1, 2)

    with torch.no_grad():
        predicted_pixels = model(all_coords)

    return predicted_pixels.reshape(h, w, 3).cpu().numpy()


class MLP(nn.Module):
    def __init__(self, L=10):
        super().__init__()
        self.L = L
        self.layers = nn.Sequential(
            nn.Linear(2 * 2 * L + 2, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = positional_encoding(x, L=self.L)
        return self.layers(x)


def train_model(model, image, optimizer, criterion, iters=3200, batch_size=10000, device='cuda'):
    psnr_scores = []
    loss_values = []

    model.train()
    h, w = image.shape[:2]

    # Normalize coordinates
    x = torch.linspace(0, w - 1, w).repeat(h, 1).to(device) / w
    y = torch.linspace(0, h - 1, h).repeat(w, 1).transpose(0, 1).to(device) / h
    all_coords = torch.stack([x, y], dim=-1).view(-1, 2)
    all_pixels = image.view(-1, 3).to(device)

    for iteration in range(iters):
        # Sample random batch of pixels and their coordinates
        idx = torch.randint(0, h * w, (batch_size,), device=device)
        coords_batch = all_coords[idx]
        pixel_batch = all_pixels[idx]

        optimizer.zero_grad()
        outputs = model(coords_batch)
        loss = criterion(outputs, pixel_batch)
        loss.backward()
        optimizer.step()

        loss_values.append(loss.item())

        if (iteration + 1) % 100 == 0:
            print(f"Iteration {iteration + 1}/{iters}, Loss: {loss.item():.6f}")
            curr_psnr = psnr(image.cpu().numpy(), predict_image(model, (h, w), device))
            print(f"PSNR: {curr_psnr:.2f} dB")
            psnr_scores.append(curr_psnr)

            # Save intermediate outputs
            model.eval()
            predicted_image = predict_image(model, (h, w), device)
            plt.imsave(f"{OUTPUT_DIR}iter{iteration + 1}.jpg", predicted_image)
            model.train()

    plt.figure()
    plt.plot(range(100, iters + 1, 100), psnr_scores)
    plt.xlabel('Iteration')
    plt.ylabel('PSNR (dB)')
    plt.title('PSNR vs. Iteration')
    plt.savefig('psnr.png')

    plt.figure()
    plt.plot(range(1, iters + 1), loss_values)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Loss vs. Iteration')
    plt.savefig('loss.png')

    final_image = predict_image(model, (h, w), device)
    plt.imsave(f"{OUTPUT_DIR}final.png", final_image)

    print(f"Final PSNR: {psnr(image.cpu().numpy(), final_image):.2f} dB")


In [17]:
BATCH_SIZE = 10000
ITERS = 3200
LEARNING_RATE = 1e-2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

image = load_image(IMAGE_PATH, device)

model = MLP().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [18]:
train_model(model, image, optimizer, criterion, iters=ITERS, batch_size=BATCH_SIZE, device=device)

Iteration 100/3200, Loss: 0.019793
PSNR: 16.92 dB
Iteration 200/3200, Loss: 0.019948
PSNR: 17.17 dB
Iteration 300/3200, Loss: 0.018185
PSNR: 17.31 dB
Iteration 400/3200, Loss: 0.017625
PSNR: 17.41 dB
Iteration 500/3200, Loss: 0.017541
PSNR: 17.52 dB
Iteration 600/3200, Loss: 0.018477
PSNR: 17.60 dB
Iteration 700/3200, Loss: 0.017526
PSNR: 17.61 dB
Iteration 800/3200, Loss: 0.018052
PSNR: 17.68 dB


KeyboardInterrupt: 