In [None]:
import os
from os import path
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import torch
from generator import *
from discriminator import *
from feature_extractor import *
from dataset import *

from torch.utils.data import DataLoader
from torch.autograd import Variable
import sys
from torchvision.utils import save_image, make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import itertools

print(os.getcwd())

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
num = 100206
img = ImagePair(number=num, root_dir='data')
slice = img.img()['LR'][:,:,25]
plt.imshow(slice, cmap='gray')

In [None]:
transform = transforms.Compose([
    ToTensor(),
])

tra_set = ImagePairDataset('training', transform=transform)
val_set = ImagePairDataset('validation', transform=transform)

In [None]:
print('Length of training set: \t{}\nLength of validation set: \t{}'
      .format(len(tra_set),len(val_set)))

num = 25
sample = tra_set[num]
title = 'Image pair {}'.format(sample['id'])
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
fig.set_facecolor('white')
fig.suptitle(title)
ax1.imshow(np.squeeze(sample['LR']),cmap='gray')
ax1.set_title('LR')
ax1.axis('off')
ax2.imshow(np.squeeze(sample['HR']),cmap ='gray')
ax2.set_title('HR')
ax2.axis('off')

In [None]:
batch_size = 4
n_cpu = 2
tra_dataloader = DataLoader(
    tra_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)

val_dataloader = DataLoader(
    val_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)


In [None]:
generator = GeneratorRRDB(channels=1, filters=64, num_res_blocks=1).cuda()
summary(generator, (1, 224, 224))

