# Import Libraries

In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Load Data

In [None]:
image_path = "car-cartoon.jpg"

In [None]:
image = Image.open(image_path)

In [None]:
w, h = image.size

In [None]:
transform = transforms.Compose([
    transforms.Resize((h, w), interpolation=transforms.InterpolationMode.LANCZOS),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
image = transform(image) 
image = image.permute(1, 2, 0)
image = image.unsqueeze(0) 

In [None]:
image.shape

In [None]:
plt.imshow(image[0])

In [None]:
corrupted_image = (image + torch.randn_like(image) * 0.1).clip(0, 1).transpose(2, 3).transpose(1, 2)

In [None]:
plt.imshow(corrupted_image[0].transpose(0, 1).transpose(1, 2).data.numpy())

In [None]:
corrupted_image = corrupted_image.to(device)

In [None]:
z = torch.randn(corrupted_image.shape) * 0.1
z = z.to(device)

# Deep Image Prior

In [None]:
class DeepImagePrior(nn.Module):
    def __init__(self):
        super(DeepImagePrior, self).__init__()

        self.down_blocks = nn.ModuleList([
            self._down_block(3, 8, 3),
            self._down_block(8, 16, 3),
            self._down_block(16, 32, 3),
            self._down_block(32, 64, 3),
            self._down_block(64, 128, 3)
        ])

        self.skip_blocks = nn.ModuleList([
            self._skip_block(32, 4, 1),
            self._skip_block(64, 4, 1)
        ])

        self.up_blocks = nn.ModuleList([
            self._up_block(128 + 4, 128, 3),
            self._up_block(128 + 4, 64, 3),
            self._up_block(64, 32, 3),
            self._up_block(32, 16, 3),
            self._up_block(16, 8, 3)
        ])

        self.conv_out = nn.Conv2d(8, 3, 1, stride=1, padding=0, bias=True)

    def _down_block(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def _skip_block(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def _up_block(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
        )

    def forward(self, x):
        skip_connections = []
        for i, down in enumerate(self.down_blocks):
            x = down(x)
            if i == 2:
                skip_connections.append(self.skip_blocks[0](x))
            elif i == 3:
                skip_connections.append(self.skip_blocks[1](x))

        x = self.up_blocks[0](torch.cat((skip_connections[1][:, :, 4:-4, 6:-6], x), dim=1))
        x = self.up_blocks[1](torch.cat((skip_connections[0][:, :, 8:-8, 12:-12], x), dim=1))
        x = self.up_blocks[2](x)
        x = self.up_blocks[3](x)
        x = self.up_blocks[4](x)

        return torch.sigmoid(self.conv_out(x))

In [None]:
deep_image_prior = DeepImagePrior()
deep_image_prior.to(device)

# Train

In [None]:
lr = 0.01

In [None]:
optimizer = optim.Adam(deep_image_prior.parameters(), lr=0.01)

In [None]:
num_epochs = 2500

In [None]:
losses = []

In [None]:
for epoch in range(1, num_epochs + 1):
    predicted_image = deep_image_prior.forward(z)
    loss = F.mse_loss(predicted_image, corrupted_image)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    if epoch % 100 == 0:
        print(f"[Epoch {epoch}/{num_epochs}], [Loss: {loss.item():.4f}]")

# Results

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(losses)
plt.title("Loss Curve")
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

axes[0].set_title("Original Image")
axes[0].imshow(image[0])
axes[0].axis("off")

axes[1].set_title("Corrupted Image")
axes[1].imshow(np.transpose(corrupted_image.cpu().detach()[0], (1, 2, 0)))
axes[1].axis("off")

axes[2].set_title("Predicted Image")
axes[2].imshow(np.transpose(predicted_image.cpu().detach()[0], (1, 2, 0)))
axes[2].axis("off")

plt.show()