## 预测


In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import os
import re
from tqdm import tqdm
from tools.models.competition_backup import FusionNet
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as psnr

valid_transform = A.Compose([ToTensorV2()])

os.environ["CUDA_VISIBLE_DEVICES"] = "2"


def compute_ssim(img1, img2):
    """
    计算两幅灰度图像的 SSIM 指数
    (输入图像应为相同尺寸的 numpy 数组，建议先转换为灰度图)
    """
    # 检查图像尺寸是否相同
    if img1.shape != img2.shape:
        raise ValueError("Input images must have the same dimensions")

    # 转换为浮点数计算
    img1 = img1.astype(np.float32)
    img2 = img2.astype(np.float32)

    # 常量参数 (基于 8-bit 图像动态范围 0-255)
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    # 高斯核参数
    kernel_size = (11, 11)
    sigma = 1.5

    # 计算均值 (高斯模糊)
    mu1 = cv2.GaussianBlur(img1, kernel_size, sigma)
    mu2 = cv2.GaussianBlur(img2, kernel_size, sigma)

    # 计算均值平方
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2

    # 计算方差和协方差
    sigma1_sq = cv2.GaussianBlur(img1**2, kernel_size, sigma) - mu1_sq
    sigma2_sq = cv2.GaussianBlur(img2**2, kernel_size, sigma) - mu2_sq
    sigma12 = cv2.GaussianBlur(img1 * img2, kernel_size, sigma) - mu1_mu2

    # SSIM 计算
    numerator = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_map = numerator / denominator

    return np.mean(ssim_map)


def perform_inference(
    config,
    ckpt_path,
    test_dir="/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/CUBIC_x8",
    out_dir=None,
    device=0,
    num_ckpts=1,  # 新增参数：选择checkpoint的数量
    select_by="max_loss",  # 新增参数：选择方式（max_loss/min_loss）
):
    """
    Performs inference with multiple checkpoints and TTA.

    Args:
        num_ckpts (int): Number of checkpoints to use (sorted by loss)
        select_by (str): 'max_loss' or 'min_loss' to specify checkpoint selection
    """
    if out_dir is None:
        out_dir = os.path.join(os.path.dirname(test_dir), "PRED")
    os.makedirs(out_dir, exist_ok=True)

    # 1. 解析checkpoint文件 ------------------------------------------------
    ckpt_files = [f for f in os.listdir(ckpt_path) if f.endswith(".ckpt")]

    # 2. 加载多个模型 -----------------------------------------------------
    models = []
    for fname in ckpt_files:
        model = FusionNet(
            dim=config.dim,
            n_blocks=config.n_blocks,
            upscaling_factor=config.upscaling_factor,
            fmb_params={
                "smfa_growth": config.smfa_growth,
                "pcfn_growth": config.pcfn_growth,
                "snfa_dropout": config.snfa_dropout,
                "pcfn_dropout": config.pcfn_dropout,
                "p_rate": config.p_rate,
            },
        ).cuda(device)

        # 加载checkpoint
        ckpt = torch.load(
            os.path.join(ckpt_path, fname),
            map_location=f"cuda:{device}",
            weights_only=False,
        )["state_dict"]

        # 处理key
        for k in list(ckpt.keys()):
            if "model" not in k:
                ckpt.pop(k)
                continue
            new_key = k.replace("model.", "")
            ckpt[new_key] = ckpt.pop(k)
            if "loss" in new_key:
                del ckpt[new_key]

        model.load_state_dict(ckpt, strict=True)
        model.eval()
        models.append(model)

    # 3. 执行推理（多模型+多TTA）---------------------------------------------
    for img_name in tqdm(os.listdir(test_dir)):
        if not img_name.endswith(".bmp"):
            continue

        # 加载图像
        img_path = os.path.join(test_dir, img_name)
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        # 预处理
        orig = valid_transform(image=img)["image"] / 127.5 - 1

        # 创建四种TTA变换
        images = [
            orig,
            torch.flip(orig, [2]),  # 水平翻转
            torch.flip(orig, [1]),  # 垂直翻转
            torch.flip(orig, [1, 2]),  # 水平+垂直翻转
            torch.rot90(orig, 1, [1, 2]),  # 逆时针旋转90度
            torch.rot90(orig, 2, [1, 2]),  # 逆时针旋转180度
            torch.rot90(orig, 3, [1, 2]),  # 逆时针旋转270度
        ]

        all_preds = []
        for model in models:
            model_preds = []
            for idx, inp in enumerate(images):
                inp_tensor = inp.unsqueeze(0).cuda(device)
                with torch.no_grad():
                    output = model(inp_tensor)

                # 逆变换
                if idx == 1:
                    output = torch.flip(output, [3])
                elif idx == 2:
                    output = torch.flip(output, [2])
                elif idx == 3:
                    output = torch.flip(output, [2, 3])
                elif idx == 4:
                    output = torch.rot90(output, 3, [2, 3])
                elif idx == 5:
                    output = torch.rot90(output, 2, [2, 3])
                elif idx == 6:
                    output = torch.rot90(output, 1, [2, 3])

                model_preds.append(output)

            # 单模型的多TTA平均
            model_avg = torch.mean(torch.stack(model_preds), dim=0)
            all_preds.append(model_avg)

        # 多模型平均
        final_pred = torch.mean(torch.stack(all_preds), dim=0)

        # 后处理并保存
        final_pred = final_pred.squeeze()
        final_pred = (
            (final_pred.permute(1, 2, 0).cpu().numpy() * 127.5 + 127.5)
            .clip(0, 255)
            .astype("uint8")
        )
        cv2.imwrite(os.path.join(out_dir, img_name), final_pred)

    # 4. 压缩结果 ---------------------------------------------------------
    print("Zipping results...")
    zip_name = f"PRED_{os.path.basename(ckpt_path)}_top{num_ckpts}_{select_by}.zip"
    os.system(f"cd {out_dir} && zip '{os.path.join('../../../PBVS', zip_name)}' *.bmp")
    print(f"Results saved to: {out_dir}")
    return out_dir


