## 预测


In [4]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import os
import re
from tqdm import tqdm
from tools.models.fusion import FusionNet, FusionNetv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr
import lpips
valid_transform = A.Compose([ToTensorV2()])


def compute_ssim(img1, img2):
    """
    计算两幅灰度图像的 SSIM 指数
    (输入图像应为相同尺寸的 numpy 数组，建议先转换为灰度图)
    """
    # 检查图像尺寸是否相同
    if img1.shape != img2.shape:
        raise ValueError("Input images must have the same dimensions")

    # 转换为浮点数计算
    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)

    # 常量参数 (基于 8-bit 图像动态范围 0-255)
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    # 高斯核参数
    kernel_size = (11, 11)
    sigma = 1.5

    # 计算均值 (高斯模糊)
    mu1 = cv2.GaussianBlur(img1, kernel_size, sigma)
    mu2 = cv2.GaussianBlur(img2, kernel_size, sigma)

    # 计算均值平方
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2

    # 计算方差和协方差
    sigma1_sq = cv2.GaussianBlur(img1**2, kernel_size, sigma) - mu1_sq
    sigma2_sq = cv2.GaussianBlur(img2**2, kernel_size, sigma) - mu2_sq
    sigma12 = cv2.GaussianBlur(img1 * img2, kernel_size, sigma) - mu1_mu2

    # SSIM 计算
    numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_map = numerator / denominator

    return np.mean(ssim_map)


def calculate_metrics(pred_dir, gt_dir):
    """
    计算预测图像和真实图像之间的 PSNR, SSIM, RMSE 和 LPIPS。

    Args:
        pred_dir (str): 预测图像目录。
        gt_dir (str): 真实图像目录。

    Returns:
        dict: 包含平均 PSNR, SSIM, RMSE 和 LPIPS 值的字典。
    """

    pred_files = [f for f in os.listdir(pred_dir) if f.lower().endswith(".bmp")]
    gt_files = [f for f in os.listdir(gt_dir) if f.lower().endswith(".bmp")]
    common_files = sorted(
        list(set(pred_files) & set(gt_files))
    )  # Find common files and sort

    if not common_files:
        print("No common BMP files found between the two directories.")
        return None

    psnr_values = []
    ssim_values = []
    rmse_values = []
    lpips_alex_values = []
    lpips_vgg_values = []
    lpips_squeeze_values = []

    # Initialize LPIPS loss functions
    loss_fn_alex = lpips.LPIPS(net="alex").to(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    loss_fn_vgg = lpips.LPIPS(net="vgg").to(
        "cuda" if torch.cuda.is_available() else "cpu"
    )
    loss_fn_squeeze = lpips.LPIPS(net="squeeze").to(
        "cuda" if torch.cuda.is_available() else "cpu"
    )

    for filename in common_files:
        pred_path = os.path.join(pred_dir, filename)
        gt_path = os.path.join(gt_dir, filename)

        try:
            pred_img = cv2.imread(pred_path, cv2.IMREAD_COLOR)
            gt_img = cv2.imread(gt_path, cv2.IMREAD_COLOR)

            if pred_img is None or gt_img is None:
                print(f"Skipping {filename} due to read error.")
                continue

            if pred_img.shape != gt_img.shape:
                print(f"Skipping {filename} due to the different image shape")
                continue

            # Convert images to grayscale for SSIM (scikit-image's ssim expects grayscale)
            pred_gray = cv2.cvtColor(pred_img, cv2.COLOR_BGR2GRAY)
            gt_gray = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)

            # Calculate PSNR
            psnr_values.append(psnr(gt_img, pred_img, data_range=255))

            # Calculate SSIM
            ssim_values.append(compute_ssim(gt_gray, pred_gray))

            # Calculate RMSE
            rmse = np.sqrt(
                np.mean(
                    (
                        gt_img.astype(np.float64) / 255.0
                        - pred_img.astype(np.float64) / 255.0
                    )
                    ** 2
                )
            )
            rmse_values.append(rmse)

            # Calculate LPIPS (requires [0,1] normalized, torch Tensor images)
            gt_img_tensor = (
                torch.from_numpy(gt_img).permute(2, 0, 1).unsqueeze(0).float()
            ) / 127.5 - 1
            pred_img_tensor = (
                torch.from_numpy(pred_img).permute(2, 0, 1).unsqueeze(0).float()
            ) / 127.5 - 1

            # Move tensors to the same device as the loss functions
            device = "cuda" if torch.cuda.is_available() else "cpu"
            gt_img_tensor = gt_img_tensor.to(device)
            pred_img_tensor = pred_img_tensor.to(device)

            lpips_alex = loss_fn_alex(gt_img_tensor, pred_img_tensor).item()
            lpips_vgg = loss_fn_vgg(gt_img_tensor, pred_img_tensor).item()
            lpips_squeeze = loss_fn_squeeze(gt_img_tensor, pred_img_tensor).item()

            lpips_alex_values.append(lpips_alex)
            lpips_vgg_values.append(lpips_vgg)
            lpips_squeeze_values.append(lpips_squeeze)

        except Exception as e:
            print(f"Error processing {filename}: {e}")
            continue

    if not psnr_values:  # Check if any images were successfully processed
        print("No images were successfully processed.")
        return None

    # Calculate average metrics
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    avg_rmse = np.mean(rmse_values)
    avg_lpips_alex = np.mean(lpips_alex_values)
    avg_lpips_vgg = np.mean(lpips_vgg_values)
    avg_lpips_squeeze = np.mean(lpips_squeeze_values)

    return {
        "PSNR": avg_psnr,
        "SSIM": avg_ssim,
        "RMSE": avg_rmse,
        "LPIPS_alex": avg_lpips_alex,
        "LPIPS_vgg": avg_lpips_vgg,
        "LPIPS_squeeze": avg_lpips_squeeze,
    }


