# Test SR model

In [None]:
import numpy as np 
import torch
from PIL import Image
import matplotlib.pyplot as plt

# 假设你的项目路径配置正确
from src.emcfsys.EMCellFiner.hat.models.hat_model import HATModel
from src.emcfsys.EMCellFiner.hat.models.img_utils import tensor2img

# 1. 初始化模型
path = r"D:\napari_EMCF\EMCFsys\models\EMCellFiner.pth"
# 显式指定 tile_size，防止显存溢出；对于小图可以不用 tile
model = HATModel(scale=4, tile_size=512) 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 2. 读取与预处理
img_path = r"D:\napari_EMCF\EMCFsys\emcfsys\image\Bock2011_2951_XrV1ciGgTWHjepNf.tif"
img = Image.open(img_path).convert("RGB")

# 转换为 Numpy 并归一化
img_np = np.array(img).astype(np.float32) / 255.
# 转换为 Tensor: (H, W, C) -> (C, H, W) -> (1, C, H, W)
img_torch = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
print(f"Input Shape: {img_torch.shape}")
# 3. 推理
with torch.no_grad():
    output = model(img_torch) 
print(f"Output Shape: {output.shape}")
# 4. 后处理
output = output.cpu()
img_out = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
# 5. 转回 PIL 图片
img_final = Image.fromarray(img_out)

# 验证结果
# img_final.show() 
# img_final.save("result_sr.png")
print("Done.")

In [None]:
from src.emcfsys.EMCellFiner.hat.models.inference_hat import hat_infer_numpy
from src.emcfsys.EMCellFiner.hat.models.hat_model import HATModel
from PIL import Image
import numpy as np
import torch
import time
start = time.time()
img_path = r"D:\napari_EMCF\EMCFsys\emcfsys\image\Bock2011_2951_XrV1ciGgTWHjepNf.tif"
img = Image.open(img_path).convert("RGB").crop([0,0,512,512])
img_np = np.array(img)

model = HATModel(scale=4, tile_size=512)
device = "cpu" #torch.device("cuda" if torch.cuda.is_available() else "cpu")

out = hat_infer_numpy(
    model= model,
    image= img_np,
    device=device,
)

end = time.time()
print("inference time: ", end - start)

# add torch.compile

In [1]:
from src.emcfsys.EMCellFiner.hat.models.inference_hat import hat_infer_numpy
from src.emcfsys.EMCellFiner.hat.models.hat_model import HATModel
from PIL import Image
import numpy as np
import torch
import time
start = time.time()
img_path = r"D:\napari_EMCF\EMCFsys\emcfsys\image\Bock2011_2951_XrV1ciGgTWHjepNf.tif"
img = Image.open(img_path).convert("RGB").crop([0,0,512,512])
img_np = np.array(img)

model = HATModel(scale=4, tile_size=512)
device = "cuda" #torch.device("cuda" if torch.cuda.is_available() else "cpu")

out = hat_infer_numpy(
    model= model,
    image= img_np,
    device=device,
)

end = time.time()
print("inference time: ", end - start)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Using the model from torch hub : https://github.com/yzy0102/emcfsys/releases/latest/download/EMCellFiner.pth
	Tile 1/1
inference time:  4.851733922958374


# Export model to ONNX

## ONNX can't speed up many times1!

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from src.emcfsys.EMCellFiner.hat.models.hat_arch import HAT

model = HAT()

state_dict = torch.load(r"D:\napari_EMCF\EMCFsys\models\EMCellFiner.pth", map_location='cpu')
model.load_state_dict(state_dict['params'], strict=True)


def export_hat_to_onnx(model, onnx_path="HATModel.onnx", device="cpu"):
    """
    导出 HAT 网络到 ONNX
    注意：tile_process 不会导出，只能做整张图推理
    """
    model.eval()
    model.to(device)

    # 示例输入
    dummy_input = torch.randn(1, 3, 64, 64, device=device)  # B,C,H,W

    # 导出 ONNX
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=17,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {2: 'height', 3: 'width'},
            'output': {2: 'height', 3: 'width'},
        }
    )
    print(f"✅ ONNX model saved at {onnx_path}")

export_hat_to_onnx(model, onnx_path="HATModel.onnx", device="cpu")

In [None]:
import onnxruntime as ort
import numpy as np
from PIL import Image
from src.emcfsys.EMCellFiner.hat.models.img_utils import tensor2img
import torch
import time
start = time.time()
ort_sess = ort.InferenceSession("HATModel.onnx")

input_name = ort_sess.get_inputs()[0].name
output_name = ort_sess.get_outputs()[0].name

img_path = r"src\emcfsys\test_imgs\test_img.tif"
image = np.array(Image.open(img_path).convert("RGB").crop([0,0,512,512]))/255.

img_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).numpy().astype(np.float32)

out = ort_sess.run([output_name], {input_name: img_tensor})
# out[0] to image 1, 3, 256, 256 -> 256, 256, 3
out = torch.from_numpy(np.array(out))
out = tensor2img(out, rgb2bgr=False, min_max=(0, 1))
end = time.time()

print("inference time: ", end - start)
print(out.shape)
print(out.max())

In [None]:
Image.fromarray(out)