def calculate_metrics(pred_dir, gt_dir):
    """
    计算预测图像和真实图像之间的 PSNR, SSIM, RMSE 和 LPIPS。

    Args:
        pred_dir (str): 预测图像目录。
        gt_dir (str): 真实图像目录。

    Returns:
        dict: 包含平均 PSNR, SSIM, RMSE 和 LPIPS 值的字典。
    """

    pred_files = [f for f in os.listdir(pred_dir) if f.lower().endswith(".bmp")]
    gt_files = [f for f in os.listdir(gt_dir) if f.lower().endswith(".bmp")]
    common_files = sorted(
        list(set(pred_files) & set(gt_files))
    )  # Find common files and sort

    if not common_files:
        print("No common BMP files found between the two directories.")
        return None

    psnr_values = []
    ssim_values = []
    rmse_values = []

    for filename in common_files:
        pred_path = os.path.join(pred_dir, filename)
        gt_path = os.path.join(gt_dir, filename)

        try:
            pred_img = cv2.imread(pred_path, cv2.IMREAD_COLOR)
            gt_img = cv2.imread(gt_path, cv2.IMREAD_COLOR)

            if pred_img is None or gt_img is None:
                print(f"Skipping {filename} due to read error.")
                continue

            if pred_img.shape != gt_img.shape:
                print(f"Skipping {filename} due to the different image shape")
                continue

            # Convert images to grayscale for SSIM (scikit-image's ssim expects grayscale)
            pred_gray = cv2.cvtColor(pred_img, cv2.COLOR_BGR2GRAY)
            gt_gray = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)

            # Calculate PSNR
            psnr_values.append(psnr(gt_img, pred_img, data_range=255))

            # Calculate SSIM
            ssim_values.append(compute_ssim(gt_gray, pred_gray))

            # Calculate RMSE
            rmse = np.sqrt(
                np.mean(
                    (
                        gt_img.astype(np.float64) / 255.0
                        - pred_img.astype(np.float64) / 255.0
                    )
                    ** 2
                )
            )
            rmse_values.append(rmse)

        except Exception as e:
            print(f"Error processing {filename}: {e}")
            continue

    if not psnr_values:  # Check if any images were successfully processed
        print("No images were successfully processed.")
        return None

    # Calculate average metrics
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    avg_rmse = np.mean(rmse_values)

    return {
        "PSNR": avg_psnr,
        "SSIM": avg_ssim,
        "RMSE": avg_rmse,
    }


In [28]:
# 示例配置
class SMFANetConfig:
    def __init__(self):
        self.dim = 192
        self.n_blocks = 12
        self.pcfn_growth = 8
        # self.dim = 96
        # self.n_blocks = 8
        # self.pcfn_growth = 4
        self.upscaling_factor = 8
        self.smfa_growth = 4
        self.snfa_dropout = 0.0
        self.pcfn_dropout = 0.12
        self.p_rate = 0.25


opt = SMFANetConfig()
perform_inference(
    opt,
    ckpt_path="./checkpoints/selected",
    num_ckpts=10,  # 使用loss最大的2个checkpoint
    select_by="max_loss",
)


100%|██████████| 20/20 [00:23<00:00,  1.18s/it]


