In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.models import vgg16
from PIL import Image
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim
from lpips import LPIPS

def calculate_psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def train_model(model, loss_fn, optimizer, dataloader, num_epochs, device):
    lpips_fn = LPIPS(net='alex').to(device)
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        avg_psnr = 0
        avg_ssim = 0
        avg_lpips = 0
        for batch in dataloader:
            input_image1, input_image2, ref_image = [item.to(device) for item in batch]
            output_image = model(input_image1, input_image2)
            loss = loss_fn(output_image, ref_image, input_image1, input_image2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            psnr_value = calculate_psnr(output_image, ref_image).item()
            ssim_value = ssim(output_image.squeeze(0).permute(1, 2, 0).cpu().numpy(), ref_image.squeeze(0).permute(1, 2, 0).cpu().numpy(), multichannel=True)
            lpips_value = lpips_fn(output_image, ref_image).item()
            avg_psnr += psnr_value
            avg_ssim += ssim_value
            avg_lpips += lpips_value

        avg_psnr /= len(dataloader)
        avg_ssim /= len(dataloader)
        avg_lpips /= len(dataloader)

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}, LPIPS: {avg_lpips:.4f}')


