# Training a SuperResolution Model with Pytorch

## Introduction

The main goal of this notebook is to illustrate an end-to-end flow for Machine Learning applied to Computer Vision. More specifically, this example considers the case of image transformation/super resolution, where the goal is to build a model capable to get as input a low-resolution image and return as output another image, with higher resolution. To achieve this go, we go through the following steps:
 * Loading and navigating through a dataset of images
 * Defining the CNN architecture in torch, as well as the training steps
 * Running training experiments, using both CPU and GPU, and monitoring
   * Evolution of metrics such as loss and accuracy across each run, using TensorBoard
   * Comparison between the results of different runs, using MLFlow
 * Saving and registering the models resulting of the different runs
 * Deploying a (local) image transformation service, accessible by REST api, using MLFlow
In order to achieve this goal, we will use Torch and Torchvision for creating and training the model, PIL to help in the image manipulation, mlflow and tensorboard for monitoring and iPyWidgets to create visualizations in the Jupyter Notebook

In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from  torchvision import transforms, utils

from PIL import Image
import os
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd

import mlflow
import ipywidgets as widgets
from IPython.display import display
import io


2023-08-31 19:37:57.575925: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Dataset visualization and image pre-processing

The following cells contain the definition of a fixed Width and Height to simplify the task, as well as a function to extract a region with the given dimensions from the source images. This function consists in rotating images presented in portrait mode (height > width) and cropping a region in the center of the image, if the dimensions exceed the one given by the parameters. Also, a small widget is proposed to visualize the pairs of LR/HR images from the dataset, already applying the proposed function to process the image

In [2]:
HR_WIDTH = 2040
HR_HEIGHT = 1152
LR_WIDTH = 510
LR_HEIGHT = 288

def center_crop(img, new_width, new_height):
    width, height = img.size
    if width < height:
        img = img.rotate(90)     
    width, height = img.size
    if new_width > width:
        img = img.resize((new_width, height))
    if new_height > height:
        img = img.resize((width, new_height))
    width, height = img.size        
    new_left = (width - new_width) / 2
    new_top = (height - new_height) / 2
    return img.crop((new_left, new_top, new_left + new_width, new_top + new_height))

In [3]:
class ImageVisualizer(object):
    # Class that shows three images, side by side horizontally, or vertically
    # The images are specified by a function that receives an index and returns a tuple with 3 PIL images
    def __init__(self, width, height, is_horizontal = True, image_getter = None, number_of_images = 0):
        self.img_index = 0
        self.width = width
        self.height = height
        self.is_horizontal = is_horizontal
        self.images = []
        if image_getter == None:
            self.image_getter = self.default_image_getter
        else:
            self.image_getter = image_getter
        self.number_of_images = number_of_images
        self.output = widgets.Output()
        self.image_widgets = None
        self.image_box = None
        self.main_box = None
        self.create_button_box()
        self.load_display()
        
    
    def default_image_getter(self, index):
        return(Image.new("RGB", (self.width, self.height)), 
               Image.new("RGB", (self.width, self.height)),
               Image.new("RGB", (self.width, self.height)))
    
    def previous_button_click(self, button):
        if self.img_index > 0:
            self.img_index = self.img_index - 1
            self.current_button.description = str(self.img_index)
            self.load_display()    

    def next_button_click(self, button):
        if self.img_index < self.number_of_images - 1:
            self.img_index = self.img_index + 1
            self.current_button.description = str(self.img_index)
            self.load_display()
            
    def create_button_box(self):
        self.previous_button = widgets.Button(
            description='Previous',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Previous Image',
        )
        
        self.current_button = widgets.Button(
            description='0',
            disabled=True,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Current image',
        )

        self.next_button = widgets.Button(
            description='Next',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Next Image',
        )
        self.previous_button.on_click(self.previous_button_click)
        self.next_button.on_click(self.next_button_click)
        self.button_box = widgets.HBox([self.previous_button, self.current_button, self.next_button])
    
    def load_display(self):
        self.images = self.image_getter(self.img_index)
        if self.image_widgets is None:
            self.image_widgets = [None] * len(self.images)
        
        for index in range(0, len(self.images)):
            byte_array = io.BytesIO()
            self.images[index].save(byte_array, format='PNG')
            if self.image_widgets[index] is None:
                self.image_widgets[index] = widgets.Image(
                    value = byte_array.getvalue(),
                    format = 'PNG',
                    width = self.width,
                    height = self.height,
                )
            else:
                self.image_widgets[index].value = byte_array.getvalue()

        if self.image_box is None:
            if self.is_horizontal:
                self.image_box = widgets.HBox(self.image_widgets)
            else:
                self.image_box = widgets.VBox(self.image_widgets)
            self.main_box = widgets.VBox([self.button_box, self.image_box])
            display(self.main_box, self.output)