Zipping results...
updating: 001_01_D2_th.bmp (deflated 84%)
updating: 002_01_D4_th.bmp (deflated 77%)
updating: 002_02_D1_th.bmp (deflated 81%)
updating: 004_02_D1_th.bmp (deflated 88%)
updating: 006_01_D4_th.bmp (deflated 80%)
updating: 008_01_D4_th.bmp (deflated 82%)
updating: 009_01_D2_th.bmp (deflated 84%)
updating: 009_02_D1_th.bmp (deflated 83%)
updating: 013_02_D1_th.bmp (deflated 86%)
updating: 014_01_D4_th.bmp (deflated 79%)
updating: 015_01_D2_th.bmp (deflated 79%)
updating: 016_02_D1_th.bmp (deflated 86%)
updating: 017_02_D1_th.bmp (deflated 83%)
updating: 019_01_D1_th.bmp (deflated 83%)
updating: 021_01_D2_th.bmp (deflated 80%)
updating: 026_01_D3_th.bmp (deflated 80%)
updating: 026_02_D1_th.bmp (deflated 80%)
updating: 027_01_D1_th.bmp (deflated 88%)
updating: 031_02_D1_th.bmp (deflated 86%)
updating: 034_01_D1_th.bmp (deflated 78%)
Results saved to: /media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED


'/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED'

In [29]:
pred_dir = "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED"
gt_dir = "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/GT"

metrics = calculate_metrics(pred_dir, gt_dir)

if metrics:
    print(f"Average Metrics:")
    for key, value in metrics.items():
        print(f"{key}: {value:.4f}  ")


Average Metrics:
PSNR: 27.4827  
SSIM: 0.8336  
RMSE: 0.0440  


In [None]:
import glob
import pandas as pd


def load_model(config, ckpt_file):
    model = FusionNet(
        dim=config.dim,
        n_blocks=config.n_blocks,
        upscaling_factor=config.upscaling_factor,
        fmb_params={
            "smfa_growth": config.smfa_growth,
            "pcfn_growth": config.pcfn_growth,
            "snfa_dropout": config.snfa_dropout,
            "pcfn_dropout": config.pcfn_dropout,
            "p_rate": config.p_rate,
        },
    ).cuda(0)
    # 加载checkpoint
    ckpt = torch.load(os.path.join(ckpt_file), map_location="cpu", weights_only=False)[
        "state_dict"
    ]

    # 处理key
    for k in list(ckpt.keys()):
        if "model" not in k:
            ckpt.pop(k)
            continue
        new_key = k.replace("model.", "")
        ckpt[new_key] = ckpt.pop(k)
        if "loss" in new_key:
            del ckpt[new_key]

    model.load_state_dict(ckpt, strict=True)
    model.eval()
    return model


def infer_from_model(model, output_path):
    test_dir = "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/CUBIC_x8"
    out_dir = output_path
    os.makedirs(out_dir, exist_ok=True)

    for img_name in tqdm(os.listdir(test_dir)):
        # 加载图像
        img_path = os.path.join(test_dir, img_name)
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        # 预处理
        orig = valid_transform(image=img)["image"] / 127.5 - 1

        # 创建四种TTA变换
        images = [
            orig,
            torch.flip(orig, [2]),  # 水平翻转
            torch.flip(orig, [1]),  # 垂直翻转
            torch.flip(orig, [1, 2]),  # 水平+垂直翻转
            torch.rot90(orig, 1, [1, 2]),  # 顺时针旋转90度
            torch.rot90(orig, 2, [1, 2]),  # 顺时针旋转180度
            torch.rot90(orig, 3, [1, 2]),  # 顺时针旋转270度
        ]

        all_preds = []
        for idx, inp in enumerate(images):
            inp_tensor = inp.unsqueeze(0).cuda(0)
            with torch.no_grad():
                output = model(inp_tensor)

            # 逆变换
            if idx == 1:
                output = torch.flip(output, [3])
            elif idx == 2:
                output = torch.flip(output, [2])
            elif idx == 3:
                output = torch.flip(output, [2, 3])
            elif idx == 4:
                output = torch.rot90(output, 3, [2, 3])
            elif idx == 5:
                output = torch.rot90(output, 2, [2, 3])
            elif idx == 6:
                output = torch.rot90(output, 1, [2, 3])

            all_preds.append(output)

        # 多TTA平均
        final_pred = torch.mean(torch.stack(all_preds), dim=0)

        # 后处理并保存
        final_pred = final_pred.squeeze()
        final_pred = (
            (final_pred.permute(1, 2, 0).cpu().numpy() * 127.5 + 127.5)
            .clip(0, 255)
            .astype("uint8")
        )
        cv2.imwrite(os.path.join(out_dir, img_name), final_pred)


