In [1]:
def setup_collab():
    %pip install torchmetrics
    
    # Download the dataset
    !wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz
    !tar -xvzf BSDS300-images.tgz
    !rm BSDS300-images.tgz
    
    COLLAB_ININITALIZED = True

In [2]:
# check if we are running on colab
try:
    import google.colab
    print("In Collab")
    if COLLAB_INITIALIZED == True:
        print("Colab Already Initialized")
    else:
        print("Initializing Colab")
        setup_collab()
        print("Colab Initialized")
except:
    print("Outside Collab")

Outside Collab


In [3]:
import os
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
import tqdm  # for nice progress bars
from matplotlib import pyplot as plt
from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure

from torch import nn
from unet import Unet

from datasets import NoisyCleantDataset, NoisyNoisyDataset

# Dataset

In [4]:
dataset_dir = "./BSDS300"

train_set_dir = f"{dataset_dir}/images/train"
train_img_files = [f"{train_set_dir}/{filename}" for filename in os.listdir(train_set_dir)]
# use this to train with fewer data
# train_img_files = random.sample(train_img_files, 50)

test_set_dir = f"{dataset_dir}/images/test"
test_img_files = [f"{test_set_dir}/{filename}" for filename in os.listdir(test_set_dir)]
val_img_files = test_img_files[:50]
test_img_files = test_img_files[50:]

In [5]:
noise_var = 0.015  # more noise makes denoising harder; we suggest you keep this value but you can also experiment with more or less noise
train_chunk_size = 128  # depends on your hardware; larger chunks require more memory during gradient computation; we recommend 128

train_set_supervised = NoisyCleantDataset(img_files=train_img_files, noise_var=noise_var, chunk_size=train_chunk_size)
train_set_sel_supervised = NoisyNoisyDataset(img_files=train_img_files, noise_var=noise_var, chunk_size=train_chunk_size)

# for validation and testing, we do not have to split the images into chunks because we do not have to compute gradients
# the images have shape (321, 481) or (481, 321) so we crop them to (321, 321) to facilitate data loading
val_set = NoisyCleantDataset(img_files=val_img_files, noise_var=noise_var, chunk_size=321)
test_set = NoisyCleantDataset(img_files=test_img_files, noise_var=noise_var, chunk_size=321)

# Training