In [4]:
def load_images(index):
    lr_folder = "sample_data/DIV2K_train_LR_mild/"
    hr_folder = "sample_data/DIV2K_train_HR/"
    lr_suffix = "x4m"
    hr_list = os.listdir(hr_folder)
    hr_name, hr_ext = os.path.splitext(hr_list[index])
    lr_name =  f"{hr_name}{lr_suffix}{hr_ext}"
    hr_image = Image.open(os.path.join(hr_folder, hr_list[index]))
    lr_image = Image.open(os.path.join(lr_folder, lr_name))    
    return (
        center_crop(lr_image, LR_WIDTH, LR_HEIGHT),
        center_crop(hr_image, HR_WIDTH, HR_HEIGHT)
    )

number_of_images = len(os.listdir("sample_data/DIV2K_train_HR/"))
ImageVisualizer(HR_WIDTH / 2, HR_HEIGHT / 2, is_horizontal = False, image_getter = load_images, number_of_images = number_of_images)

VBox(children=(HBox(children=(Button(description='Previous', style=ButtonStyle(), tooltip='Previous Image'), B…

Output()

<__main__.ImageVisualizer at 0x7f8e442dfb10>

# Defining the CNN model and the Training process



In [5]:
class FSRCNN(nn.Module):
    def __init__(self, scale_factor):

        # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
        # paper(filter_size, number_of_filters, number_of_channels)
        # Correspondência com os parâmetros do PyTorch:
        # torch.nn.Conv2d(in_channels, out_channels, kernel_size)
        #        number_of_chennels, number_of_filters, filter_size

        super(FSRCNN, self).__init__()
        self.scale_factor = scale_factor

        self.feature_extraction = nn.Sequential(
            nn.Conv2d(3, 56, kernel_size=5, padding=2),
            nn.PReLU()
        )
        self.shrinking = nn.Sequential(
            nn.Conv2d(56, 12, kernel_size=1),
            nn.PReLU()
        )
        self.non_linear_mapping = nn.Sequential(
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(12, 12, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.expanding = nn.Sequential(
            nn.Conv2d(12, 56, kernel_size=1),
            nn.PReLU()
        )
        self.deconvolution = nn.ConvTranspose2d(56, 3, kernel_size=9, stride=scale_factor, padding=4, output_padding=scale_factor-1)

    def forward(self, x):
        x = self.feature_extraction(x)
        x = self.shrinking(x)
        x = self.non_linear_mapping(x)
        x = self.expanding(x)
        x = self.deconvolution(x)
        return x


In [6]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, lr_suffix):
        super(DIV2KDataset, self).__init__()
        self.hr_dir = hr_dir
        self.hr_list = os.listdir(self.hr_dir)
        self.lr_dir = lr_dir
        self.lr_suffix = lr_suffix
        
    
    def __getitem__(self, index):
        img_hr = Image.open(os.path.join(self.hr_dir, self.hr_list[index]))
        hr_name, hr_ext = os.path.splitext(self.hr_list[index])
        img_lr_name =  f"{hr_name}{self.lr_suffix}{hr_ext}"
        img_lr = Image.open(os.path.join(self.lr_dir, img_lr_name))
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        img_hr = center_crop(img_hr, 2040, 1152)
        img_lr = center_crop(img_lr, 510, 288)
        img_hr = transform(img_hr)
        img_lr = transform(img_lr)
        return img_hr, img_lr

    def __len__(self):
        return len(self.hr_list)

In [7]:
class TrainFSRCNN(object):
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, n_epochs, device, run_name, epochs_until_display=20):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.n_epochs = n_epochs
        self.device = device
        self.tb_writer = torch.utils.tensorboard.SummaryWriter()
        self.run_name = run_name
        self.sample = ()
        self.epochs_until_display = epochs_until_display

    def train(self):
        # Adicionar um temporizador para calcular o tempo de treinamento
        #start_time = time.time()
        self.model = self.model.to(self.device)
        self.model.train()
        writer = SummaryWriter(os.path.join(os.environ["TENSORBOARD_LOGDIR"], self.run_name))
        tr_loss, val_loss, val_psnr = (0, 0, 0)
        outer_loop = tqdm(range(self.n_epochs))
        for epoch in outer_loop:
            should_add_image = (epoch % self.epochs_until_display) == 0 
            outer_loop.set_description(f"Epoch [{epoch}/{self.n_epochs}]")
            running_loss = 0.0
            train_loop = tqdm(self.train_loader)
            for i, (hr, lr) in enumerate(train_loop):
                self.sample = (hr, lr)
                hr = hr.to(self.device)
                lr = lr.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(lr)
                loss = self.criterion(outputs, hr)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()
                train_loop.set_description(f"Training: Loss = {running_loss / (i + 1)}")
            val_loss, val_psnr, val_grid = self.validate(should_add_image)
            tr_loss = running_loss / len(self.train_loader)
            writer.add_scalar("Loss (Train)", tr_loss, epoch)
            writer.add_scalar("Loss (Validation)", val_loss, epoch)
            writer.add_scalar("PNSR (Validation)", val_psnr, epoch)
            if should_add_image:
                writer.add_image("Sample of validation outputs", val_grid, epoch)
            train_loop.close()
        if device.type == "cuda":
            torch.cuda.empty_cache()
        writer.flush()
        writer.close()

        with mlflow.start_run(run_name=self.run_name) as run:
            print(run.info.run_id)
            mlflow.log_metric("Training RMSE", tr_loss)
            mlflow.log_metric("Validation RMSE", val_loss)
            mlflow.log_metric("Validation PSNR", val_psnr)
            #signature = mlflow.models.signature.infer_signature(self.sample[1], self.sample[0])
            SuperResolutionModel.log_sr_model(model, "", "fscnn")
            #mlflow.pytorch.log_model(self.model, "fscnn")
            mlflow.register_model(model_uri = f"runs:/{run.info.run_id}/fscnn", name="fscnn")


        # Adicionar um temporizador para calcular o tempo de treinamento
        #end_time = time.time()
        #total_time = end_time - start_time
        #print('Tempo total de treinamento: {:.2f} segundos'.format(total_time))

    def validate(self, batch_as_grid = False):
        self.model.eval()
        grid = None
        with torch.no_grad():
            val_loss = 0.0
            val_psnr = 0.0
            val_loop = tqdm(self.val_loader)
            for i, (hr, lr) in enumerate(val_loop):
                hr = hr.to(self.device)
                lr = lr.to(self.device)
                outputs = self.model(lr)
                if batch_as_grid:
                    grid = utils.make_grid(outputs)
                loss = self.criterion(outputs, hr)
                val_loss += loss.item()

                # Calcular PSNR
                mse = torch.mean((hr - outputs) ** 2)
                psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
                val_psnr += psnr.item()
                val_loop.set_description(f"Validation: Loss = {val_loss / (i + 1)}")
            val_loop.close()
            return val_loss / len(self.val_loader), val_psnr / len(self.val_loader), grid
        
    def device_validate(self):
        self.model = self.model.to(self.device)
        return self.validate()


In [8]:
class SuperResolutionModel(mlflow.pyfunc.PythonModel):
    def _preprocess(self, img):
        lr_width = 510
        lr_height = 288
        width, height = img.size
        rotated = False
        if width < height:
            img = img.rotate(90)     
            rotated = True
        width, height = img.size
        if lr_width > width:
            img = img.resize((lr_width, height))
        if lr_height > height:
            img = img.resize((width, lr_height))
        width, height = img.size        
        new_left = (width - lr_width) / 2
        new_top = (height - lr_height) / 2
        input_img = img.crop((new_left, new_top, new_left + lr_width, new_top + lr_height))
        transform = transforms.ToTensor()
        return transform(input_img), rotated

    def _postprocess(self, tensor, rotated):
        transform = transforms.ToPILImage()
        output_image = transform(tensor)
        if rotated:
            output_image = output_image.rotate(270)
        return output_image
        
    def load_context(self, context):
        import torch
        self.model = FSRCNN(scale_factor=4)
        self.model.load_state_dict(torch.load(context.artifacts["model"]))
        
    def predict(self, context, model_input, params=None):
        input_list = model_input.iloc[0, 0]
        print(len(input_list))
        input_np_array = np.array(input_list)
        print(input_np_array.shape)
        input_image = Image.fromarray(np.uint8(input_np_array))
        input_tensor, rotated = self._preprocess(input_image)
        output_tensor = self.model(input_tensor)
        output_image = self._postprocess(output_tensor, rotated)
        return np.array(output_image).tolist()

    @classmethod
    def log_sr_model(cls, torch_model, base_path, model_name):
        torch_model_path = os.path.join(base_path, model_name)
        torch.save(model.state_dict(), model_name)
        requirements = [
            "mlflow==2.5.0",
            "astunparse==1.6.3",
            "cloudpickle==2.2.1",
            "numpy==1.24.3",
            "opt-einsum==3.3.0",
            "torch==2.0.0",
            "tqdm==4.65.0",
            "torchvision"
        ]
        mlflow.pyfunc.log_model(
            model_name,
            python_model=cls(),
            artifacts={"model": model_name},
            pip_requirements=requirements
        )

        

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

mlflow.set_experiment("Super Resolution 4x")

batch_size = 8
epochs = 200

train_dataset = DIV2KDataset("sample_data/DIV2K_train_HR", "sample_data/DIV2K_train_LR_mild", "x4m")
val_dataset = DIV2KDataset("sample_data/DIV2K_valid_HR", "sample_data/DIV2K_valid_LR_mild", "x4m")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

run_name = "Test2"

model = FSRCNN(scale_factor=4)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
train_fsrcnn = TrainFSRCNN(model, criterion, optimizer, train_loader, val_loader, epochs, device, run_name)

train_fsrcnn.train()


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:08<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

abba8ccc5d9249288b6b472ec1a92f8a


Registered model 'fscnn' already exists. Creating a new version of this model...
2023/08/20 22:09:07 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation. Model name: fscnn, version 38
Created version '38' of model 'fscnn'.


In [30]:
def process_image_to_mlflow_model(image):
     return pd.DataFrame([{"image": np.array(image).tolist()}])

def return_image_from_mlflow_model(list):
    return Image.fromarray(np.uint8(list))

def load_images(index):
    in_transform = transforms.ToTensor()
    #out_transform = transforms.ToPILImage()
        
    lr_folder = "sample_data/DIV2K_train_LR_mild/"
    hr_folder = "sample_data/DIV2K_train_HR/"
    lr_suffix = "x4m"
    hr_list = os.listdir(hr_folder)
    hr_name, hr_ext = os.path.splitext(hr_list[index])
    lr_name =  f"{hr_name}{lr_suffix}{hr_ext}"
    hr_image = Image.open(os.path.join(hr_folder, hr_list[index]))
    lr_image = Image.open(os.path.join(lr_folder, lr_name))    
    lr_tensor = in_transform(center_crop(lr_image, LR_WIDTH, LR_HEIGHT))   #.to(device)
    output = loaded_model.predict(process_image_to_mlflow_model(lr_image))
    out_image = return_image_from_mlflow_model(output)
    
    return (
        center_crop(lr_image, LR_WIDTH, LR_HEIGHT),
        center_crop(hr_image, HR_WIDTH, HR_HEIGHT),
        out_image
    )

ImageVisualizer(HR_WIDTH / 2, HR_HEIGHT / 2, is_horizontal = False, image_getter = load_images, number_of_images = 10)

Got here
This is a test
324
(324, 510, 3)
Loaded Image
....Started preprocessing
....Loaded Variables
....Processed Rotation
....Processed Resize
....Processed Cropping
....Loaded Transform
Preprocessed
Predicted
PostProcessed


VBox(children=(HBox(children=(Button(description='Previous', style=ButtonStyle(), tooltip='Previous Image'), B…

Output()

<__main__.ImageVisualizer at 0x7fb47e1cae10>

In [5]:
!mlflow models serve -m mlflow/671370789459351034/abba8ccc5d9249288b6b472ec1a92f8a/artifacts/fscnn --env-manager conda --port 5001

  value = self.callback(ctx, self, value)
2023/08/17 19:44:18 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
2023/08/17 19:44:19 INFO mlflow.utils.conda: Conda environment mlflow-629041bb6f3b7f2a5ba86f6ea07eb4ee4363ca4d already exists.
2023/08/17 19:44:19 INFO mlflow.utils.environment: === Running command '['bash', '-c', 'source activate mlflow-629041bb6f3b7f2a5ba86f6ea07eb4ee4363ca4d 1>&2 && python -c ""']'
2023/08/17 19:44:20 INFO mlflow.utils.environment: === Running command '['bash', '-c', 'source activate mlflow-629041bb6f3b7f2a5ba86f6ea07eb4ee4363ca4d 1>&2 && exec gunicorn --timeout=60 -b 127.0.0.1:5001 -w 1 ${GUNICORN_CMD_ARGS} -- mlflow.pyfunc.scoring_server.wsgi:app']'
[2023-08-17 19:44:20 +0000] [255] [INFO] Starting gunicorn 20.1.0
[2023-08-17 19:44:20 +0000] [255] [INFO] Listening at: http://127.0.0.1:5001 (255)
[2023-08-17 19:44:20 +0000] [255] [INFO] Using worker: sync
[2023-08-17 19:44:20 +0000] [260] [INFO] Booting worker with 

In [None]:
mlflow deployments create -t sagemaker -m mlflow/671370789459351034/abba8ccc5d9249288b6b472ec1a92f8a/artifacts/fscnn --name fscnn_test