# Imports

In [None]:
import argparse
import os
import copy
import numpy as np

# import PIL.Image as pil_image
from PIL import Image
import matplotlib. pyplot as plt 
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
# import torchvision.transforms as transforms
import time
import tensorflow as tf
import mlflow
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

# Model

In [None]:
class FSRCNN(nn.Module):
    def __init__(self, scale_factor):

        super(FSRCNN, self).__init__()
        self.scale_factor = scale_factor

        self.feature_extraction = nn.Sequential(
            nn.Conv2d(3, 56, kernel_size=5, padding=2),
            nn.PReLU()
        )
        self.shrinking = nn.Sequential(
            nn.Conv2d(56, 12, kernel_size=1),
            nn.PReLU()
        )
        self.non_linear_mapping = nn.Sequential(
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.expanding = nn.Sequential(
            nn.Conv2d(12, 56, kernel_size=1),
            nn.PReLU()
        )
        self.deconvolution = nn.ConvTranspose2d(56, 3, kernel_size=9, stride=scale_factor, padding=4, output_padding=scale_factor-1)

    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.shrinking(x)
        x = self.non_linear_mapping(x)
        x = self.expanding(x)
        x = self.deconvolution(x)
        return x


# Dataloader / preprocessing

In [None]:
class DIV2KDataset(Dataset):
    def __init__(self, img_dir, scale_factor, desired_height, desired_width):
        super(DIV2KDataset, self).__init__()
        self.img_dir = img_dir
        self.scale_factor = scale_factor
        self.desired_height = desired_height
        self.desired_width = desired_width
        self.img_list = os.listdir(self.img_dir)

    def __getitem__(self, index):
        img_hr = Image.open(os.path.join(self.img_dir, self.img_list[index]))

        # Redimensionar imagem HR para o tamanho desejado
        img_hr = img_hr.resize((self.desired_width, self.desired_height), Image.BICUBIC)

        # Redimensionar imagem LR correspondente com o fator de escala
        img_lr = img_hr.resize((self.desired_width // self.scale_factor, self.desired_height // self.scale_factor), Image.BICUBIC)

        transform = T.Compose([
            T.ToTensor(),
        ])

        img_hr = transform(img_hr)
        img_lr = transform(img_lr)

        return img_hr, img_lr

    def __len__(self):
        return len(self.img_list)


# Train and Val

In [None]:
class TrainFSRCNN(object):
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, n_epochs, device):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.n_epochs = n_epochs
        self.device = device
        self.log_dir = '/phoenix/tensorboard/tensorlogs'
        self.mlflow_exp_name = '/phoenix/mlflow'
        self.best_metric = float('inf')  
        self.best_epoch = -1

    def train(self):
        start_time = time.time()

        writer = SummaryWriter(self.log_dir)
        mlflow.set_experiment(self.mlflow_exp_name)

        self.model = self.model.to(self.device)
        self.model.train()

        for epoch in range(self.n_epochs):
            running_loss = 0.0
            for i, (hr, lr) in enumerate(self.train_loader):
                hr = hr.to(self.device)
                lr = lr.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.model(lr)
                loss = self.criterion(outputs, hr)

                loss.backward()
                self.optimizer.step()

                running_loss += loss.item()

            if (epoch +  1) % 10 == 0:
                self.log_images_to_tensorboard(writer, epoch)

            val_loss, val_psnr = self.validate()

            print("Epoch: %d, Loss: %.3f, Validation Loss: %.3f, Validation PSNR: %.2f" %
                  (epoch + 1, running_loss / len(self.train_loader), val_loss, val_psnr))
            
            if running_loss < self.best_metric:
                torch.save(self.model.state_dict(), 'best_model.pth')
                self.best_metric = running_loss

            with mlflow.start_run(run_name=self.mlflow_exp_name) as run:
                print(run.info.run_id)
                mlflow.log_metric("Training RMSE", running_loss)
                mlflow.log_metric("Validation RMSE", val_loss)
                # mlflow.log_metric("Validation PSNR", val_psnr)
                mlflow.register_model(model_uri = f"runs:/{run.info.run_id}/fscnn", name="fscnn")
    

        writer.close()

        end_time = time.time()
        total_time = end_time - start_time
        print('Tempo total de treinamento: {:.2f} segundos'.format(total_time))

    def log_images_to_tensorboard(self, writer, epoch):
        self.model.eval()
        with torch.no_grad():
            # Get a batch from the validation set
            hr, lr = next(iter(self.val_loader))
            hr = hr.to(self.device)
            lr = lr.to(self.device)
    
            # Forward pass
            outputs = self.model(lr)
    
            # Convert images to a grid for visualization
            grid_hr = make_grid(hr, nrow=1, normalize=True)
            grid_lr = make_grid(lr, nrow=1, normalize=True)
            # grid_sr = make_grid(outputs, nrow=1, normalize=True)
    
            # Log images to TensorBoard
            writer.add_image(f'Original/Epoch_{epoch + 1}', grid_hr, epoch)
            writer.add_image(f'Low Resolution/Epoch_{epoch + 1}', grid_lr, epoch)
            writer.add_image(f'Super-Resolved/Epoch_{epoch + 1}', grid_sr, epoch)

    
    def validate(self):
        self.model.eval()
        with torch.no_grad():
            val_loss = 0.0
            val_psnr = 0.0
            for (hr, lr) in self.val_loader:
                hr = hr.to(self.device)
                lr = lr.to(self.device)
                outputs = self.model(lr)
                loss = self.criterion(outputs, hr)
                val_loss += loss.item()

                # Calcular PSNR
                mse = torch.mean((hr - outputs) ** 2)
                psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
                val_psnr += psnr.item()
            return val_loss / len(self.val_loader), val_psnr / len(self.val_loader)
        

    def device_validate(self):
        self.model = self.model.to(self.device)
        return self.validate()


In [None]:
class Args:
    train_dir = '../datafabric/DIV2K/Deep Learning/PyTorch/Computer Vision/DIV2K/DIV2K_train_HR/DIV2K_train_HR'
    val_dir = '../datafabric/DIV2K/Deep Learning/PyTorch/Computer Vision/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR'

    # train_dir = 'sample'
    # val_dir = 'sample'
    
    scale = 4
    batch_size = 4
    epochs = 300

args = Args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Carregar dataset de treinamento
train_dataset = DIV2KDataset(args.train_dir, args.scale, 1020, 2040)

# Carregar dataset de validação
val_dataset = DIV2KDataset(args.val_dir, args.scale, 1020, 2040)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)

model = FSRCNN(scale_factor=args.scale)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
train_fsrcnn = TrainFSRCNN(model, criterion, optimizer, train_loader, val_loader, args.epochs, device)

train_fsrcnn.train()

# Inference

In [None]:
torch.save(model.state_dict(), 'FSRCNN_300_epochs.pt')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = FSRCNN(4)
model.to(device)
model.load_state_dict(torch.load('FSRCNN_300_epochs.pt'))
model.eval()

class Args:
    train_dir = '../datafabric/DIV2K/Deep Learning/PyTorch/Computer Vision/DIV2K/DIV2K_train_HR/DIV2K_train_HR'
    val_dir = '../datafabric/DIV2K/Deep Learning/PyTorch/Computer Vision/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR'
    scale = 4
    batch_size = 4
    epochs = 300

args = Args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Carregar dataset de treinamento
train_dataset = DIV2KDataset(args.train_dir, args.scale, 1020, 2040)

# Carregar dataset de validação
val_dataset = DIV2KDataset(args.val_dir, args.scale, 1020, 2040)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True)

