In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import torchvision.models as models
from torchvision.models.vgg import VGG19_Weights
import torchvision.transforms.functional as F
import seaborn as sns
import pandas as pd
import os


In [2]:
from SRResNet_Up2 import SRResNet
from EDSR_Extended import EDSR
import matplotlib.pyplot as plt
from torchvision.models.vgg import VGG19_Weights
#from SRResNetY import SRResnet

In [3]:
class VGGContentLoss(nn.Module):
    def __init__(self, layer_ids=[4, 9, 18, 27], use_pretrained=True):
        super(VGGContentLoss, self).__init__()
        # Carregar a VGG19 pré-treinada no ImageNet
        vgg = models.vgg19(pretrained=use_pretrained).features
        # Mantemos as camadas até a layer que queremos (maxpool1, maxpool2, etc.)
        self.vgg_layers = nn.Sequential(*[vgg[i] for i in range(max(layer_ids)+1)])
        # Definir quais camadas serão utilizadas para o cálculo do loss
        self.layer_ids = layer_ids
        # Congelemos os parâmetros da VGG (não será treinada)
        for param in self.vgg_layers.parameters():
            param.requires_grad = False

    def forward(self, img_real, img_fake):
        # Extrair as características das imagens reais e geradas
        features_real = self.extract_features(img_real)
        features_fake = self.extract_features(img_fake)

        # Content Loss: diferença L2 entre as características
        content_loss = 0.0
        for real, fake in zip(features_real, features_fake):
            content_loss += torch.nn.functional.mse_loss(fake, real)

        return content_loss

    def extract_features(self, x):
        features = []
        for i, layer in enumerate(self.vgg_layers):
            x = layer(x)
            if i in self.layer_ids:
                features.append(x)
            
        return features

