In [1]:
!pip install -q onnx onnxruntime-gpu basicsr  # Use onnxruntime-gpu for GPU systems; use onnxruntime for CPU-only systems

In [2]:
!pip show basicsr

Name: basicsr
Version: 1.4.2
Summary: Open Source Image and Video Super-Resolution Toolbox
Home-page: https://github.com/xinntao/BasicSR
Author: Xintao Wang
Author-email: xintao.wang@outlook.com
License: Apache License 2.0
Location: /usr/local/python/3.12.1/lib/python3.12/site-packages
Requires: addict, future, lmdb, numpy, opencv-python, Pillow, pyyaml, requests, scikit-image, scipy, tb-nightly, torch, torchvision, tqdm, yapf
Required-by: 


In [4]:
%%writefile dependency-fix.sh
#!/bin/bash
# Fix torchvision import in basicsr/data/degradations.py using relative path
sed -i 's/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/' /usr/local/python/3.12.1/lib/python3.12/site-packages/basicsr/data/degradations.py

Writing dependency-fix.sh


In [3]:
!chmod +x dependency-fix.sh
!./dependency-fix.sh

In [7]:
import yaml
import torch
import torch.onnx
from basicsr.archs.rrdbnet_arch import RRDBNet

def main():
    # Configuration
    input_model_path = '../../RealESRGAN/model/net_g_5000.pth'
    output_onnx_path = '../../RealESRGAN/model/net_g_5000.onnx'
    config_path = 'config.yml'
    opset_version = 11
    use_params_ema = True  # Set to False to use params instead of params_ema

    # Load model configuration
    with open(config_path, 'r') as reader:
        config = yaml.load(reader, Loader=yaml.FullLoader)
    print('network_g config:', config['network_g'])

    # Initialize model
    model = RRDBNet(
        num_in_ch=config['network_g']['num_in_ch'],
        num_out_ch=config['network_g']['num_out_ch'],
        num_feat=config['network_g']['num_feat'],
        num_block=config['network_g']['num_block'],
        num_grow_ch=config['network_g']['num_grow_ch'],
        scale=config['scale']
    )

    # Load model weights
    keyname = 'params_ema' if use_params_ema else 'params'
    model.load_state_dict(torch.load(input_model_path)[keyname])
    model.cpu().eval()

    # Create example input
    x = torch.rand(1, 3, config['network_g']['num_feat'], config['network_g']['num_feat'])

    # Export to ONNX
    with torch.no_grad():
        torch_out = torch.onnx.export(
            model, x, output_onnx_path,
            opset_version=opset_version,
            export_params=True,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                'output': {0: 'batch_size', 2: 'height', 3: 'width'}
            }
        )
    print(f"Done! Exported to {output_onnx_path}")

if __name__ == '__main__':
    main()

network_g config: {'num_in_ch': 3, 'num_out_ch': 3, 'num_feat': 64, 'num_block': 23, 'num_grow_ch': 32}
Done! Exported to ../../RealESRGAN/model/net_g_5000.onnx


In [None]:
import math
import time
from typing import Optional, Tuple, List

import cv2
import numpy as np
import onnxruntime as ort


class BaseModel:
    """ Inference with ONNXRuntime
    """
    def __init__(self,
                 model_path: str,
                 intra_op_num_threads: int = -1,
                 providers: List[str] = ['CPUExecutionProvider']):
        """ Initializer
        Args:
          model_path (str): path to model
          intra_op_num_threads (int): num threads, defaults to -1
          providers (List[str]): onnxruntime providers, defaults to ['CPUExecutionProvider']
        """
        if intra_op_num_threads > 0:
            sess_options = ort.SessionOptions()
            sess_options.intra_op_num_threads = intra_op_num_threads
            self.sess = ort.InferenceSession(model_path, sess_options, providers=providers)
        else:
            self.sess = ort.InferenceSession(model_path, providers=providers)

    def __call__(self, img: np.ndarray) -> np.ndarray:
        input = self.sess.get_inputs()[0].name
        output = self.sess.get_outputs()[0].name
        return self.sess.run([output], {input: img})[0]