def perform_inference(
    config,
    ckpt_path,
    test_dir="/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/CUBIC_x8",
    out_dir=None,
    device=0,
    num_ckpts=1,  # 新增参数：选择checkpoint的数量
    select_by="max_loss",  # 新增参数：选择方式（max_loss/min_loss）
):
    """
    Performs inference with multiple checkpoints and TTA.

    Args:
        num_ckpts (int): Number of checkpoints to use (sorted by loss)
        select_by (str): 'max_loss' or 'min_loss' to specify checkpoint selection
    """
    if out_dir is None:
        out_dir = os.path.join(os.path.dirname(test_dir), "PRED")
    os.makedirs(out_dir, exist_ok=True)

    # 1. 解析checkpoint文件 ------------------------------------------------
    ckpt_files = [f for f in os.listdir(ckpt_path) if f.endswith(".ckpt")]
    if not ckpt_files:
        raise ValueError(f"No checkpoint found in {ckpt_path}")

    # 提取loss值并排序
    parsed = []
    for f in ckpt_files:
        match = re.search(r"loss_([0-9]+\.[0-9]+)", f)
        if match:
            loss = float(match.group(1))
            parsed.append((f, loss))

    if not parsed:
        raise ValueError("No valid checkpoint files with loss in name")

    # 根据选择方式排序
    reverse = True if select_by == "max_loss" else False
    parsed_sorted = sorted(parsed, key=lambda x: x[1], reverse=reverse)
    selected_files = parsed_sorted[:num_ckpts]

    # 2. 加载多个模型 -----------------------------------------------------
    models = []
    for fname, loss in selected_files:
        model = FusionNet(
            dim=config.dim,
            n_blocks=config.n_blocks,
            upscaling_factor=config.upscaling_factor,
            fmb_params={
                "smfa_growth": config.smfa_growth,
                "pcfn_growth": config.pcfn_growth,
                "snfa_dropout": config.snfa_dropout,
                "pcfn_dropout": config.pcfn_dropout,
                "p_rate": config.p_rate,
            },
        ).cuda(device)

        # 加载checkpoint
        ckpt = torch.load(
            os.path.join(ckpt_path, fname),
            map_location=f"cuda:{device}",
            weights_only=False,
        )["state_dict"]

        # 处理key
        for k in list(ckpt.keys()):
            new_key = k.replace("model.", "")
            ckpt[new_key] = ckpt.pop(k)
            if "loss" in new_key:
                del ckpt[new_key]

        model.load_state_dict(ckpt, strict=True)
        model.eval()
        models.append(model)

    # 3. 执行推理（多模型+多TTA）---------------------------------------------
    for img_name in tqdm(os.listdir(test_dir)):
        if not img_name.endswith(".bmp"):
            continue

        # 加载图像
        img_path = os.path.join(test_dir, img_name)
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        # 预处理
        orig = valid_transform(image=img)["image"] / 127.5 - 1

        # 创建四种TTA变换
        images = [
            orig,
            torch.flip(orig, [2]),  # 水平翻转
            torch.flip(orig, [1]),  # 垂直翻转
            torch.flip(orig, [1, 2]),  # 水平+垂直翻转
        ]

        all_preds = []
        for model in models:
            model_preds = []
            for idx, inp in enumerate(images):
                inp_tensor = inp.unsqueeze(0).cuda(device)
                with torch.no_grad():
                    output = model(inp_tensor)

                # 逆变换
                if idx == 1:
                    output = torch.flip(output, [3])
                elif idx == 2:
                    output = torch.flip(output, [2])
                elif idx == 3:
                    output = torch.flip(output, [2, 3])

                model_preds.append(output)

            # 单模型的多TTA平均
            model_avg = torch.mean(torch.stack(model_preds), dim=0)
            all_preds.append(model_avg)

        # 多模型平均
        final_pred = torch.mean(torch.stack(all_preds), dim=0)
        # final_pred = all_preds[0]

        # 后处理并保存
        final_pred = final_pred.squeeze()
        final_pred = (
            (final_pred.permute(1, 2, 0).cpu().numpy() * 127.5 + 127.5)
            .clip(0, 255)
            .astype("uint8")
        )
        cv2.imwrite(os.path.join(out_dir, img_name), final_pred)

    # 4. 压缩结果 ---------------------------------------------------------
    print("Zipping results...")
    zip_name = f"PRED_{os.path.basename(ckpt_path)}_top{num_ckpts}_{select_by}.zip"
    os.system(f"cd {out_dir} && zip '{os.path.join('../../../PBVS', zip_name)}' *.bmp")
    print(f"Results saved to: {out_dir}")
    return out_dir