In [6]:
def train_model(model,  train_loader, val_loader, optimizer, loss_fn_main, extra_loss_functions_validation, max_epochs, check_val_every_epochs, limit_train_length=None):
    best_validation_loss = float("inf")
    best_model_parameters = None
    history_main_validation_losses = []
    history_additional_validation_losses = {
        key: [] for key in extra_loss_functions_validation.keys()
    }

    no_of_training_examples = 0

    print("Starting training")
    for e in range(max_epochs):
        # Train one epoch
        for imgs_noisy, imgs_clean in tqdm.tqdm(train_loader, desc="Training"):

            # limit train length (when optional argument is set)
            if limit_train_length is not None:
                batch_size = imgs_noisy.shape[0]
                no_of_training_examples_reamaining = limit_train_length - no_of_training_examples
                if no_of_training_examples_reamaining <= 0:
                    print("Preemptively stopping training because limit_train_length was reached")
                    break
                elif no_of_training_examples_reamaining < batch_size:
                    imgs_noisy = imgs_noisy[:no_of_training_examples_reamaining]
                    imgs_clean = imgs_clean[:no_of_training_examples_reamaining]
                    no_of_training_examples += no_of_training_examples_reamaining
                else:
                    no_of_training_examples += batch_size
            
            # Proper training
            optimizer.zero_grad()

            imgs_noisy = imgs_noisy.to(device)
            imgs_clean = imgs_clean.to(device)
        
            out = model(imgs_noisy)  # forward pass
            imgs_denoised = imgs_noisy - out            
            loss = loss_fn_main(imgs_denoised, imgs_clean)

            # run gradient update
            loss.backward()
            optimizer.step()
        
        # validate model
        if (e + 1) % check_val_every_epochs == 0:
            loss_main_loss_function_averaged, losses_additional_functions_averaged = validate_model(model, val_loader, loss_fn_main, extra_loss_functions_validation)
            # add losses to history
            history_main_validation_losses.append(loss_main_loss_function_averaged)
            for key, loss in losses_additional_functions_averaged.items():
                history_additional_validation_losses[key].append(loss)

            print(f"Validation loss (main function) after epoch {e}: {loss_main_loss_function_averaged}")
            print(f"Validation losses (additional functions) after epoch {e}: {losses_additional_functions_averaged}")
            
            # save best model parameters
            if loss_main_loss_function_averaged < best_validation_loss:
                best_validation_loss = loss_main_loss_function_averaged
                best_model_parameters = model.state_dict()
            
            # save model checkpoint
            print("Checkpoint saved")
            torch.save({
                "epoch": e,
                "model_state_dict": model.state_dict(),
                "best_model_state_dict": best_model_parameters,
                "optimizer_state_dict": optimizer.state_dict(),
                "history_main_validation_loss": history_main_validation_losses,
                "history_additional_validation_losses": history_additional_validation_losses,
            }, f"checkpoint_{model.name}.pth")

    print("Training finished")

    # one final evaluation at the end of training
    loss_main_loss_function_averaged, losses_additional_functions_averaged = validate_model(model, val_loader, loss_fn_main, extra_loss_functions_validation)
    # add losses to history
    history_main_validation_losses.append(loss_main_loss_function_averaged)
    for key, loss in losses_additional_functions_averaged.items():
        history_additional_validation_losses[key].append(loss)
    
    print(f"Validation loss (main function) after training: {loss_main_loss_function_averaged}")
    print(f"Validation losses (additional functions) after training: {losses_additional_functions_averaged}")

    # save best model parameters
    if loss_main_loss_function_averaged < best_validation_loss:
        best_validation_loss = loss_main_loss_function_averaged
        best_model_parameters = model.state_dict()

    # save model checkpoint
    print("Checkpoint saved")
    torch.save({
        "epoch": max_epochs,
        "model_state_dict": model.state_dict(),
        "best_model_state_dict": best_model_parameters,
        "optimizer_state_dict": optimizer.state_dict(),
        "history_main_validation_loss": history_main_validation_losses,
        "history_additional_validation_losses": history_additional_validation_losses,
    }, f"checkpoint_{model.name}.pth")

    # load best model parameters
    model.load_state_dict(best_model_parameters)

def validate_model(model, val_loader, main_loss_function, additional_loss_functions):
    losses_main_loss_function = []
    losses_additional_validation_functions  = {
        key: [] for key in additional_loss_functions.keys()
    }
    # disable gradient computation for validation
    with torch.no_grad():
        for imgs_noisy, imgs_clean in tqdm.tqdm(val_loader, desc="Validation"):
            imgs_noisy = imgs_noisy.to(device)
            imgs_clean = imgs_clean.to(device)
    
            out = model(imgs_noisy)  # forward pass
            imgs_denoised = imgs_noisy - out
            
            # generate loss for main loss function
            losses_main_loss_function.append(main_loss_function(imgs_denoised, imgs_clean).item())
            
            # generate all losses
            for key, loss_fn in additional_loss_functions.items():
                losses_additional_validation_functions[key].append(loss_fn(imgs_denoised, imgs_clean).item())

    # compute average losses
    loss_main_loss_function_averaged = sum(losses_main_loss_function) / len(losses_main_loss_function)
    losses_additional_functions_averaged = {}
    for key, loss_list in losses_additional_validation_functions.items():
        losses_additional_functions_averaged[key] = sum(loss_list) / len(loss_list)
    
    return loss_main_loss_function_averaged, losses_additional_functions_averaged



def test_model(model, test_loader, loss_functions):
    losses = {
        key: [] for key in loss_functions.keys()
    }

    with torch.no_grad():
        for imgs_noisy, imgs_clean in tqdm.tqdm(test_loader, desc="Test"):
            imgs_noisy = imgs_noisy.to(device)
            imgs_clean = imgs_clean.to(device)

            out = model(imgs_noisy)  # forward pass
            imgs_denoised = imgs_noisy - out

            # generate all losses
            for key, loss_fn in loss_functions.items():
                losses[key].append(loss_fn(imgs_denoised, imgs_clean).item())
    
    # compute average losses
    for key, loss_list in losses.items():
        losses[key] = sum(loss_list) / len(loss_list)
    
    return losses