In [8]:
discriminator = Discriminator(input_shape=(1,224,224)).cuda()
summary(discriminator, (1, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]             640
         LeakyReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 112, 112]          36,928
       BatchNorm2d-4         [-1, 64, 112, 112]             128
         LeakyReLU-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
       BatchNorm2d-7        [-1, 128, 112, 112]             256
         LeakyReLU-8        [-1, 128, 112, 112]               0
            Conv2d-9          [-1, 128, 56, 56]         147,584
      BatchNorm2d-10          [-1, 128, 56, 56]             256
        LeakyReLU-11          [-1, 128, 56, 56]               0
           Conv2d-12          [-1, 256, 56, 56]         295,168
      BatchNorm2d-13          [-1, 256, 56, 56]             512
        LeakyReLU-14          [-1, 256,

In [10]:
writer = SummaryWriter('runs/model_vis')
writer.add_graph(generator, sample['LR'].cuda())
writer.add_graph(discriminator, sample['LR'].cuda())

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = GeneratorRRDB(channels=1, filters=64, num_res_blocks=1).to(device)
discriminator = Discriminator(input_shape=(1,224,224)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

In [10]:
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

In [26]:
class Trainer:
    def __init__(self,
                 generator,
                 discriminator,
                 training_loader,
                 validation_loader,
                 feature_extractor,
                 lr = 0.0002,
                 b1 = 0.9,
                 b2 = 0.999,
                 epochs = 10,
                 warmup_batches = 100,
                 lambda_adv = 5e-3,
                 lambda_pixel = 1e-2,
                 sample_interval = 100,
                 checkpoint_interval = 1000,
                 ):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
        self.netG = generator.to(self.device)
        self.netD = discriminator.to(self.device)
        self.netF = feature_extractor.to(self.device)

        self.criterion_GAN = torch.nn.BCEWithLogitsLoss().to(self.device)
        self.criterion_content = torch.nn.L1Loss().to(self.device)
        self.criterion_pixel = torch.nn.L1Loss().to(self.device)

        self.lambda_adv = lambda_adv,
        self.lambda_pixel = lambda_pixel,
        self.sample_interval = sample_interval

        self.training_loader = training_loader
        self.validation_loader = validation_loader
        self.batch_size = batch_size
        self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=lr, betas=(b1, b2))
        self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=lr, betas=(b1, b2))
        self.metric = {
            'train_loss_G': [],
            'train_loss_content_G': [],
            'train_loss_adversarial_G': [],
            'train_loss_pixel_G': [],
            'train_loss_D': [],
            'val_loss_G': [],
            'val_loss_content_G': [],
            'val_loss_adversarial_G': [],
            'val_loss_pixel_G': [],
            'val_loss_D': [],

        }
        # self.output_dir = output_dir
        # os.makedirs(output_dir, exist_ok=True)
        # self.writer = SummaryWriter(output_dir)

    def training(self, train_batch):
        self.netG.train()
        self.netD.train()
        imgs_lr = Variable(train_batch['LR'].type(self.Tensor))
        imgs_hr = Variable(train_batch['HR'].type(self.Tensor))

        # Adversarial ground truths
        valid = Variable(self.Tensor(np.ones((imgs_lr.size(0), *self.netD.output_shape))), requires_grad=False)
        fake = Variable(self.Tensor(np.zeros((imgs_lr.size(0), *self.netD.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        self.optimizer_G.zero_grad()

        # Generate a high resolution image from low resolution input
        gen_hr = self.netG(imgs_lr)

        # Measure pixel-wise loss against ground truth
        loss_pixel = self.criterion_pixel(gen_hr, imgs_hr)

        # if batches_done < warmup_batches:
        #     # Warm-up (pixel-wise loss only)
        #     loss_pixel.backward()
        #     optimizer_G.step()
        #     print(
        #         "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
        #         % (epoch, epochs, i, len(tra_dataloader), loss_pixel.item())
        #     )
        #     continue

        # Extract validity predictions from discriminator
        pred_real = self.netD(imgs_hr).detach()
        pred_fake = self.netD(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = self.netF(torch.repeat_interleave(gen_hr,3,1))
        real_features = self.netF(torch.repeat_interleave(imgs_hr,3,1)).detach()
        loss_content = self.criterion_content(gen_features, real_features)

        # Total generator loss
        loss_G = loss_content + self.lambda_adv * loss_GAN + self.lambda_pixel * loss_pixel

        loss_G.backward()
        self.optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        self.optimizer_D.zero_grad()

        pred_real = self.netD(imgs_hr)
        pred_fake = self.netD(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = self.criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        self.optimizer_D.step()

        self.metric['train_loss_G'].append(loss_G.item())
        self.metric['train_loss_content_G'].append(loss_content.item())
        self.metric['train_loss_adversarial_G'].append(loss_GAN.item())
        self.metric['train_loss_pixel_G'].append(loss_pixel.item())
        self.metric['train_loss_D'].append(loss_D.item())
        return gen_hr

    @torch.no_grad()
    def validate(self, validation_batch):
        self.netG.eval()
        self.netD.eval()
        imgs_lr = Variable(validation_batch['LR'].type(self.Tensor))
        imgs_hr = Variable(validation_batch['HR'].type(self.Tensor))

        # Adversarial ground truths
        valid = Variable(self.Tensor(np.ones((imgs_lr.size(0), *self.netD.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *self.netD.output_shape))), requires_grad=False)

        # ------------------
        #  Validate Generators
        # ------------------

        # Generate a high resolution image from low resolution input
        gen_hr = self.netG(imgs_lr)

        # Measure pixel-wise loss against ground truth
        loss_pixel = self.criterion_pixel(gen_hr, imgs_hr)

        # if batches_done < warmup_batches:
        #     # Warm-up (pixel-wise loss only)
        #     loss_pixel.backward()
        #     optimizer_G.step()
        #     print(
        #         "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
        #         % (epoch, epochs, i, len(tra_dataloader), loss_pixel.item())
        #     )
        #     continue

        # Extract validity predictions from discriminator
        pred_real = self.netD(imgs_hr).detach()
        pred_fake = self.netD(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = self.netF(torch.repeat_interleave(gen_hr,3,1))
        real_features = self.netF(torch.repeat_interleave(imgs_hr,3,1)).detach()
        loss_content = self.criterion_content(gen_features, real_features)

        # Total generator loss
        loss_G = loss_content + lambda_adv * loss_GAN + lambda_pixel * loss_pixel

        # ---------------------
        #  Validate Discriminator
        # ---------------------

        pred_real = self.netD(imgs_hr)
        pred_fake = self.netD(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = self.criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        self.metric['val_loss_G'].append(loss_G.item())
        self.metric['val_loss_content_G'].append(loss_content.item())
        self.metric['val_loss_adversarial_G'].append(loss_GAN.item())
        self.metric['val_loss_pixel_G'].append(loss_pixel.item())
        self.metric['val_loss_D'].append(loss_D.item())
        return gen_hr

    def fit(self, epochs):
        training_loader = self.training_loader
        validation_loader = self.validation_loader
        sys.stdout.flush()
        for epoch in range(epochs):
            print('Epoch %d'%(epoch+1))
            with tqdm(desc=('Training'), total=len(training_loader)) as pbar:
                for i, (training_batch, validation_batch) in enumerate(itertools.zip_longest(training_loader,
                                                                           validation_loader)):
                    batches_done = epoch * len(tra_dataloader) + i
                    gen_hr_train = self.training(training_batch)
                    gen_hr_val = self.validate(training_batch)

                    if batches_done % self.sample_interval == 0:
                        img_grid = torch.cat((training_batch['LR'].to(device), training_batch['HR'].to(device), gen_hr_val, (torch.abs(training_batch['LR'].to(device)-gen_hr_val)*2)), -1)
                        path = os.path.join(os.getcwd(), 'images', '%d.png' % batches_done)
                        save_image(img_grid, path, nrow=1, normalize=False)


                    it_metrics = {
                        "train_loss_G": self.metric["train_loss_G"][-1],
                        "train_loss_D": self.metric["train_loss_D"][-1],
                    }
                    pbar.set_postfix(**it_metrics)
                    pbar.update()
            torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
            torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" %epoch)
        sys.stdout.flush()
        return self.metric

In [27]:
trainer = Trainer(
    generator=generator,
    discriminator=discriminator,
    feature_extractor=feature_extractor,
    training_loader=tra_dataloader,
    validation_loader=val_dataloader,
)

metrics = trainer.fit(2)

Epoch 1


Training:   0%|                                        | 0/1750 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.76 GiB total capacity; 9.40 GiB already allocated; 21.69 MiB free; 9.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:

writer.close()