In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torcheval.metrics import PeakSignalNoiseRatio
import numpy as np
import torchvision.models as models
from torchvision.models.vgg import VGG19_Weights
import torchvision.transforms.functional as F
import torch.nn.functional as Fe
import seaborn as sns
import pandas as pd
from torchvision.utils import make_grid
import os
import cv2

In [2]:
from SRResNet import SRResnet
from EDSR_Extended import EDSR
import matplotlib.pyplot as plt

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = SRResnet()
model.load_state_dict(torch.load('SRResNet-VGG/snapshots/epoch_0_batch_100.pt'))
model.to(device=device)
model.name = "SRResNet - VGG"

model2 = SRResnet()
model2.load_state_dict(torch.load('SRResNet-Backup - Actual/snapshots/epoch_0_batch_7200.pt'))
model2.name = "SRResNet"

model3 = SRResnet()
model3.load_state_dict(torch.load('SRResNet-Backup - Actual/snapshots/epoch_12_batch_0.pt'))
model3.name = "SRResNet - 12 Epochs"

model6 = SRResnet()
model6.load_state_dict(torch.load('SRResNet-Backup - Actual/snapshots/epoch_5_batch_0.pt'))
model6.name = "SRResNet - 5 Epochs"

#model4 = EDSR()
#model4.load_state_dict(torch.load('EDSR/snapshots/epoch_2_batch_0.pt'))
#model4.name = "SRResNet - No BN"

model5 = EDSR()
model5.load_state_dict(torch.load('EDSR Extended/snapshots/epoch_3_batch_0.pt'))
model5.name = "SRResNet - No BN - Extended - 3 Epochs"

model7 = EDSR()
model7.load_state_dict(torch.load(('EDSR/snapshots/epoch_8_batch_0.pt')))


models = [model5]

In [5]:
batch_size = 1
num_workers = 1
data_path = "datasetVal/"