def calculate_metrics(pred_dir, gt_dir):
    """
    计算预测图像和真实图像之间的 PSNR, SSIM, RMSE 和 LPIPS。

    Args:
        pred_dir (str): 预测图像目录。
        gt_dir (str): 真实图像目录。

    Returns:
        dict: 包含平均 PSNR, SSIM, RMSE 和 LPIPS 值的字典。
    """

    pred_files = [f for f in os.listdir(pred_dir) if f.lower().endswith(".bmp")]
    gt_files = [f for f in os.listdir(gt_dir) if f.lower().endswith(".bmp")]
    common_files = sorted(
        list(set(pred_files) & set(gt_files))
    )  # Find common files and sort

    if not common_files:
        print("No common BMP files found between the two directories.")
        return None

    psnr_values = []
    ssim_values = []
    rmse_values = []

    for filename in common_files:
        pred_path = os.path.join(pred_dir, filename)
        gt_path = os.path.join(gt_dir, filename)

        try:
            pred_img = cv2.imread(pred_path, cv2.IMREAD_COLOR)
            gt_img = cv2.imread(gt_path, cv2.IMREAD_COLOR)

            if pred_img is None or gt_img is None:
                print(f"Skipping {filename} due to read error.")
                continue

            if pred_img.shape != gt_img.shape:
                print(f"Skipping {filename} due to the different image shape")
                continue

            # Convert images to grayscale for SSIM (scikit-image's ssim expects grayscale)
            pred_gray = cv2.cvtColor(pred_img, cv2.COLOR_BGR2GRAY)
            gt_gray = cv2.cvtColor(gt_img, cv2.COLOR_BGR2GRAY)

            # Calculate PSNR
            psnr_values.append(psnr(gt_img, pred_img, data_range=255))

            # Calculate SSIM
            ssim_values.append(compute_ssim(gt_gray, pred_gray))

            # Calculate RMSE
            rmse = np.sqrt(
                np.mean(
                    (
                        gt_img.astype(np.float64) / 255.0
                        - pred_img.astype(np.float64) / 255.0
                    )
                    ** 2
                )
            )
            rmse_values.append(rmse)

        except Exception as e:
            print(f"Error processing {filename}: {e}")
            continue

    if not psnr_values:  # Check if any images were successfully processed
        print("No images were successfully processed.")
        return None

    # Calculate average metrics
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    avg_rmse = np.mean(rmse_values)

    return {
        "PSNR": avg_psnr,
        "SSIM": avg_ssim,
        "RMSE": avg_rmse,
    }


class SMFANetConfig:
    def __init__(self):
        self.dim = 192
        self.n_blocks = 12
        self.pcfn_growth = 8
        self.upscaling_factor = 8
        self.smfa_growth = 4
        self.snfa_dropout = 0.0
        self.pcfn_dropout = 0.12
        self.p_rate = 0.25


opt = SMFANetConfig()

ckpt_files = glob.glob("./checkpoints/v3_fintune_attn_res_df2k/*.ckpt")
metrics_list = []

for ckpt_file in ckpt_files:
    model = load_model(opt, ckpt_file)
    infer_from_model(
        model, "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED"
    )
    metrics = calculate_metrics(
        "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/PRED",
        "/media/hdd/sonwe1e/Competition/PBVS_Thermal/Data/valid/GT",
    )

    if metrics:
        metrics["checkpoint"] = os.path.basename(ckpt_file)
        metrics_list.append(metrics)
    os.rename(
        ckpt_file,
        ckpt_file.split("ep")[0]
        + f"{(metrics['PSNR'] - 20):.4f}_{metrics['SSIM']:.4f}.ckpt",
    )

# Create DataFrame from metrics
if metrics_list:
    metrics_df = pd.DataFrame(metrics_list)
    print(metrics_df.sort_values(by="PSNR", ascending=False))  # Sort by PSNR descending
else:
    print("No valid metrics collected")

100%|██████████| 20/20 [00:04<00:00,  4.75it/s]
100%|██████████| 20/20 [00:04<00:00,  4.84it/s]
100%|██████████| 20/20 [00:04<00:00,  4.87it/s]
100%|██████████| 20/20 [00:04<00:00,  4.78it/s]


        PSNR      SSIM      RMSE                 checkpoint
0  27.479439  0.833031  0.044069  epoch_83-loss_26.746.ckpt
3  27.458195  0.832917  0.044171  epoch_91-loss_26.738.ckpt
2  27.450718  0.831366  0.044199  epoch_66-loss_26.740.ckpt
1  27.448583  0.833455  0.044188  epoch_89-loss_26.746.ckpt


: 