In [None]:
import os
import torch
import numpy as np
import cv2
from models.vmunet.vmunet import VMUNet
from configs.config_setting_cbis import setting_config
from utils import *

In [None]:
def load_image(img_path, input_size=(256, 256)):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, input_size)
    img = img.astype(np.float32) / 255.0  # normalize to [0, 1]
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)  # shape: (1, 3, H, W)
    return img_tensor

In [None]:
def run_inference(config, image_path, output_path):
    model_cfg = config.model_config
    model = VMUNet(
        num_classes=model_cfg['num_classes'],
        input_channels=model_cfg['input_channels'],
        depths=model_cfg['depths'],
        depths_decoder=model_cfg['depths_decoder'],
        drop_path_rate=model_cfg['drop_path_rate'],
        load_ckpt_path=model_cfg['load_ckpt_path'],
    )
    model.load_from()
    
    checkpoint = torch.load(os.path.join(config.work_dir, 'checkpoints/best.pth'), map_location='cpu')
    model.load_state_dict(checkpoint)
    model = model.cuda()
    model.eval()

    img_tensor = load_image(image_path).cuda()

    with torch.no_grad():
        output = model(img_tensor)
        pred = torch.sigmoid(output).cpu().numpy()[0, 0]  # binary mask

    # Save prediction mask
    pred_mask = (pred * 255).astype(np.uint8)
    cv2.imwrite(output_path, pred_mask)


In [None]:
# Configuration and paths
config = setting_config
test_image_path = './test_images/image1.png'  # <-- Change this to your image path
output_mask_path = './outputs/mask1.png'      # <-- Desired output path

# Run
run_inference(config, test_image_path, output_mask_path)
