In [1]:
# Gỡ cài đặt phiên bản cũ
!pip uninstall -y tensorflow tensorflowjs

# Clear pip cache
!pip cache purge

# Cài đặt phiên bản mới với --no-cache-dir
!pip install --no-cache-dir tensorflow==2.18.0
!pip install --no-cache-dir tensorflowjs==4.22.0

# Cài đặt các dependencies khác
!pip install --no-cache-dir torch torchvision
!git clone https://github.com/XPixelGroup/BasicSR.git
%cd BasicSR
!pip install -r requirements.txt
!python setup.py develop
%cd ..

# Verify versions
import tensorflow as tf
import tensorflowjs as tfjs
print("TensorFlow version:", tf.__version__)
print("TensorFlow.js version:", tfjs.__version__)

Found existing installation: tensorflow 2.18.0
Uninstalling tensorflow-2.18.0:
  Successfully uninstalled tensorflow-2.18.0
Found existing installation: tensorflowjs 4.22.0
Uninstalling tensorflowjs-4.22.0:
  Successfully uninstalled tensorflowjs-4.22.0
[0mFiles removed: 0
Collecting tensorflow==2.18.0
  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m81.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow
Successfully installed tensorflow-2.18.0
Collecting tensorflowjs==4.22.0
  Downloading tensorflowjs-4.22.0-py3-none-any.whl.metadata (3.2 kB)
Downloading tensorflowjs-4.22.0-py3-none-any.whl (89 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.1/89.1 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m


TensorFlow version: 2.18.0
TensorFlow.js version: 4.22.0


In [112]:
import os
import json
import torch
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
from basicsr.archs.srvgg_arch import SRVGGNetCompact
import tensorflowjs as tfjs

class ScaleLayer(layers.Layer):
    """Custom layer để kiểm soát scale"""
    def __init__(self, scale_factor, **kwargs):
        super().__init__(**kwargs)
        self.scale_factor = scale_factor

    def call(self, inputs):
        return inputs * self.scale_factor

    def get_config(self):
        config = super().get_config()
        config.update({"scale_factor": self.scale_factor})
        return config

class DynamicReshapeLayer(layers.Layer):
    """PixelShuffle layer với xử lý permute chuẩn xác hơn"""
    def __init__(self, scale_factor, **kwargs):
        super().__init__(**kwargs)
        self.scale_factor = scale_factor

    def call(self, inputs):
        # Get dynamic input shape
        batch_size = tf.shape(inputs)[0]
        h = tf.shape(inputs)[1]
        w = tf.shape(inputs)[2]
        c = tf.shape(inputs)[3]

        # Calculate output dimensions
        new_c = c // (self.scale_factor * self.scale_factor)
        new_h = h * self.scale_factor
        new_w = w * self.scale_factor

        # Reshape để match với PyTorch PixelShuffle
        # [B, H, W, C] -> [B, H, W, r, r, C/(r*r)]
        x = tf.reshape(inputs, [batch_size, h, w, self.scale_factor, self.scale_factor, new_c])

        # Transpose để đúng thứ tự như PyTorch
        # [B, H, W, r, r, C/(r*r)] -> [B, H, r, W, r, C/(r*r)]
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])

        # Reshape final để có được output format mong muốn
        # [B, H, r, W, r, C/(r*r)] -> [B, H*r, W*r, C/(r*r)]
        x = tf.reshape(x, [batch_size, new_h, new_w, new_c])

        return x

    def get_config(self):
        config = super().get_config()
        config.update({"scale_factor": self.scale_factor})
        return config

