In [1]:
import numpy as np
import pandas as pd
import os
import random
import time

import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset

from Modelos.generator import Generator
from Modelos.discriminator import Discriminator
from Modelos.vgg19 import FeatureExtractor
from Utils.metrics import rmse_metric

import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from PIL import Image

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# number of epochs of training
n_epochs = 200

# name of the dataset
dataset_path = "data/"

# adam: learning rate
lr = 1e-4

# epoch from which to start lr decay
decay_epoch = 100

# number of cpu threads to use during batch generation
n_cpu = 8

# high res. image height
hr_height = 1024

# high res. image width
hr_width = 1024

# number of image channels
channels = 1

checkpoint_interval = 500
sample_interval = 100

hr_shape = (hr_height, hr_width)

os.makedirs("images/validation", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class ImageDataset(Dataset):
    def __init__(self, files, hr_shape):

        self.files = [os.path.join(files, f) for f in sorted(os.listdir(files))]

        hr_height, hr_width = hr_shape

        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_width // 4)),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_width)),
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
            ]
        )

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])

        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

In [4]:
train_paths = dataset_path + 'Chest_X-Ray_train_HR/'
valid_paths = dataset_path + 'Chest_X-Ray_valid_HR/'
train_dataset = ImageDataset(train_paths, hr_shape=hr_shape)
valid_dataset = ImageDataset(valid_paths, hr_shape=hr_shape)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=n_cpu)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=n_cpu)

In [5]:
# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator(input_shape=(channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.MSELoss().to(device)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor



In [6]:
def train_step(lr, hr):

    optimizer_G.zero_grad()

    # Generate a high resolution image from low resolution input
    gen_hr = generator(lr)

    gen_features = feature_extractor(gen_hr)
    real_features = feature_extractor(hr).detach()
    mse = criterion_content(real_features, gen_features)
    
    gen_loss = criterion_GAN(gen_hr, torch.ones_like(gen_hr)) # Generator Loss
    
    perc_loss = mse + 0.001 * gen_loss # Total generator loss

    perc_loss.backward()
    optimizer_G.step()

    optimizer_D.zero_grad()

    pred_real = discriminator(hr)
    pred_fake = discriminator(gen_hr.detach())

    loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))
    loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))
    disc_loss = loss_real + loss_fake

    disc_loss.backward()
    optimizer_D.step()

    return perc_loss, loss_real, loss_fake, disc_loss

In [7]:
gen_loss_array = np.array([])
disc_loss_array = np.array([])
rmse_error_array = np.array([])
time_array = np.array([])
hr_loss_array = np.array([])
sr_loss_array = np.array([])

start_time = time.time()

step = 0

for epoch in range(n_epochs):

    # ----------
    #  Training
    # ----------
    gen_loss, disc_loss = 0, 0
    
    for i, imgs in enumerate(train_dataloader):
        
        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))
        
        perceptual_loss, hr_loss, sr_loss, discriminator_loss = train_step(imgs_lr, imgs_hr)

        step = step + 1
        
        if step % sample_interval == 0:
            # Evaluate on validation set
            generator.eval()  # Set generator to evaluation mode
            val_gen_loss = 0
            with torch.no_grad():
                for j, imgs_val in enumerate(valid_dataloader):
                    filename = valid_dataset.files[j % len(valid_dataset.files)]
                    if filename == 'data/Chest_X-Ray_valid_HR/10.jpeg':
                        val_lr = Variable(imgs_val["lr"].type(Tensor))
                        val_hr = Variable(imgs_val["hr"].type(Tensor))
                        break

                # Generate a high resolution image from low resolution input
                gen_hr = generator(val_lr)

                gen_hr = gen_hr.squeeze()
                gen_hr = (gen_hr - gen_hr.min()) / (gen_hr.max() - gen_hr.min())    # Escala de 0 a 1
                gen_hr = (gen_hr * 255).clamp(0, 255).to(torch.uint8)               # Escala de 0 a 255

                imagen_gen_hr = Image.fromarray(gen_hr.cpu().numpy(), mode='L')
                imagen_gen_hr.save("images/validation/%d.jpeg" % step)

                if step == 100:

                    val_lr = val_lr.squeeze()
                    val_lr = (val_lr - val_lr.min()) / (val_lr.max() - val_lr.min())    # Escala de 0 a 1
                    val_lr = (val_lr * 255).clamp(0, 255).to(torch.uint8)               # Escala de 0 a 255        

                    val_hr = val_hr.squeeze()
                    val_hr = (val_hr - val_hr.min()) / (val_hr.max() - val_hr.min())    # Escala de 0 a 1
                    val_hr = (val_hr * 255).clamp(0, 255).to(torch.uint8)               # Escala de 0 a 255
                    
                    imagen_val_lr = Image.fromarray(val_lr.cpu().numpy(), mode='L')
                    imagen_val_lr.save("images/validation/low_res_image.jpeg")
                    
                    imagen_val_hr = Image.fromarray(val_hr.cpu().numpy(), mode='L')
                    imagen_val_hr.save("images/validation/high_res_image.jpeg")
                
            rmse = rmse_metric(imagen_gen_hr, imagen_val_hr)
            
            train_time = time.time()
            
            rmse_error_array = np.append(rmse_error_array, rmse)
            disc_loss_array = np.append(disc_loss_array, discriminator_loss.cpu().detach().numpy())
            hr_loss_array = np.append(hr_loss_array, hr_loss.cpu().detach().numpy())
            sr_loss_array = np.append(sr_loss_array, sr_loss.cpu().detach().numpy())
            gen_loss_array = np.append(gen_loss_array, np.multiply(perceptual_loss.cpu().detach().numpy(), 100))
            current_time = train_time - start_time
            time_array = np.append(time_array, current_time)
            
            generator.train()  # Set generator back to training mode
            
            print(
                "[Step %d/%d] [D loss: %f] [G loss: %f] [RMSE: %f] (%f s)"
                % (
                    step,
                    n_epochs*100,
                    discriminator_loss.item(),
                    np.multiply(perceptual_loss.item(), 100),
                    rmse,
                    current_time
                )
            )