transform = transforms.Compose([
  transforms.CenterCrop((256, 256)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])


transform_LR = transforms.Compose([
  transforms.CenterCrop((256, 256)),
  transforms.Resize((128,128)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])

transform_resize = transforms.Compose([
  transforms.CenterCrop((256, 256)),
  transforms.Resize((128,128)),
  transforms.Resize((256,256)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])

hr_dataset = ImageFolder(root=data_path + "lr", transform= transform)
lr_dataset = ImageFolder(root=data_path + "lr", transform=transform_LR)
resize_dataset = ImageFolder(root=data_path + "lr", transform=transform_resize)

hr_data_loader = DataLoader(hr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
lr_data_loader = DataLoader(lr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
rs_data_loader = DataLoader(resize_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
imagensSelecionadas = [16, 18, 20 ,22, 25,27, 30, 32, 31, 45,46, 90, 52, 56, 50, 70, 100, 120, 150, 170, 200]

In [6]:
def criaPlot2(imagens, text, output):
    newImage = [cv2.cvtColor(cv2.imread(imagem), cv2.COLOR_BGR2RGB) for imagem in imagens]
    rows = 2
    cols = 4
    fig = plt.figure(figsize=(16,8))
    for i, (image, title) in enumerate(zip(newImage, text)):
        plt.subplot(2, 4, i + 1)
        plt.imshow(image)
        plt.axis('off')
        plt.title(title)
    plt.savefig(output)

In [None]:
for i, (hr_images, lr_images, rs_images) in enumerate(zip(hr_data_loader, lr_data_loader, rs_data_loader)):
    
    folderName = str(i)+"folder"
    os.mkdir(f'Test/{folderName}')
    hr_images = hr_images[0].to(device).float()
    lr_images = lr_images[0].to(device).float()
    rs_images = rs_images[0].to(device).float()
    
    save_image(hr_images, f'Test/{folderName}/original.png', normalize = True)
    save_image(lr_images, f'Test/{folderName}/low_resolution.png', normalize = True)
    save_image(rs_images, f'Test/{folderName}/bilinear.png', normalize = True)
    
    # parte que cria dados pro gráfico de imagens
    text = ["High Resolution - HR", "Bilinear Upscale"]
    imagens = [f'Test/{folderName}/original.png',f'Test/{folderName}/bilinear.png']
    
    for x in models:
        x.to(device)
        sr_images = x(lr_images)
        save_image(sr_images, f"Test/{folderName}/{x.name}.png", normalize = True)
        imagens.append(f"Test/{folderName}/{x.name}.png")
        text.append(f"{x.name}")
    criaPlot2(imagens, text, f"Test/{folderName}/" + "plot.png")    
                

In [7]:
def testModel(model, loss_function):
  totalError = 0
  error_track = []
  psnr_loss = []
  psnr = PeakSignalNoiseRatio()
  model.to(device)
  
  for i, (hr_images, lr_images) in enumerate(zip(hr_data_loader, lr_data_loader)):
    hr_images = hr_images[0].to(device).float()
    lr_images = lr_images[0].to(device).float()
    
    sr_images = model(lr_images)
    g_loss_content = loss_function(sr_images, hr_images)
    
    psnr.update(sr_images, hr_images)
    psnr_loss.append(psnr.compute().to('cpu'))
    
    totalError += g_loss_content.item()
    error_track.append(g_loss_content.item())
    
    #print(f"{i} batch\nPsnr:{psnr_loss[i]}")
  return totalError, error_track, np.asarray(psnr_loss).mean()

##### Testando todos os modelos

In [15]:
totalError = 0
error_track = []
psnr_loss = []
psnr = PeakSignalNoiseRatio()
model.to(device)
loss_function = nn.MSELoss()

for i, (hr_images, sr_images) in enumerate(zip(hr_data_loader, rs_data_loader)):
    hr_images = hr_images[0].to(device).float()
    sr_images = sr_images[0].to(device).float()

    g_loss_content = loss_function(sr_images, hr_images)

    psnr.update(sr_images, hr_images)
    psnr_loss.append(psnr.compute().to('cpu'))

    totalError += g_loss_content.item()
    error_track.append(g_loss_content.item())

#print(f"{i} batch\nPsnr:{psnr_loss[i]}")
totalError, error_track, np.asarray(psnr_loss).mean()

(85.21737998227036,
 [0.008866239339113235,
  0.02348390221595764,
  0.039528846740722656,
  0.02006634883582592,
  0.014680149964988232,
  0.012719947844743729,
  0.007905548438429832,
  0.01209353469312191,
  0.006675140466541052,
  0.02079123631119728,
  0.018358666449785233,
  0.009387647733092308,
  0.002978375181555748,
  0.006824645213782787,
  0.019210491329431534,
  0.0077764010056853294,
  0.00902366079390049,
  0.009718842804431915,
  0.011545248329639435,
  0.002876361832022667,
  0.04611989110708237,
  0.012325970456004143,
  0.010220265947282314,
  0.0007507880218327045,
  0.011414158158004284,
  0.01066848635673523,
  0.014627466909587383,
  0.050189100205898285,
  0.006325410213321447,
  0.016336344182491302,
  0.026310821995139122,
  0.0031594661995768547,
  0.08543535321950912,
  0.009346283040940762,
  0.016411691904067993,
  0.012507490813732147,
  0.04128344729542732,
  0.00898092333227396,
  0.0069645605981349945,
  0.00038480921648442745,
  0.008757134899497032,


In [21]:
np.asarray(psnr_loss).mean()

23.720375

In [24]:
np.asarray(error_track).sum()

85.21737998227036

In [8]:
for x in models:
    erro_REDE, track_REDE, psnr_loss = testModel(x, nn.MSELoss())
    print(f"Modelo {x.name}:\n-PSNR:{psnr_loss}\n-MSE total:{erro_REDE}")

  return F.conv2d(input, weight, bias, self.stride,


Modelo SRResNet - No BN - Extended - 3 Epochs:
-PSNR:26.693490982055664
-MSE total:42.88871908938745


In [13]:
erro_REDE, track_REDE, psnr_loss = testModel(model3, nn.MSELoss())

In [11]:
np.asarray(track_REDE).sum()

89.46096962445881

In [None]:
print(f"PSRN Médio da Rede: {np.asarray(psnr_loss).mean()}")

#### PSRN Médio da Rede: 24.983320236206055 50 Batchs
#### PSRN Médio da Rede: 24.83488655090332 100 Batchs



In [None]:
totalError = 0
error_track = []
psnr_loss = []
loss = nn.MSELoss()
psnr = PeakSignalNoiseRatio()
for i, (hr_images, rs_images) in enumerate(zip(hr_data_loader, rs_data_loader)):
  hr_images = hr_images[0].to(device).float()
  rs_images = rs_images[0].to(device).float()
  
  g_loss_content = loss(hr_images, rs_images)
  psnr.update(rs_images, hr_images)

  totalError += g_loss_content.item()
  psnr_loss.append(psnr.compute().to('cpu'))
  error_track.append(g_loss_content.item())
  print(f"{i} batch")

In [None]:
print(f"PSRN Médio do Scaler: {np.asarray(psnr_loss).mean()}")

In [None]:
totalError

In [None]:
error_track

In [None]:
track_REDE