In [117]:
class TFJSModelConverter:
    def __init__(
        self,
        model_path: str = 'weights/realesr-animevideov3.pth',
        output_path: str = 'web_model',
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.model_path = model_path
        self.output_path = output_path
        self.device = torch.device(device)

        # Cấu hình cố định cho model
        self.SCALE_FACTOR = 4
        self.INPUT_CHANNELS = 3
        self.OUTPUT_CHANNELS = 3
        self.TILE_SIZE = 128

        # Định nghĩa tên node thống nhất
        self.MODEL_NAME = "RealESRGAN"
        self.INPUT_TENSOR_NAME = "input_tensor"
        self.OUTPUT_TENSOR_NAME = "output_tensor"
        self.INPUT_SIGNATURE_NAME = f"serving_default_{self.INPUT_TENSOR_NAME}"
        self.OUTPUT_SIGNATURE_NAME = f"StatefulPartitionedCall_{self.OUTPUT_TENSOR_NAME}"

        if device == 'cuda':
            self.setup_cuda()

    def setup_cuda(self):
        """Tối ưu CUDA cho conversion"""
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    def load_pytorch_model(self) -> torch.nn.Module:
        """Load PyTorch model"""
        try:
            model = SRVGGNetCompact(
                num_in_ch=self.INPUT_CHANNELS,
                num_out_ch=self.OUTPUT_CHANNELS,
                num_feat=64,
                num_conv=16,
                upscale=self.SCALE_FACTOR,
                act_type='prelu'
            )

            state_dict = torch.load(self.model_path, map_location=self.device, weights_only=True)
            if 'params_ema' in state_dict:
                state_dict = state_dict['params_ema']
            elif 'params' in state_dict:
                state_dict = state_dict['params']

            model.load_state_dict(state_dict)
            model.eval()

            print("\nPyTorch Model Structure:")
            print(model)

            return model

        except Exception as e:
            print(f"Error in load_pytorch_model: {str(e)}")
            raise e

    def create_tensorflow_model(self, torch_model: torch.nn.Module):
        """Convert PyTorch model với normalize fix"""
        try:
            inputs = layers.Input(shape=(None, None, self.INPUT_CHANNELS),
                               name=self.INPUT_TENSOR_NAME)

            # Normalize về range [0,1]
            x = layers.Lambda(
                lambda x: x / 255.0,
                name='normalize'
            )(inputs)

            # Process through model layers
            for i, module in enumerate(torch_model.body):
                if isinstance(module, torch.nn.Conv2d):
                    weight = module.weight.detach().cpu().numpy()
                    weight = np.transpose(weight, (2, 3, 1, 0))

                    conv = layers.Conv2D(
                        filters=weight.shape[-1],
                        kernel_size=weight.shape[:2],
                        padding='same',
                        use_bias=module.bias is not None,
                        kernel_initializer=tf.keras.initializers.Constant(weight),
                        # Removed the * 0.5 scaling for bias
                        bias_initializer='zeros' if module.bias is None else tf.keras.initializers.Constant(module.bias.detach().cpu().numpy()),
                        name=f'conv_{i}'
                    )(x)
                    x = conv

                elif isinstance(module, torch.nn.PReLU):
                    weight = module.weight.detach().cpu().numpy()
                    weight = np.reshape(weight, (1, 1, -1))

                    prelu = layers.PReLU(
                        alpha_initializer=tf.keras.initializers.Constant(weight),
                        shared_axes=[1, 2],
                        name=f'prelu_{i}'
                    )(x)
                    x = prelu

            # Add scale correction before shuffle
            x = layers.Lambda(
                lambda x: x + 0.5,  # Add offset to match PyTorch
                name='scale_correction'
            )(x)

            # Track state before shuffle
            x = layers.Lambda(
                lambda x: tf.identity(x),
                name='before_shuffle'
            )(x)

            # PixelShuffle operation
            x = DynamicReshapeLayer(
                scale_factor=self.SCALE_FACTOR,
                name='pixel_shuffle'
            )(x)

            # Track state after shuffle
            x = layers.Lambda(
                lambda x: tf.identity(x),
                name='after_shuffle'
            )(x)

            # Denormalize về range [0,255]
            x = layers.Lambda(
                lambda x: x * 255.0,
                name='denormalize'
            )(x)

            # Ensure output range
            outputs = layers.Lambda(
                lambda x: tf.clip_by_value(x, 0.0, 255.0),
                name=self.OUTPUT_TENSOR_NAME
            )(x)

            model = tf.keras.Model(inputs=inputs, outputs=outputs, name=self.MODEL_NAME)
            print("\nTensorFlow Model Summary:")
            model.summary()

                        # Setup detailed debug model
            self.debug_model = tf.keras.Model(
                inputs=model.input,
                outputs={
                    'normalize': model.get_layer('normalize').output,
                    'scale_correction': model.get_layer('scale_correction').output,
                    'before_shuffle': model.get_layer('before_shuffle').output,
                    'after_shuffle': model.get_layer('after_shuffle').output,
                    'denormalize': model.get_layer('denormalize').output,
                    'final': outputs
                }
            )

            return model

        except Exception as e:
            print(f"Error in create_tensorflow_model: {str(e)}")
            raise e

    def save_tfjs_model(self, model, output_path: str):
        """Save model in TFJS GraphModel format"""
        try:
            # Bước 1: Lưu model với signature
            temp_saved_model_path = "temp_saved_model"

            @tf.function(input_signature=[
                tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=self.INPUT_TENSOR_NAME)
            ])
            def serving_fn(input_tensor):
                return {self.OUTPUT_TENSOR_NAME: model(input_tensor)}

            tf.saved_model.save(
                model,
                temp_saved_model_path,
                signatures={'serving_default': serving_fn}
            )

            # Bước 2: Convert sang TFJS GraphModel
            tfjs.converters.convert_tf_saved_model(
                temp_saved_model_path,
                output_path
            )

            # Bước 3: Cleanup và update metadata
            import shutil
            if os.path.exists(temp_saved_model_path):
                shutil.rmtree(temp_saved_model_path)

            metadata_path = os.path.join(output_path, 'model.json')
            if os.path.exists(metadata_path):
                with open(metadata_path, 'r') as f:
                    model_json = json.load(f)

                updated_model_json = {
                    "format": "graph-model",
                    "generatedBy": f"tfjs@4.22.0",
                    "convertedBy": "tfjs-converter",
                    "modelTopology": model_json.get("modelTopology", {}),
                    "weightsManifest": model_json.get("weightsManifest", []),
                    "signature": {
                        "inputs": {
                            self.INPUT_SIGNATURE_NAME: {
                                "name": self.INPUT_SIGNATURE_NAME,
                                "dtype": "float32",
                                "shape": [-1, -1, -1, 3]
                            }
                        },
                        "outputs": {
                            self.OUTPUT_SIGNATURE_NAME: {
                                "name": self.OUTPUT_SIGNATURE_NAME,
                                "dtype": "float32",
                                "shape": [-1, -1, -1, 3]
                            }
                        }
                    },
                    "userDefinedMetadata": {
                        "scale_factor": self.SCALE_FACTOR,
                        "tile_size": self.TILE_SIZE,
                        "input_format": "RGB",
                        "preprocessing": "normalize to [0,1]",
                        "postprocessing": "clip and scale to [0,255]",
                        "backend": "webgpu"
                    }
                }

                with open(metadata_path, 'w') as f:
                    json.dump(updated_model_json, f, indent=2)

        except Exception as e:
            print(f"Error in save_tfjs_model: {str(e)}")
            raise e

    def verify_conversion(self, torch_model: torch.nn.Module, tf_model):
        """Verify với debug chi tiết hơn"""
        try:
            # Tạo input test đơn giản hơn
            test_input = np.ones((1, 32, 32, 3), dtype=np.float32) * 127.5  # Use middle value

            print("\n=== PyTorch Pipeline ===")
            with torch.no_grad():
                torch_input = torch.from_numpy(test_input).permute(0, 3, 1, 2)
                if torch.cuda.is_available():
                    torch_input = torch_input.cuda()
                    torch_model = torch_model.cuda()

                # Track intermediate values
                x = torch_input / 255.0  # Normalize
                print(f"After normalize: [{x.min().item():.6f}, {x.max().item():.6f}]")

                # Track first conv output
                first_conv = None
                for module in torch_model.body:
                    if isinstance(module, torch.nn.Conv2d):
                        first_conv = module
                        break
                if first_conv:
                    first_out = first_conv(x)
                    print(f"After first conv: [{first_out.min().item():.6f}, {first_out.max().item():.6f}]")

                x = torch_model(x)  # Full model
                print(f"After model: [{x.min().item():.6f}, {x.max().item():.6f}]")

                x = x * 255.0  # Denormalize
                print(f"After denormalize: [{x.min().item():.6f}, {x.max().item():.6f}]")

                x = torch.clamp(x, 0, 255)  # Clip
                print(f"After clip: [{x.min().item():.6f}, {x.max().item():.6f}]")

                if torch.cuda.is_available():
                    x = x.cpu()
                torch_output = x.numpy().transpose(0, 2, 3, 1)

            print("\n=== TensorFlow Pipeline ===")
            outputs = self.debug_model(test_input, training=False)

            # Get first conv layer output
            first_conv_tf = tf_model.get_layer('conv_0')
            first_conv_out = first_conv_tf(outputs['normalize'])
            print(f"After first conv: [{first_conv_out.numpy().min():.6f}, {first_conv_out.numpy().max():.6f}]")

            print(f"After normalize: [{outputs['normalize'].numpy().min():.6f}, {outputs['normalize'].numpy().max():.6f}]")
            print(f"After scale correction: [{outputs['scale_correction'].numpy().min():.6f}, {outputs['scale_correction'].numpy().max():.6f}]")
            print(f"Before shuffle: [{outputs['before_shuffle'].numpy().min():.6f}, {outputs['before_shuffle'].numpy().max():.6f}]")
            print(f"After shuffle: [{outputs['after_shuffle'].numpy().min():.6f}, {outputs['after_shuffle'].numpy().max():.6f}]")
            print(f"After denormalize: [{outputs['denormalize'].numpy().min():.6f}, {outputs['denormalize'].numpy().max():.6f}]")
            print(f"Final output: [{outputs['final'].numpy().min():.6f}, {outputs['final'].numpy().max():.6f}]")

            tf_output = outputs['final']

            print("\n=== Comparison ===")
            print(f"PyTorch shape: {torch_output.shape}")
            print(f"TF shape: {tf_output.shape}")
            print(f"PyTorch range: [{torch_output.min():.6f}, {torch_output.max():.6f}]")
            print(f"TF range: [{tf_output.numpy().min():.6f}, {tf_output.numpy().max():.6f}]")

            if torch_output.shape == tf_output.shape:
                abs_diff = np.abs(tf_output - torch_output)
                print(f"\nMax difference: {np.max(abs_diff):.6f}")
                print(f"Mean difference: {np.mean(abs_diff):.6f}")
                print(f"Std of difference: {np.std(abs_diff):.6f}")

                # Print more details about first few pixels
                print("\nFirst few pixels comparison:")
                for i in range(min(5, torch_output.shape[1])):
                    pt_val = torch_output[0,i,0,0]
                    tf_val = tf_output[0,i,0,0]
                    print(f"Pixel {i}: PT={pt_val:.6f}, TF={tf_val:.6f}, Diff={abs(pt_val-tf_val):.6f}")

            else:
                print("\nERROR: Output shapes do not match!")

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error in verify_conversion: {str(e)}")
            raise e

    def verify_weights(self, torch_model: torch.nn.Module, tf_model):
        """Verify weights với debug chi tiết"""
        try:
            print("\n=== Weights Verification ===")

            # Lấy layer đầu tiên từ cả hai model
            first_pt_conv = None
            for module in torch_model.body:
                if isinstance(module, torch.nn.Conv2d):
                    first_pt_conv = module
                    break

            if first_pt_conv is not None:
                first_tf_conv = tf_model.get_layer('conv_0')

                # So sánh weights
                pt_weights = first_pt_conv.weight.detach().cpu().numpy()
                tf_weights = first_tf_conv.get_weights()[0]

                # Chuyển đổi format để so sánh
                pt_weights_converted = np.transpose(pt_weights, (2, 3, 1, 0))  # OIHW -> HWIO

                print(f"PyTorch weights shape: {pt_weights.shape}")
                print(f"TF weights shape: {tf_weights.shape}")
                print("\nFirst few weights comparison:")
                print(f"PyTorch (original): {pt_weights.flatten()[:5]}")
                print(f"PyTorch (converted): {pt_weights_converted.flatten()[:5]}")
                print(f"TensorFlow: {tf_weights.flatten()[:5]}")

                weight_diff = np.abs(pt_weights_converted - tf_weights).max()
                print(f"\nMax weight difference: {weight_diff}")

                # So sánh bias nếu có
                if first_pt_conv.bias is not None:
                    pt_bias = first_pt_conv.bias.detach().cpu().numpy()
                    tf_bias = first_tf_conv.get_weights()[1]
                    print("\nBias comparison:")
                    print(f"PyTorch bias: {pt_bias[:5]}")
                    print(f"TF bias: {tf_bias[:5]}")
                    bias_diff = np.abs(pt_bias - tf_bias).max()
                    print(f"Max bias difference: {bias_diff}")

                # Kiểm tra range của weights
                print("\nWeights range:")
                print(f"PyTorch: [{pt_weights.min()}, {pt_weights.max()}]")
                print(f"TF: [{tf_weights.min()}, {tf_weights.max()}]")

        except Exception as e:
            print(f"Error in verify_weights: {str(e)}")
            raise e

    def convert(self):
        """Run full conversion pipeline"""
        try:
            os.makedirs(self.output_path, exist_ok=True)

            torch_model = self.load_pytorch_model()
            tf_model = self.create_tensorflow_model(torch_model)
            self.verify_weights(torch_model, tf_model)
            self.verify_conversion(torch_model, tf_model)
            self.save_tfjs_model(tf_model, self.output_path)

            print(f"\nConversion completed successfully!")
            print(f"Model saved to: {self.output_path}")
            print("\nUse the following in your web app:")
            print("const model = await tf.loadGraphModel('model/model.json');")

            return True

        except Exception as e:
            print(f"Error during conversion: {str(e)}")
            raise e

