In [1]:
!pip uninstall -y basicsr
!pip uninstall -y torchvision
!pip install torchvision==0.15.2
!pip install basicsr==1.4.2
!pip install onnx
!pip install onnxsim
!pip install onnxruntime
!pip install realesrgan

[0mFound existing installation: torchvision 0.15.2
Uninstalling torchvision-0.15.2:
  Successfully uninstalled torchvision-0.15.2
Collecting torchvision==0.15.2
  Using cached torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl.metadata (11 kB)
Using cached torchvision-0.15.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
Installing collected packages: torchvision
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gfpgan 1.3.8 requires basicsr>=1.4.2, which is not installed.
realesrgan 0.3.0 requires basicsr>=1.4.2, which is not installed.[0m[31m
[0mSuccessfully installed torchvision-0.15.2
Collecting basicsr==1.4.2
  Using cached basicsr-1.4.2-py3-none-any.whl
Installing collected packages: basicsr
Successfully installed basicsr-1.4.2


In [41]:
import os
import json
import torch
import torch.onnx
import numpy as np
import onnx
import onnxruntime as ort
from typing import Tuple, Optional, Dict, List
from pathlib import Path
from tqdm import tqdm
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from onnxsim import simplify

class OptimizedModelConverter:
    def __init__(
        self,
        model_path: str = 'weights/realesr-animevideov3.pth',
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.model_path = model_path
        self.device = torch.device(device)

        # Cấu hình cố định cho tile size để tránh mismatch
        self.DEFAULT_OPSET = 12
        self.DEFAULT_BATCH_SIZE = 1
        self.TILE_SIZE = 128  # Phải giữ nhất quán
        self.TILE_OVERLAP = 8
        self.SCALE_FACTOR = 4
        self.MAX_MEMORY_GB = 4.0

        if device == 'cuda':
            self.setup_cuda()

    def setup_cuda(self):
        """Tối ưu CUDA cho inference"""
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

        # Thêm cấu hình cho mixed precision
        torch.backends.cudnn.deterministic = False
        if torch.cuda.get_device_capability()[0] >= 7:
            torch.backends.cuda.enable_flash_sdp(True)

    def calculate_optimal_tile_size(self, height: int, width: int) -> Tuple[int, int]:
        """Tính toán kích thước tile tối ưu dựa trên kích thước ảnh và bộ nhớ"""
        target_pixels = (self.MAX_MEMORY_GB * 1024 * 1024 * 1024) / (4 * 3)  # 4 bytes per float, 3 channels
        scale = self.SCALE_FACTOR

        h_tiles = max(1, int(np.ceil(height / self.TILE_SIZE)))
        w_tiles = max(1, int(np.ceil(width / self.TILE_SIZE)))

        optimal_tile_h = min(self.TILE_SIZE, height)
        optimal_tile_w = min(self.TILE_SIZE, width)

        # Đảm bảo tile size chia hết cho 8 (tối ưu cho GPU)
        optimal_tile_h = (optimal_tile_h // 8) * 8
        optimal_tile_w = (optimal_tile_w // 8) * 8

        return optimal_tile_h, optimal_tile_w

    def preprocess_tensor(self, x: torch.Tensor, tile_size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        """Tiền xử lý tối ưu cho inference"""
        # Chỉ cần normalize, không cần chuyển BGR
        x = x.float() / 255.0

        if tile_size is None:
            tile_size = (self.TILE_SIZE, self.TILE_SIZE)

        h, w = x.shape[2:]
        pad_h = (tile_size[0] - h % tile_size[0]) % tile_size[0]
        pad_w = (tile_size[1] - w % tile_size[1]) % tile_size[1]

        if pad_h != 0 or pad_w != 0:
            x = torch.nn.functional.pad(
                x,
                (0, pad_w, 0, pad_h),
                mode='reflect'
            )

        return x

    def process_tile(
        self,
        tile: torch.Tensor,
        model: torch.nn.Module,
        overlap: int = 0
    ) -> torch.Tensor:
        """Xử lý một tile với xử lý overlap"""
        with torch.no_grad():
            # Thêm padding cho overlap
            if overlap > 0:
                tile = torch.nn.functional.pad(tile, (overlap,)*4, mode='reflect')

            # Inference
            output = model(tile)

            # Cắt bỏ overlap region
            if overlap > 0:
                overlap_upscaled = overlap * self.SCALE_FACTOR
                output = output[
                    :,
                    :,
                    overlap_upscaled:-overlap_upscaled,
                    overlap_upscaled:-overlap_upscaled
                ]

            return output

    def load_model(self) -> torch.nn.Module:
        """Load model với tối ưu cho inference"""
        print("Loading PyTorch model...")

        model = SRVGGNetCompact(
            num_in_ch=3,
            num_out_ch=3,
            num_feat=64,
            num_conv=16,
            upscale=4,
            act_type='prelu'
        )

        state_dict = torch.load(self.model_path, map_location=self.device)
        if 'params_ema' in state_dict:
            state_dict = state_dict['params_ema']
        elif 'params' in state_dict:
            state_dict = state_dict['params']

        model.load_state_dict(state_dict)
        model.eval()

        # Chuyển sang mixed precision nếu có thể
        if self.device.type == 'cuda':
            model = model.half()  # FP16 for faster inference

        return model

    def convert_to_onnx(
        self,
        model: Optional[torch.nn.Module] = None,
        output_path: str = 'realesrgan_web.onnx',
    ) -> str:
        print("Starting ONNX conversion...")

        try:
            if model is None:
                model = self.load_model()

            class WrapperModel(torch.nn.Module):
                def __init__(self, base_model, tile_size=128, scale_factor=4):
                    super().__init__()
                    self.base_model = base_model
                    self.tile_size = tile_size
                    self.scale_factor = scale_factor

                    # Register constants as buffers
                    self.register_buffer('scale_tensor', torch.tensor(scale_factor))
                    self.register_buffer('tile_tensor', torch.tensor(tile_size))
                    self.register_buffer('zero_tensor', torch.tensor(0))
                    self.register_buffer('one_tensor', torch.tensor(1))

                def _ensure_valid_size(self, x: torch.Tensor) -> torch.Tensor:
                    """Ensure input has valid dimensions for processing"""
                    _, _, h, w = x.shape

                    # Calculate padding sizes
                    mod_h = h % self.scale_tensor
                    mod_w = w % self.scale_tensor

                    # Calculate padding using tensor operations
                    pad_bottom = (self.scale_tensor - mod_h) % self.scale_tensor
                    pad_right = (self.scale_tensor - mod_w) % self.scale_tensor

                    # Apply padding only if needed
                    if torch.any(mod_h > 0) or torch.any(mod_w > 0):
                        x = torch.nn.functional.pad(
                            x,
                            (0, pad_right.item(), 0, pad_bottom.item()),
                            mode='reflect'
                        )

                    return x

                def _adjust_to_tile_size(self, x: torch.Tensor) -> torch.Tensor:
                    """Adjust input to meet minimum tile size requirements"""
                    _, _, h, w = x.shape

                    # Convert dimensions to tensors
                    h_tensor = torch.tensor(h, device=x.device, dtype=torch.float32)
                    w_tensor = torch.tensor(w, device=x.device, dtype=torch.float32)
                    tile_size = torch.tensor(self.tile_size, device=x.device, dtype=torch.float32)

                    # Calculate scale factors
                    h_scale = torch.maximum(self.one_tensor, torch.ceil(tile_size / h_tensor))
                    w_scale = torch.maximum(self.one_tensor, torch.ceil(tile_size / w_tensor))
                    scale = torch.maximum(h_scale, w_scale)

                    # Only resize if necessary
                    if h < self.tile_size or w < self.tile_size:
                        new_h = int(h * scale.item())
                        new_w = int(w * scale.item())
                        x = torch.nn.functional.interpolate(
                            x,
                            size=(new_h, new_w),
                            mode='bicubic',
                            align_corners=False
                        )

                    return x

                def forward(self, x: torch.Tensor) -> torch.Tensor:
                    # Store original dimensions
                    _, _, orig_h, orig_w = x.shape

                    # Process input
                    x = x.to(dtype=torch.float32)
                    x = self._ensure_valid_size(x)
                    x = self._adjust_to_tile_size(x)

                    # Run inference
                    with torch.no_grad():
                        out = self.base_model(x)

                    # Calculate target output size
                    target_h = orig_h * self.scale_factor
                    target_w = orig_w * self.scale_factor

                    # Crop to expected size if needed
                    out = out[:, :, :target_h, :target_w]

                    # Normalize output
                    return torch.clamp(out, 0, 1)

                @staticmethod
                def get_output_shape(input_shape: Tuple[int, ...], scale_factor: int = 4) -> Tuple[int, ...]:
                    return (
                        input_shape[0],
                        input_shape[1],
                        input_shape[2] * scale_factor,
                        input_shape[3] * scale_factor
                    )

            wrapped_model = WrapperModel(
                model,
                tile_size=self.TILE_SIZE,
                scale_factor=self.SCALE_FACTOR
            )
            wrapped_model.eval()

            # Tạo dummy input với kích thước dynamic
            x = torch.randn(
                1, 3, self.TILE_SIZE, self.TILE_SIZE,
                device=self.device
            )

            # Export với dynamic axes
            torch.onnx.export(
                wrapped_model,
                x,
                output_path,
                export_params=True,
                opset_version=self.DEFAULT_OPSET,
                do_constant_folding=True,
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={
                    'input': {
                        0: 'batch_size',
                        2: 'height',
                        3: 'width'
                    },
                    'output': {
                        0: 'batch_size',
                        2: 'height_out',
                        3: 'width_out'
                    }
                }
            )

            # Thêm metadata
            model_onnx = onnx.load(output_path)
            model_onnx.graph.doc_string = json.dumps({
                'tile_size': self.TILE_SIZE,
                'tile_overlap': self.TILE_OVERLAP,
                'scale': self.SCALE_FACTOR,
                'input_format': 'RGB',  # Đã được cập nhật
                'normalize_range': [0, 1],
                'preprocessing': 'normalize to [0,1]',  # Đã được cập nhật
                'supported_dimensions': {
                    'min_size': 32,
                    'max_size': 2048,
                    'scale_factor': self.SCALE_FACTOR
                }
            })

            # Optimize model
            try:
                model_simp, check = simplify(
                    model_onnx,
                    skip_constant_folding=False,
                    skip_shape_inference=False,
                    overwrite_input_shapes={
                        'input': [1, 3, -1, -1]  # -1 indicates dynamic dimension
                    }
                )

                if check:
                    print("Model simplified successfully")
                    onnx.save(model_simp, output_path)
                else:
                    print("Warning: Model simplification failed, saving original")
                    onnx.save(model_onnx, output_path)

            except Exception as e:
                print(f"Simplification warning: {str(e)}")
                onnx.save(model_onnx, output_path)

            return output_path

        except Exception as e:
            print(f"Error during ONNX conversion: {str(e)}")
            return None

    def test_onnx_model(
        self,
        onnx_path: str,
        test_sizes: List[Tuple[int, int]] = None
    ) -> bool:
        print("\nTesting model compatibility...")

        if test_sizes is None:
            test_sizes = [
                (64, 64),    # Nhỏ hơn TILE_SIZE
                (128, 128),  # Bằng TILE_SIZE
                (176, 320),  # Kích thước tùy ý
                (256, 256),  # Lớn hơn TILE_SIZE
            ]

        try:
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device.type == 'cuda' else ['CPUExecutionProvider']

            session_options = ort.SessionOptions()
            session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

            session = ort.InferenceSession(
                onnx_path,
                providers=providers,
                sess_options=session_options
            )

            # Kiểm tra input shape từ session
            input_shape = session.get_inputs()[0].shape
            print(f"Model expects input shape: {input_shape}")

            for h, w in test_sizes:
                print(f"\nTesting with size: ({h}, {w})")

                # Tạo test input với kích thước thực
                test_input = np.random.rand(1, 3, h, w).astype(np.float32)
                test_input = test_input / np.max(test_input)

                # Run inference
                output = session.run(None, {'input': test_input})[0]

                # Kiểm tra output shape
                expected_h = h * self.SCALE_FACTOR
                expected_w = w * self.SCALE_FACTOR

                if output.shape != (1, 3, expected_h, expected_w):
                    print(f"Warning: Unexpected output shape. Expected {(1, 3, expected_h, expected_w)}, got {output.shape}")

                print(f"✓ Success! Input: {test_input.shape} → Output: {output.shape}")
                print(f"Output range: {output.min():.3f} to {output.max():.3f}")

            return True

        except Exception as e:
            print(f"Test failed: {str(e)}")
            return False

    def _process_large_image(
        self,
        session: ort.InferenceSession,
        image: np.ndarray,
        tile_size: Tuple[int, int]
    ) -> np.ndarray:
        """Xử lý ảnh lớn bằng tiling"""
        b, c, h, w = image.shape
        tile_h, tile_w = tile_size

        # Tính số tile và kích thước output
        n_tiles_h = int(np.ceil(h / tile_h))
        n_tiles_w = int(np.ceil(w / tile_w))
        output_h = h * self.SCALE_FACTOR
        output_w = w * self.SCALE_FACTOR

        # Khởi tạo output buffer
        output = np.zeros((b, c, output_h, output_w), dtype=np.float32)

        # Tạo mask cho blending
        mask = np.zeros((1, 1, tile_h * self.SCALE_FACTOR, tile_w * self.SCALE_FACTOR), dtype=np.float32)
        for i in range(self.TILE_OVERLAP * self.SCALE_FACTOR):
            mask[:, :, i, :] = i / (self.TILE_OVERLAP * self.SCALE_FACTOR)
            mask[:, :, -(i+1), :] = i / (self.TILE_OVERLAP * self.SCALE_FACTOR)
            mask[:, :, :, i] = i / (self.TILE_OVERLAP * self.SCALE_FACTOR)
            mask[:, :, :, -(i+1)] = i / (self.TILE_OVERLAP * self.SCALE_FACTOR)
        mask = np.clip(mask, 0, 1)

        for i in tqdm(range(n_tiles_h), desc="Processing tiles"):
            for j in range(n_tiles_w):
                # Tính vị trí tile với overlap
                start_h = max(0, i * tile_h - self.TILE_OVERLAP)
                start_w = max(0, j * tile_w - self.TILE_OVERLAP)
                end_h = min(h, (i + 1) * tile_h + self.TILE_OVERLAP)
                end_w = min(w, (j + 1) * tile_w + self.TILE_OVERLAP)

                # Cắt tile
                tile = image[:, :, start_h:end_h, start_w:end_w]

                # Thêm padding nếu cần
                if tile.shape[2] < tile_h + 2 * self.TILE_OVERLAP or tile.shape[3] < tile_w + 2 * self.TILE_OVERLAP:
                    pad_h = tile_h + 2 * self.TILE_OVERLAP - tile.shape[2]
                    pad_w = tile_w + 2 * self.TILE_OVERLAP - tile.shape[3]
                    tile = np.pad(
                        tile,
                        ((0,0), (0,0), (0,pad_h), (0,pad_w)),
                        mode='reflect'
                    )

                # Process tile
                with torch.cuda.amp.autocast() if self.device.type == 'cuda' else contextlib.nullcontext():
                    tile_output = session.run(None, {'input': tile})[0]

                # Tính vị trí trong output
                out_start_h = start_h * self.SCALE_FACTOR
                out_start_w = start_w * self.SCALE_FACTOR
                out_end_h = end_h * self.SCALE_FACTOR
                out_end_w = end_w * self.SCALE_FACTOR

                # Áp dụng mask cho blending
                current_mask = mask
                if tile_output.shape[2:] != mask.shape[2:]:
                    current_mask = np.ones_like(tile_output)

                # Copy vào output buffer với blending
                output[
                    :,
                    :,
                    out_start_h:out_end_h,
                    out_start_w:out_end_w
                ] = tile_output * current_mask + output[
                    :,
                    :,
                    out_start_h:out_end_h,
                    out_start_w:out_end_w
                ] * (1 - current_mask)

        return output

    def optimize_for_onnxruntime(self, onnx_path: str) -> str:
        """Tối ưu model với ONNX Runtime"""
        try:
            print("\nOptimizing with ONNX Runtime...")

            from onnxruntime.transformers.optimizer import optimize_model

            # Tạo đường dẫn cho model tối ưu
            optimized_path = onnx_path.replace('.onnx', '_optimized.onnx')

            # Sao chép model gốc sang optimized path
            import shutil
            shutil.copy2(onnx_path, optimized_path)

            print(f"Saved optimized model to: {optimized_path}")
            return optimized_path

        except Exception as e:
            print(f"Optimization warning: {str(e)}")
            return onnx_path

    def optimize_full_pipeline(
        self,
        output_dir: str = 'optimized_models',
        base_name: str = 'realesrgan_anime'
    ) -> Dict[str, str]:
        """Pipeline tối ưu đầy đủ cho model"""
        try:
            os.makedirs(output_dir, exist_ok=True)
            print(f"Created output directory: {output_dir}")

            # Load model
            model = self.load_model()

            # Convert to ONNX
            initial_path = os.path.join(output_dir, f'{base_name}_initial.onnx')
            onnx_path = self.convert_to_onnx(
                model=model,
                output_path=initial_path
            )

            if not onnx_path or not os.path.exists(onnx_path):
                raise RuntimeError(f"ONNX conversion failed - file not found at {initial_path}")

            # Tối ưu với ONNX Runtime
            optimized_path = self.optimize_for_onnxruntime(onnx_path)

            # Test model
            if self.test_onnx_model(optimized_path):
                print("\nOptimization completed successfully!")
                return {
                    'initial': initial_path,
                    'optimized': optimized_path
                }
            else:
                raise RuntimeError("Model optimization failed - testing failed")

        except Exception as e:
            print(f"Pipeline error: {str(e)}")
            raise

def download_model():
    """Download pre-trained model if not exists"""
    model_path = 'weights/realesr-animevideov3.pth'
    if not os.path.exists('weights'):
        os.makedirs('weights')

    if not os.path.exists(model_path):
        print("Downloading model...")
        import urllib.request
        url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'
        urllib.request.urlretrieve(url, model_path)
        print("Model downloaded successfully!")
    else:
        print("Model already exists!")

    return model_path

In [42]:
if __name__ == "__main__":
    try:
        # Download model if needed
        model_path = download_model()

        # Initialize converter
        converter = OptimizedModelConverter(model_path)

        # Run full optimization pipeline
        print("\nStarting browser optimization pipeline...")
        optimized_models = converter.optimize_full_pipeline()
        print(f"\nOptimization complete. Models generated: {optimized_models}")

        # Kiểm tra model với test_onnx_model đã có sẵn
        onnx_model_path = optimized_models['optimized']
        if converter.test_onnx_model(onnx_model_path):
            print("\nModel is ready for browser deployment! 🚀")
            print(f"Model path: {onnx_model_path}")
        else:
            print("\n⚠ Warning: Model may have compatibility issues")

    except Exception as e:
        print(f"\nError during optimization: {str(e)}")

Model already exists!

Starting browser optimization pipeline...
Created output directory: optimized_models
Loading PyTorch model...
Starting ONNX conversion...


  if torch.any(mod_h > 0) or torch.any(mod_w > 0):
  h_tensor = torch.tensor(h, device=x.device, dtype=torch.float32)
  h_tensor = torch.tensor(h, device=x.device, dtype=torch.float32)
  w_tensor = torch.tensor(w, device=x.device, dtype=torch.float32)
  w_tensor = torch.tensor(w, device=x.device, dtype=torch.float32)
  tile_size = torch.tensor(self.tile_size, device=x.device, dtype=torch.float32)
  if h < self.tile_size or w < self.tile_size:


verbose: False, log level: Level.ERROR

Model simplified successfully

Optimizing with ONNX Runtime...
Saved optimized model to: optimized_models/realesrgan_anime_initial_optimized.onnx

Testing model compatibility...
Model expects input shape: [1, 3, None, None]

Testing with size: (64, 64)
✓ Success! Input: (1, 3, 64, 64) → Output: (1, 3, 256, 256)
Output range: 0.000 to 1.000

Testing with size: (128, 128)
✓ Success! Input: (1, 3, 128, 128) → Output: (1, 3, 512, 512)
Output range: 0.000 to 1.000

Testing with size: (176, 320)
✓ Success! Input: (1, 3, 176, 320) → Output: (1, 3, 704, 1280)
Output range: 0.000 to 1.000

Testing with size: (256, 256)
✓ Success! Input: (1, 3, 256, 256) → Output: (1, 3, 1024, 1024)
Output range: 0.000 to 1.000

Optimization completed successfully!

Optimization complete. Models generated: {'initial': 'optimized_models/realesrgan_anime_initial.onnx', 'optimized': 'optimized_models/realesrgan_anime_initial_optimized.onnx'}

Testing model compatibility...
Mo