In [4]:
batch_size = 22
num_workers = 1
num_epochs = 1
save_interval = 20
betas = (0.5, 0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
data_path = "dataset/"

transform = transforms.Compose([
  transforms.CenterCrop((256, 256)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])


transform_LR = transforms.Compose([
  transforms.CenterCrop((256, 256)),
  transforms.Resize((128,128)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))
])

hr_dataset = ImageFolder(root=data_path + "lr", transform= transform)
lr_dataset = ImageFolder(root=data_path + "lr", transform=transform_LR)

hr_data_loader = DataLoader(hr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
lr_data_loader = DataLoader(lr_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

#vgg_loss = VGGContentLoss(layer_ids=[4, 9, 18, 27]).to(device)
    

In [6]:
def makeFolders(name):
  os.mkdir(name)
  os.mkdir(name + '/training_images')
  os.mkdir(name + '/snapshots')

In [7]:
model = EDSR()
#model.load_state_dict(torch.load('SRResNet-Backup - Actuual/model.pt'))
models = [model] #[ESRT(),RNAN(), SRResnet(), SRResnetExtended(), RNANExtended()]

In [8]:

def trainModel(model, loss_function, LRS, num_epochs):
  error_track = []
  model.to(device)
  makeFolders(model.name)
  for epoch, lr in zip(range(0, num_epochs), LRS):
    optimizer_gen = optim.Adam(model.parameters(), lr = lr, betas = betas)
    for i, (hr_images, lr_images) in enumerate(zip(hr_data_loader, lr_data_loader)):
      hr_images = hr_images[0].to(device).float()
      lr_images = lr_images[0].to(device).float()
      
      sr_images = model(lr_images)
      g_loss_content = loss_function(sr_images, hr_images)
      #g_loss_content = vgg_loss(sr_images, hr_images)  #+ (loss_function(sr_images, hr_images) * 0.1)
      g_loss_content.backward()
      
      optimizer_gen.step()
      
      model.zero_grad()
      optimizer_gen.zero_grad()
      
      if (i + 1) % save_interval == 1:
        print(g_loss_content.item())
        error_track.append(g_loss_content.item())
      if (i) % 50 == 0:
        torch.save(model.state_dict(), f"{model.name}/snapshots/epoch_{epoch}_batch_{i}.pt")
        save_image(sr_images, f"{model.name}/training_images/SR_epoch_{epoch}_batch_{i}.png", normalize = True)
        save_image(hr_images, f"{model.name}/training_images/HR_epoch_{epoch}_batch_{i}.png", normalize = True)
        save_image(lr_images, f"{model.name}/training_images/LR_epoch_{epoch}_batch_{i}.png", normalize = True)
  return error_track

In [9]:
def makeLRList(initial, destiny, epochs):
  dif = destiny - initial
  dif /= epochs
  return [initial + dif * i for i in range(0, epochs)]  

In [None]:
error_tracks = {}
for model in models:
  error_tracks[model.name] = trainModel(model, nn.MSELoss(), makeLRList(0.0001, 0.000001, 5), 5)

  return F.conv2d(input, weight, bias, self.stride,


0.3264651596546173
0.036086611449718475
0.028496479615569115
0.03048541024327278
0.02195902355015278
0.02428966574370861
0.022302910685539246
0.02014380507171154
0.02013089507818222
0.018002958968281746
0.016507795080542564
0.015390399843454361
0.024000216275453568
0.02085338719189167
0.01864708960056305
0.01859557442367077
0.010455998592078686
0.01255401223897934
0.0136759327724576
0.015421737916767597
0.013191530480980873
0.013690485619008541
0.022672588005661964
0.01504336018115282
0.012153388001024723
0.015066048130393028
0.018631551414728165
0.012017378583550453
0.012952769175171852
0.010023247450590134
0.01201297715306282
0.012895997613668442
0.014828398823738098
0.008437300100922585
0.015218903310596943
0.013938026502728462
0.0137669388204813
0.01197677943855524
0.01400416623800993
0.009313098154962063
0.012113920412957668
0.012830597348511219
0.01725752465426922
0.010448230430483818
0.00981107261031866
0.014631197787821293
0.01048300415277481
0.017140096053481102
0.011465269140

In [None]:
df = pd.DataFrame(error_tracks)
df.head(5)

In [None]:
df = pd.DataFrame(error_tracks)
df['Batch'] = np.asarray(df.index)
df.to_csv('Relatorio de Treino EDSR.csv')

In [None]:
df

In [None]:
sns.set_theme(palette='deep')
ax = sns.lineplot(df, x='Batch', y='EDSR')
ax.set_title('SRResNet sem BN')
ax.set_ylabel('')
plt.savefig('EDSR/ErrorCurve.png')

In [None]:
sns.set_theme(palette='deep')
ax = sns.lineplot(df, x='Batch', y='EDSR')
ax.set_title('SRResNet-Extended')
ax.set_ylabel('')
plt.savefig('SRResNet-Extended/ErrorCurve.png')

df = pd.DataFrame({'Iteração do Batch': range(0, len(error_track)), 'MSE' : error_track_array})
sns.set_theme()
ax = sns.lineplot(df, x='Iteração do Batch', y='MSE')
ax.set_title(f'{model.name}')
plt.savefig(f'{model.name}/ErrorCurve')

for model in models:
   Modelo na GPU
  model.to(device)
  content_loss = nn.MSELoss()
   Otimizador
  optimizer_gen = optim.Adam(model.parameters(), lr = lr, betas = betas)
  error_track = [ ]
  error_lines = []
  makeFolders(model.name)
  Treino
  for epoch in range(0, num_epochs):
    for i, (hr_images, lr_images) in enumerate(zip(hr_data_loader, lr_data_loader)):
      hr_images = hr_images[0].to(device).float()
      lr_images = lr_images[0].to(device).float()
      
      sr_images = model(lr_images)
      g_loss_content = content_loss(sr_images, hr_images)
      g_loss_content.backward()
      
      optimizer_gen.step()
      
      model.zero_grad()
      optimizer_gen.zero_grad()
      
      if (i + 1) % save_interval == 1:
        print(g_loss_content.item())
        error_track.append(g_loss_content.item())
        torch.save(model.state_dict(), f"{model.name}/snapshots/epoch{epoch+1}_batch{i+1}.pt")
        print(f"Saved {model.name} model and images at epoch {epoch+1}, batch {i+1} / {len(hr_data_loader)}.")
        save_image(sr_images, f"{model.name}/training_images/image_epoch{epoch + 1}_batch{i+1}_sr.png", normalize = True)
        save_image(hr_images, f"{model.name}/training_images/image_epoch{epoch + 1}_batch{i+1}_hr.png", normalize = True)
        save_image(lr_images, f"{model.name}/training_images/image_epoch{epoch + 1}_batch{i+1}_lr.png", normalize = True)
  error_track_array = np.asarray(error_track)
  error_lines.append(error_track_array)