In [118]:
def download_model():
    """Download pre-trained model if not exists"""
    model_path = 'weights/realesr-animevideov3.pth'
    if not os.path.exists('weights'):
        os.makedirs('weights')

    if not os.path.exists(model_path):
        print("Downloading model...")
        import urllib.request
        url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth'
        urllib.request.urlretrieve(url, model_path)
        print("Model downloaded successfully!")
    else:
        print("Model already exists!")

    return model_path

if __name__ == "__main__":
    try:
        # Download model if needed
        model_path = download_model()

        # Initialize converter
        converter = TFJSModelConverter(
            model_path=model_path,
            output_path='web_model'
        )

        # Run conversion
        converter.convert()

    except Exception as e:
        print(f"Error: {str(e)}")

Model already exists!

PyTorch Model Structure:
SRVGGNetCompact(
  (body): ModuleList(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): PReLU(num_parameters=64)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): PReLU(num_parameters=64)
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): PReLU(num_parameters=64)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): PReLU(num_parameters=64)
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): PReLU(num_parameters=64)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): PReLU(num_parameters=64)
    (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): PReLU(num_parameters=64)
    (14): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): PReLU(num_parameters=64)
    (16): Conv2d(64, 64, kernel_size


=== Weights Verification ===
PyTorch weights shape: (64, 3, 3, 3)
TF weights shape: (3, 3, 3, 64)

First few weights comparison:
PyTorch (original): [-0.0473705  -0.04728249  0.5571999  -1.8980639   0.07729447]
PyTorch (converted): [-0.0473705  -1.0398519  -0.08468015  0.00962272  0.93967235]
TensorFlow: [-0.0473705  -1.0398519  -0.08468015  0.00962272  0.93967235]

Max weight difference: 0.0

Bias comparison:
PyTorch bias: [ 0.21254003  0.07621118  4.5727463  -9.340443    0.38109306]
TF bias: [ 0.21254003  0.07621118  4.5727463  -9.340443    0.38109306]
Max bias difference: 0.0

Weights range:
PyTorch: [-34.91389465332031, 33.27830123901367]
TF: [-34.91389465332031, 33.27830123901367]

=== PyTorch Pipeline ===
After normalize: [0.500000, 0.500000]
After first conv: [-22.445719, 18.227266]
After model: [0.499740, 0.522922]
After denormalize: [127.433640, 133.345123]
After clip: [127.433640, 133.345123]

=== TensorFlow Pipeline ===
After first conv: [-22.445719, 18.227266]
After normal