class RealESRGAN:
    def __init__(self,
                 model_path: str,
                 scale: int = 4,
                 tile: int = 0,
                 tile_pad: int = 10,
                 pre_pad: int = 10,
                 verbose: bool = True,
                 **kwargs):
        """A helper class for upsampling images with RealESRGAN.
        Args:
            model_path (str): The path to the pretrained model.
            scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
            tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
                input images into tiles, and then process each of them. Finally, they will be merged into one image.
                0 denotes for do not use tile. Default: 0.
            tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
            pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
            verbose (bool):  whether to verbose log. Default: True
        """
        self.scale = scale
        self.tile_size = tile
        self.tile_pad = tile_pad
        self.pre_pad = pre_pad
        self.verbose = verbose
        if self.scale == 2:
            self.mod_scale = 2
        elif self.scale == 1:
            self.mod_scale = 4
        else:
            self.mod_scale = None
        self.model = BaseModel(model_path, **kwargs)

    def pre_process(self, img: np.ndarray) -> np.ndarray:
        """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
        """
        img = np.transpose(img, (2, 0, 1)).astype('float32')
        img = np.expand_dims(img, 0)

        # pre_pad
        if self.pre_pad != 0:
            img = np.pad(img, [(0, 0), (0, 0), (0, self.pre_pad), (0, self.pre_pad)], 'reflect')
        # mod pad for divisible borders
        if self.mod_scale is not None:
            self.mod_pad_h, self.mod_pad_w = 0, 0
            _, _, h, w = img.shape
            if (h % self.mod_scale != 0):
                self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
            if (w % self.mod_scale != 0):
                self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
            img = np.pad(img, [(0, 0), (0, 0), (0, self.mod_pad_h), (0, self.mod_pad_w)], 'reflect')
        return img

    def predict(self, img: np.ndarray) -> np.ndarray:
        # model inference
        return self.model(img)

    def tile_predict(self, img: np.ndarray) -> np.ndarray:
        """It will first crop input images to tiles, and then process each tile.
        Finally, all the processed tiles are merged into one images.
        Modified from: https://github.com/ata4/esrgan-launcher
        """
        batch, channel, height, width = img.shape
        output_height = height * self.scale
        output_width = width * self.scale
        output_shape = (batch, channel, output_height, output_width)

        # start with black image
        output = np.zeros(output_shape)
        tiles_x = math.ceil(width / self.tile_size)
        tiles_y = math.ceil(height / self.tile_size)

        # loop over all tiles
        for y in range(tiles_y):
            for x in range(tiles_x):
                # extract tile from input image
                ofs_x = x * self.tile_size
                ofs_y = y * self.tile_size
                # input tile area on total image
                input_start_x = ofs_x
                input_end_x = min(ofs_x + self.tile_size, width)
                input_start_y = ofs_y
                input_end_y = min(ofs_y + self.tile_size, height)

                # input tile area on total image with padding
                input_start_x_pad = max(input_start_x - self.tile_pad, 0)
                input_end_x_pad = min(input_end_x + self.tile_pad, width)
                input_start_y_pad = max(input_start_y - self.tile_pad, 0)
                input_end_y_pad = min(input_end_y + self.tile_pad, height)

                # input tile dimensions
                input_tile_width = input_end_x - input_start_x
                input_tile_height = input_end_y - input_start_y
                tile_idx = y * tiles_x + x + 1
                input_tile = img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]

                # upscale tile
                try:
                    output_tile = self.model(input_tile)
                except RuntimeError as error:
                    print('Error', error)
                if self.verbose:
                    print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')

                # output tile area on total image
                output_start_x = input_start_x * self.scale
                output_end_x = input_end_x * self.scale
                output_start_y = input_start_y * self.scale
                output_end_y = input_end_y * self.scale

                # output tile area without padding
                output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
                output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
                output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
                output_end_y_tile = output_start_y_tile + input_tile_height * self.scale

                # put tile into output image
                output[:, :, output_start_y:output_end_y,
                       output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
                                                                  output_start_x_tile:output_end_x_tile]
        return output

    def post_process(self, output: np.ndarray) -> np.ndarray:
        _, _, h, w = output.shape
        # remove extra pad
        if self.mod_scale is not None:
            output = output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
        # remove prepad
        if self.pre_pad != 0:
            output = output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
        return output

    def enhance(self,
                img: np.ndarray,
                outscale: Optional[int] = None,
                alpha_upsampler: str = 'realesrgan') -> Tuple[np.ndarray, str]:
        h_input, w_input = img.shape[0:2]
        # img: numpy
        img = img.astype(np.float32)
        max_range = 65535 if np.max(img) > 256 else 255
        img = img / max_range
        if len(img.shape) == 2:  # gray image
            img_mode = 'L'
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        elif img.shape[2] == 4:  # RGBA image with alpha channel
            img_mode = 'RGBA'
            alpha = img[:, :, 3]
            img = img[:, :, 0:3]
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if alpha_upsampler == 'realesrgan':
                alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
        else:
            img_mode = 'RGB'
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # process image (without the alpha channel)
        img = self.pre_process(img)
        if self.tile_size > 0:
            logits = self.tile_predict(img)
        else:
            logits = self.predict(img)
        output_img = self.post_process(logits)
        output_img = np.clip(np.squeeze(output_img, 0), 0, 1)
        output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
        if img_mode == 'L':
            output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)

        # process the alpha channel if necessary
        if img_mode == 'RGBA':
            if alpha_upsampler == 'realesrgan':
                alpha_img = self.pre_process(alpha)
                if self.tile_size > 0:
                    logits = self.tile_predict(alpha_img)
                else:
                    logits = self.predict(alpha_img)
                output_alpha = self.post_process(logits)
                output_alpha = np.squeeze(output_alpha, 0)
                output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
                output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_RGB2GRAY)
            else:  # use the cv2 resize for alpha channel
                h, w = alpha.shape[0:2]
                output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)

            # merge the alpha channel
            output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGRA)
            output_img[:, :, 3] = output_alpha

        if max_range == 65535:  # 16-bit image
            output = (output_img * 65535.0).round().astype(np.uint16)
        else:
            output = (output_img * 255.0).round().astype(np.uint8)

        if outscale is not None and outscale != float(self.scale):
            output = cv2.resize(
                output, (
                    int(w_input * outscale),
                    int(h_input * outscale),
                ), interpolation=cv2.INTER_LANCZOS4)

        return output, img_mode