In [7]:
def print_training_curves(checkpoint_path):

    checkpoint = torch.load(checkpoint_path)
    history_main_validation_losses = checkpoint["history_main_validation_loss"]
    history_additional_validation_losses = checkpoint["history_additional_validation_losses"]

    no_of_losses_tested = len(history_additional_validation_losses.keys()) + 1
    max_plots_per_row = 3
    no_of_rows = no_of_losses_tested // max_plots_per_row + (1 if no_of_losses_tested % max_plots_per_row != 0 else 0)
    fig, ax = plt.subplots(no_of_rows, max_plots_per_row, figsize=(15, 5 * no_of_rows))
    ax = ax.flatten()
    ax[0].plot(history_main_validation_losses)
    ax[0].set_title("Main loss function")
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Loss")
    ax[0].grid()
    for i, (key, loss_list) in enumerate(history_additional_validation_losses.items()):
        ax[i + 1].plot(loss_list)
        ax[i + 1].set_title(key)
        ax[i + 1].set_xlabel("Epoch")
        ax[i + 1].set_ylabel("Loss")
        ax[i + 1].grid()
    plt.show()

# Model Definition

In [8]:
# more pooling layers and convolutional kernels increase the complexity of the U-Net (see lecture notes)
num_pool_layers = 4
chans = 64

# check if cuda is available
device = "cuda" if torch.cuda.is_available() else "cpu"



batch_size = 32 if device == "cuda" else 4  # depends on your hardware

train_loader_supervised = DataLoader(train_set_supervised, batch_size=batch_size, shuffle=True)
train_loader_sel_supervised = DataLoader(train_set_sel_supervised, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# define loss function
mse = torch.nn.MSELoss()  # use the mean squared error loss function
psnr =  peak_signal_noise_ratio # peak signal-to-noise ratio
ssim = structural_similarity_index_measure # structural similarity index measure

learning_rate = 1e-3
epochs = 50
check_val_every_epochs = 1

## Train supervised model

In [9]:
model_supervised = Unet(
    in_chans=1,  # 1 input channel as we use grayscale images as input
    out_chans=1,  # 1 output channel as the model returns grayscale images
    num_pool_layers=num_pool_layers,
    chans=chans,
    name="Unet_supervised",
)
model_supervised = model_supervised.to(device)
optimizer_supervised = torch.optim.Adam(model_supervised.parameters(), lr=learning_rate) 

In [10]:
train_model(
    model=model_supervised,
    optimizer=optimizer_supervised,
    train_loader=train_loader_supervised,
    val_loader=val_loader,
    loss_fn_main=mse,
    extra_loss_functions_validation={"PSNR": psnr, "SSIM": ssim},
    max_epochs=epochs,
    check_val_every_epochs=check_val_every_epochs,
    limit_train_length=300
)

Starting training


Training:  15%|█▌        | 46/300 [03:17<18:12,  4.30s/it]


KeyboardInterrupt: 

In [None]:
model_path = "./checkpoint_Unet_supervised.pth"
print_training_curves(model_path)

# Train self supervised model

In [None]:
model_self_supervised = Unet(
    in_chans=1,  # 1 input channel as we use grayscale images as input
    out_chans=1,  # 1 output channel as the model returns grayscale images
    num_pool_layers=num_pool_layers,
    chans=chans,
    name="Unet_self_supervised",
)
model_self_supervised = model_self_supervised.to(device)
optimizer_self_supervised = torch.optim.Adam(model_self_supervised.parameters(), lr=learning_rate) 

In [None]:
train_model(
    model=model_self_supervised,
    optimizer=optimizer_self_supervised,
    train_loader=train_loader_sel_supervised,
    val_loader=val_loader,
    loss_fn_main=mse,
    extra_loss_functions_validation={"PSNR": psnr, "SSIM": ssim},
    max_epochs=epochs,
    check_val_every_epochs=check_val_every_epochs,
)

In [None]:
model_path = "./checkpoint_Unet_supervised.pth"
print_training_curves(model_path)