In [8]:
# 示例配置
class SMFANetConfig:
    def __init__(self):
        self.dim = 192
        self.n_blocks = 12
        self.pcfn_growth = 8
        # self.dim = 96
        # self.n_blocks = 8
        # self.pcfn_growth = 4
        self.upscaling_factor = 8
        self.smfa_growth = 4
        self.snfa_dropout = 0.0
        self.pcfn_dropout = 0.12
        self.p_rate = 0.25


opt = SMFANetConfig()
perform_inference(
    opt,
    ckpt_path="./checkpoints/v3_dim192block12gr8dp0.16_df2kost300gray",
    num_ckpts=3,  # 使用loss最大的2个checkpoint
    select_by="max_loss",
)


100%|██████████| 20/20 [00:06<00:00,  3.17it/s]


Zipping results...
updating: 001_01_D2_th.bmp (deflated 80%)
updating: 002_01_D4_th.bmp (deflated 72%)
updating: 002_02_D1_th.bmp (deflated 77%)
updating: 004_02_D1_th.bmp (deflated 85%)
updating: 006_01_D4_th.bmp (deflated 74%)
updating: 008_01_D4_th.bmp (deflated 76%)
updating: 009_01_D2_th.bmp (deflated 80%)
updating: 009_02_D1_th.bmp (deflated 79%)
updating: 013_02_D1_th.bmp (deflated 83%)
updating: 014_01_D4_th.bmp (deflated 75%)
updating: 015_01_D2_th.bmp (deflated 75%)
updating: 016_02_D1_th.bmp (deflated 83%)
updating: 017_02_D1_th.bmp (deflated 80%)
updating: 019_01_D1_th.bmp (deflated 78%)
updating: 021_01_D2_th.bmp (deflated 76%)
updating: 026_01_D3_th.bmp (deflated 76%)
updating: 026_02_D1_th.bmp (deflated 76%)
updating: 027_01_D1_th.bmp (deflated 84%)
updating: 031_02_D1_th.bmp (deflated 82%)
updating: 034_01_D1_th.bmp (deflated 74%)
Results saved to: /media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED


'/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED'

In [9]:
pred_dir = "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED"
gt_dir = "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/GT"

metrics = calculate_metrics(pred_dir, gt_dir)

if metrics:
    print(f"Average Metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}")


Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/ubuntu/miniconda3/envs/swv2/lib/python3.11/site-packages/lpips/weights/v0.1/alex.pth
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/ubuntu/miniconda3/envs/swv2/lib/python3.11/site-packages/lpips/weights/v0.1/vgg.pth
Setting up [LPIPS] perceptual loss: trunk [squeeze], v[0.1], spatial [off]
Loading model from: /home/ubuntu/miniconda3/envs/swv2/lib/python3.11/site-packages/lpips/weights/v0.1/squeeze.pth
Average Metrics:
PSNR: 27.4301
SSIM: 0.8313
RMSE: 0.0443
LPIPS_alex: 0.2716
LPIPS_vgg: 0.3615
LPIPS_squeeze: 0.2374


In [39]:
import heavyball
heavyball.utils.compile_mode = False