if __name__ == '__main__':
    # Configuration
    model_path = '../../RealESRGAN/model/net_g_5000.onnx'
    input_path = 'input.png'
    output_path = 'output.png'
    output_scale = 4
    tile_size = 400
    num_threads = -1
    providers = ['CPUExecutionProvider']

    print(f"[INFO] Loading model from: {model_path}")
    model = RealESRGAN(model_path, scale=output_scale, tile=tile_size, intra_op_num_threads=num_threads, providers=providers)

    print(f"[INFO] Reading input image: {input_path}")
    img = cv2.imread(input_path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f"Could not read the input image from '{input_path}'")

    print("[INFO] Starting enhancement...")
    start_time = time.time()
    output, _ = model.enhance(img, outscale=output_scale)
    elapsed_time = time.time() - start_time
    print(f"[INFO] Enhancement completed in {elapsed_time:.2f} seconds")

    cv2.imwrite(output_path, output)
    print(f"[SUCCESS] Enhanced image saved to: {output_path}")

[INFO] Loading model from: ../../RealESRGAN/model/net_g_5000.onnx
[INFO] Reading input image: input.png
[INFO] Starting enhancement...
	Tile 1/6
	Tile 2/6
	Tile 3/6
	Tile 4/6
	Tile 5/6
	Tile 6/6
[INFO] Enhancement completed in 275.85 seconds
[SUCCESS] Enhanced image saved to: output.png