# hr_img, lr_img
x_batch, y_batch = next(iter(val_loader))
y, x = x_batch[0], y_batch[0]
x = x.to(device)
pred = model(x)
pred = pred.cpu()
print('imagem PREDITA', pred.shape)
print('imagem HR', y.shape)

# Converter tensores para arrays numpy
y = y.numpy().transpose(1, 2, 0)
pred = pred.detach().numpy().transpose(1, 2, 0)

# Plotar as imagens
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(y)
axes[0].set_title('Imagem de Alta Resolução')
axes[0].axis('off')

axes[1].imshow(pred)
axes[1].set_title('Imagem Predita')
axes[1].axis('off')


In [None]:
# hr_img, lr_img
x_batch, y_batch = next(iter(val_loader))
y, x = x_batch[0], y_batch[0]
x = x.to(device)
pred = model(x)
pred = pred.cpu()
print('imagem PREDITA', pred.shape)
print('imagem HR', y.shape)

# Converter tensores para arrays numpy
y = y.numpy().transpose(1, 2, 0)
pred = pred.detach().numpy().transpose(1, 2, 0)

# Plotar as imagens
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(y)
axes[0].set_title('Imagem de Alta Resolução')
axes[0].axis('off')

axes[1].imshow(pred)
axes[1].set_title('Imagem Predita')
axes[1].axis('off')

plt.show()


# HR and LR image comparasion

In [None]:
def train_imgs_visualization():
    for imgs in train_loader:
        fig = plt.figure(figsize=(10, 7))

        img1 = imgs[0][0].permute(1, 2, 0)
        fig.add_subplot(1, 2, 1)
        plt.imshow(img1)
        print('imagem HR', img1.shape)

        
        img2 = imgs[1][0].permute(1, 2, 0)
        fig.add_subplot(1, 2, 2)
        plt.imshow(img2)
        print('imagem LR', img2.shape)
        break

def val_imgs_visualization():
    for imgs in val_loader:
        fig = plt.figure(figsize=(10, 7))

        img1 = imgs[0][0].permute(1, 2, 0)
        fig.add_subplot(1, 2, 1)
        plt.imshow(img1)
        
        img2 = imgs[1][0].permute(1, 2, 0)
        fig.add_subplot(1, 2, 2)
        plt.imshow(img2)
        break

train_imgs_visualization()