In [None]:
from ConvNeXtUNet.convnextv2_unet import convnextv2unet_atto

In [None]:
# Three models with predefined sizes are provided 
model = convnextv2unet_atto( ms_output=True)

Now, lets do a simple image denoising task as a demonstration of the U-Net

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
# This function creates a synthetic image
# We can add noise to this to simulate a noisy image

import random
def create_random_pixel_image(channels, height, width):
    """
    Generates a 2D tensor (image) of zeros with a single '1' at the center pixel.

    Args:
    channels: The channels of the image (number of channels).
    height: The height of the image (number of rows).
    width: The width of the image (number of columns).

    Returns:
    A torch.Tensor of shape (height, width) with dtype torch.float32.
    """
    if height <= 0 or width <= 0:
        raise ValueError("Height and width must be positive integers.")

    image = torch.zeros((channels, height, width), dtype=torch.float32)

    pixel_height = random.randint(0,height-2)
    pixel_width = random.randint(0, width-2)

    for channel in range(channels):
        image[channel, pixel_height, pixel_width] = 1.0
        image[channel, pixel_height+1, pixel_width+1] = 1.0
        image[channel, pixel_height, pixel_width+1] = 1.0
        image[channel, pixel_height+1, pixel_width] = 1.0

    return image.unsqueeze(0)

sample_image = create_random_pixel_image(3, 32, 32)
plt.imshow(sample_image[0].permute(1, 2, 0))

In [None]:
def add_noise(image, scale=0.5):
    B, C, W, H = image.shape
    image += torch.rand(B, C, W, H) * scale
    return image

noisy_sample = add_noise(sample_image)
plt.imshow(noisy_sample[0].permute(1, 2, 0))

Now, let's train our model.

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('mps')
    if torch.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')
device = get_device()
device

In [None]:
# These training parameters are not optimizd and are merely meant to quickly demonstrate training.

num_epochs = 10
num_batches = 500
batch_size = 3

lr = 2e-4
opt = torch.optim.Adam(model.parameters(), lr=lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, num_epochs)
from ConvNeXtUNet.losses.MultiScaleLoss import MultiScaleLoss
loss_func = MultiScaleLoss(loss_criterion=torch.nn.MSELoss())


model = model.to(device)

In [None]:
for epoch in range(num_epochs):
    total_loss_per_batch = 0
    for batch in range(num_batches):
        ground_truth = torch.cat([create_random_pixel_image(3, 32, 32) 
                                for _ in range(batch_size)])
        noisy_images = add_noise(ground_truth.clone())
        
        ground_truth = ground_truth.to(device)
        noisy_images = noisy_images.to(device)

        pred = model(noisy_images)
        loss = loss_func(pred, ground_truth)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss_per_batch += loss
    # MSE by default computes the mean, so divide by batches
    print(f'Epoch {epoch}: Loss: {total_loss_per_batch/num_batches:.2e} LR: {sched.get_last_lr()[0]:.2e}')
    sched.step()

In [None]:
def format_tensor(tensor):
    return tensor.detach().cpu().permute(1, 2, 0)

fig, ax = plt.subplots(nrows=2, ncols=2)

plt.subplot(2, 2, 1)
noisy_img = format_tensor(noisy_images[0])
plt.imshow(noisy_img)
plt.title('Noisy image')

plt.subplot(2, 2, 2)
pred_img = format_tensor(pred[-1][0])
plt.imshow(pred_img)
plt.title('Denoised image')

plt.subplot(2, 2, 3)
gt_image = format_tensor(ground_truth[0])
plt.imshow(gt_image)
plt.title('Ground truth')

plt.subplot(2, 2, 4)
plt.imshow(noisy_img-pred_img)
plt.title('Extracted noise')

plt.tight_layout()
plt.show()