[Step 100/20000] [D loss: 1.053404] [G loss: 54.817975] [RMSE: 0.245855] (43.602375 s)
[Step 200/20000] [D loss: 1.033267] [G loss: 44.147399] [RMSE: 0.281122] (86.901388 s)
[Step 300/20000] [D loss: 1.008616] [G loss: 11.982422] [RMSE: 0.248959] (129.971481 s)
[Step 400/20000] [D loss: 1.007506] [G loss: 17.024611] [RMSE: 0.217545] (173.474516 s)
[Step 500/20000] [D loss: 1.006924] [G loss: 26.410806] [RMSE: 0.216819] (216.556970 s)
[Step 600/20000] [D loss: 1.006771] [G loss: 11.808009] [RMSE: 0.222058] (259.702239 s)
[Step 700/20000] [D loss: 1.006709] [G loss: 18.360469] [RMSE: 0.228524] (302.785100 s)
[Step 800/20000] [D loss: 1.006757] [G loss: 19.036126] [RMSE: 0.263577] (345.817084 s)
[Step 900/20000] [D loss: 1.006646] [G loss: 16.744851] [RMSE: 0.213912] (388.716626 s)
[Step 1000/20000] [D loss: 1.006592] [G loss: 12.992120] [RMSE: 0.226862] (431.515910 s)
[Step 1100/20000] [D loss: 1.006624] [G loss: 13.042663] [RMSE: 0.245001] (474.293986 s)
[Step 1200/20000] [D loss: 1.006

Traceback (most recent call last):
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 740, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 738, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-_ugw1k06'


[Step 5800/20000] [D loss: 1.006901] [G loss: 8.085157] [RMSE: 0.036997] (2355.169229 s)
[Step 5900/20000] [D loss: 1.017564] [G loss: 5.564396] [RMSE: 0.069613] (2395.007697 s)
[Step 6000/20000] [D loss: 1.119776] [G loss: 25.295499] [RMSE: 0.121206] (2434.809464 s)
[Step 6100/20000] [D loss: 1.010320] [G loss: 5.626583] [RMSE: 0.088034] (2474.605409 s)
[Step 6200/20000] [D loss: 1.027138] [G loss: 9.991968] [RMSE: 0.079260] (2514.411633 s)
[Step 6300/20000] [D loss: 1.007472] [G loss: 19.840378] [RMSE: 0.070109] (2554.221030 s)
[Step 6400/20000] [D loss: 1.009128] [G loss: 20.214723] [RMSE: 0.119384] (2594.109306 s)
[Step 6500/20000] [D loss: 1.014060] [G loss: 7.883669] [RMSE: 0.085722] (2633.917382 s)
[Step 6600/20000] [D loss: 1.006656] [G loss: 5.887897] [RMSE: 0.090308] (2673.727709 s)
[Step 6700/20000] [D loss: 1.037934] [G loss: 6.349365] [RMSE: 0.042009] (2713.547284 s)
[Step 6800/20000] [D loss: 1.008094] [G loss: 9.844044] [RMSE: 0.086153] (2753.353794 s)
[Step 6900/20000] 

Traceback (most recent call last):
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 740, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 738, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-qo2qpsql'


[Step 30900/20000] [D loss: 1.006411] [G loss: 5.957285] [RMSE: 0.095552] (12349.221573 s)
[Step 31000/20000] [D loss: 1.006409] [G loss: 4.688874] [RMSE: 0.082680] (12389.040317 s)
[Step 31100/20000] [D loss: 1.006410] [G loss: 4.404408] [RMSE: 0.095192] (12428.867156 s)
[Step 31200/20000] [D loss: 1.006409] [G loss: 6.510008] [RMSE: 0.090027] (12468.681418 s)
[Step 31300/20000] [D loss: 1.006409] [G loss: 3.913450] [RMSE: 0.111647] (12508.490399 s)
[Step 31400/20000] [D loss: 1.006410] [G loss: 7.333215] [RMSE: 0.086295] (12548.313565 s)
[Step 31500/20000] [D loss: 1.006411] [G loss: 4.023213] [RMSE: 0.105802] (12588.123382 s)
[Step 31600/20000] [D loss: 1.006411] [G loss: 3.438249] [RMSE: 0.091104] (12627.920341 s)
[Step 31700/20000] [D loss: 1.006410] [G loss: 7.392535] [RMSE: 0.093799] (12667.716691 s)
[Step 31800/20000] [D loss: 1.006409] [G loss: 6.157851] [RMSE: 0.087377] (12707.530257 s)
[Step 31900/20000] [D loss: 1.006410] [G loss: 11.244253] [RMSE: 0.095059] (12747.406575 s

Traceback (most recent call last):
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 740, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 738, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-009w_v07'


[Step 35200/20000] [D loss: 1.006410] [G loss: 18.280832] [RMSE: 0.083101] (14061.424779 s)
[Step 35300/20000] [D loss: 1.006410] [G loss: 5.371556] [RMSE: 0.107631] (14101.235566 s)
[Step 35400/20000] [D loss: 1.006411] [G loss: 16.039217] [RMSE: 0.075648] (14141.051608 s)
[Step 35500/20000] [D loss: 1.006410] [G loss: 5.664241] [RMSE: 0.076086] (14180.867064 s)
[Step 35600/20000] [D loss: 1.006410] [G loss: 8.801744] [RMSE: 0.090479] (14220.694936 s)
[Step 35700/20000] [D loss: 1.006410] [G loss: 2.732012] [RMSE: 0.116939] (14260.499904 s)
[Step 35800/20000] [D loss: 1.006413] [G loss: 6.486723] [RMSE: 0.106361] (14300.297853 s)
[Step 35900/20000] [D loss: 1.006411] [G loss: 5.732686] [RMSE: 0.106772] (14340.109691 s)
[Step 36000/20000] [D loss: 1.006410] [G loss: 8.392134] [RMSE: 0.091638] (14379.912996 s)
[Step 36100/20000] [D loss: 1.006411] [G loss: 7.221942] [RMSE: 0.108887] (14419.731683 s)
[Step 36200/20000] [D loss: 1.006409] [G loss: 4.011481] [RMSE: 0.124161] (14459.538858 

Traceback (most recent call last):
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 740, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/pablodonav/miniconda3/envs/torch/lib/python3.9/shutil.py", line 738, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-su6643wp'


[Step 54900/20000] [D loss: 1.006409] [G loss: 9.209316] [RMSE: 0.095802] (21904.268481 s)
[Step 55000/20000] [D loss: 1.006409] [G loss: 4.975554] [RMSE: 0.072628] (21944.085308 s)
[Step 55100/20000] [D loss: 1.006409] [G loss: 4.423204] [RMSE: 0.087521] (21983.906149 s)
[Step 55200/20000] [D loss: 1.006409] [G loss: 3.281706] [RMSE: 0.093147] (22023.735642 s)
[Step 55300/20000] [D loss: 1.006409] [G loss: 7.338740] [RMSE: 0.092109] (22063.556592 s)
[Step 55400/20000] [D loss: 1.006409] [G loss: 5.274210] [RMSE: 0.074427] (22103.361411 s)
[Step 55500/20000] [D loss: 1.006409] [G loss: 5.791990] [RMSE: 0.070782] (22143.173176 s)
[Step 55600/20000] [D loss: 1.006409] [G loss: 5.292432] [RMSE: 0.101456] (22182.970079 s)
[Step 55700/20000] [D loss: 1.006409] [G loss: 6.302993] [RMSE: 0.090263] (22222.762499 s)
[Step 55800/20000] [D loss: 1.006409] [G loss: 10.633359] [RMSE: 0.093823] (22262.581111 s)
[Step 55900/20000] [D loss: 1.006409] [G loss: 7.594226] [RMSE: 0.101005] (22302.378015 s

KeyboardInterrupt: 

In [8]:
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" %epoch)

In [9]:
data = {
    'Time': time_array,
    'RMSE Error': rmse_error_array,
    'Generator Loss': gen_loss_array,
    'Discriminator Loss': disc_loss_array,
    'HR Loss': hr_loss_array,
    'SR Loss': sr_loss_array
}

df = pd.DataFrame(data)
df.to_csv('metrics_pytorch.csv', index=False)