In [None]:
%load_ext autoreload
%autoreload 2
from srgan_utils import *

In [None]:
import os
import math
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## TRAIN META-SRGAN

In [None]:
checkpoint_path = 'srgan_checkpoint.pth'
start_epoch = 0
history = {
    'g_loss_epoch': [],
    'd_loss_epoch': [],
    'g_loss_iter': [],
    'd_loss_iter': []
}

if os.path.exists(checkpoint_path):
    ckpt = torch.load(checkpoint_path)
    netG.load_state_dict(ckpt['netG'])
    netD.load_state_dict(ckpt['netD'])
    optimizerG.load_state_dict(ckpt['optimizerG'])
    optimizerD.load_state_dict(ckpt['optimizerD'])
    schedulerG.load_state_dict(ckpt['schedulerG'])
    schedulerD.load_state_dict(ckpt['schedulerD'])
    history = ckpt['history']
    start_epoch = ckpt['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")


In [None]:
num_epochs = 5
for epoch in range(start_epoch, num_epochs):
    avg_g, avg_d, iter_g, iter_d = train_srgan(epoch)

    # store losses
    history['g_loss_epoch'].append(avg_g)
    history['d_loss_epoch'].append(avg_d)
    history['g_loss_iter'].append(iter_g)
    history['d_loss_iter'].append(iter_d)

    # step schedulers
    schedulerG.step()
    schedulerD.step()

    # save checkpoint
    torch.save({
        'epoch':       epoch,
        'netG':        netG.state_dict(),
        'netD':        netD.state_dict(),
        'optimizerG':  optimizerG.state_dict(),
        'optimizerD':  optimizerD.state_dict(),
        'schedulerG':  schedulerG.state_dict(),
        'schedulerD':  schedulerD.state_dict(),
        'history':     history
    }, checkpoint_path)


In [None]:
# After training or after loading the final checkpoint:
torch.save(netG.state_dict(), 'generator.pth')

In [None]:
ckpt = torch.load('/kaggle//input/metasrgan-1/srgan_checkpoint_1.pth')

In [None]:
history = ckpt['history']


In [None]:
import matplotlib.pyplot as plt

try:
    history
except NameError:
    import torch
    ckpt = torch.load('srgan_checkpoint.pth', map_location='cpu')
    history = ckpt['history']

# Flatten iteration losses across all epochs
gen_iter = [loss for epoch_losses in history['g_loss_iter'] for loss in epoch_losses]
disc_iter = [loss for epoch_losses in history['d_loss_iter'] for loss in epoch_losses]

iterations = range(1, len(gen_iter) + 1)

plt.figure()
plt.plot(iterations, gen_iter[10:], label='Generator Loss')
plt.plot(iterations, disc_iter[10:], label='Discriminator Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Generator & Discriminator Loss per Iteration')
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
ckpt = torch.load('srgan_checkpoint.pth', map_location='cpu')
history = ckpt['history']

# Flatten iteration losses across all epochs
gen_iter = [loss for epoch_losses in history['g_loss_iter'] for loss in epoch_losses]
disc_iter = [loss for epoch_losses in history['d_loss_iter'] for loss in epoch_losses]

iterations = range(1, len(gen_iter) + 1)

plt.figure()
plt.plot(iterations, gen_iter, label='Generator Loss')
plt.plot(iterations, disc_iter, label='Discriminator Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Generator & Discriminator Loss per Iteration')
plt.legend()
plt.tight_layout()
plt.show()