In [1]:
import pandas as pd
import os
import sklearn
from PIL import Image
import numpy as np
import logging


class WildScenesDataset:
    root_dirs = [
        os.path.join('..',  'WildScenes2d', 'V-01'),
        os.path.join('..',  'WildScenes2d', 'V-02'),
        os.path.join('..',  'WildScenes2d', 'V-03')
    ]
    _data_list_dir = os.path.join('datasets', 'data_list')
    _csv_files = {
        'train': os.path.join(_data_list_dir, 'train.csv'),
        'valid': os.path.join(_data_list_dir, 'valid.csv'),
        'test': os.path.join(_data_list_dir, 'test.csv'),
    }
    csv = _csv_files
    _label_to_trainid = {
        0: 15,  # 背景类
        1: 16,  # 忽略类
        2: 0,  # Bush
        3: 1,  # Dirt
        4: 2,  # Fence
        5: 3,  # Grass
        6: 4,  # Gravel
        7: 5,  # Log
        8: 6,  # Mud
        9: 7,  # Other-Object
        10: 8,  # Other-terrain
        11: 16,  # 忽略类
        12: 9,  # Rock
        13: 10,  # Sky
        14: 11,  # Structure
        15: 12,  # Tree-foliage
        16: 13,  # Tree-trunk
        17: 16,  # 忽略类
        18: 14,  # Water
    }
    def __init__(self, dataset_type, transform=None):
        assert dataset_type in ('train', 'valid', 'test')
        self._dataset_type = dataset_type
        self._data_frame = pd.read_csv(WildScenesDataset._csv_files[self._dataset_type])
        self._transform = transform

    def __len__(self):
        return len(self._data_frame)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f"Index {index} out of bounds for dataset of size {len(self)}")
        try:
            image_path = self._data_frame['image'].iloc[index]
            label_path = self._data_frame['label'].iloc[index]

            image = Image.open(image_path).convert('RGB')
            label = Image.open(label_path).convert('L')
            # 将灰度图转换为Numpy数组
            label_np = np.array(label)

            # 将标签索引映射到训练标识（trainId）值
            label_trainId = np.vectorize(lambda x: self._label_to_trainid.get(x, 255))(label_np)

            if self._transform is not None:
                for t in self._transform:
                    image, label_trainId = t(image, label_trainId)

            return image, label_trainId
        except Exception as e:
            logging.error(f"Error loading item at index {index}: {str(e)}")

    @staticmethod
    def _get_image_label_dir():
        """
        Traverse server image and label directories and yield image and label paths.
        :return: Generator yielding (image path, label path)
        """
        data_err = 'data error. check!'
        image_base = WildScenesDataset.image_file_base
        label_base = WildScenesDataset.label_file_base

        for image in os.listdir(image_base):
            image_origin = os.path.join(image_base, image)
            image_label = os.path.join(label_base, image)

            if not (os.path.isfile(image_label) and
                    os.path.exists(image_label) and
                    os.path.isfile(image_label)):
                print(image_origin, image_label, data_err)  # Print error message and skip if paths are invalid
                continue

            yield image_origin, image_label

    @staticmethod
    def make_data_list(train_rate=0.7, valid_rate=0.2, shuffle=True):
        """
        Shuffle and generate data_list CSV files with image and label paths sorted by filename.
        :param train_rate: Training set ratio, default 0.7
        :param valid_rate: Validation set ratio, default 0.2
        :param shuffle: Whether to shuffle the dataset, default True
        :return: None
        """
        # Ensure the directory exists
        data_list_dir = WildScenesDataset._data_list_dir
        if not os.path.exists(data_list_dir):
            os.makedirs(data_list_dir)
            print(f"Created directory: {data_list_dir}")

        all_image_paths = []
        all_label_paths = []

        # Traverse all root directories to collect image and label paths
        for root_dir in WildScenesDataset.root_dirs:
            image_base = os.path.join(root_dir, 'image')
            label_base = os.path.join(root_dir, 'indexLabel')

            if not os.path.exists(image_base) or not os.path.exists(label_base):
                print(f"Error: Image or label directory does not exist in {root_dir}")
                continue

            for image in os.listdir(image_base):
                image_origin = os.path.join(image_base, image)
                image_label = os.path.join(label_base, image)

                if not (os.path.isfile(image_label) and os.path.exists(image_label)):
                    print(f"Warning: Skipping invalid file pair {image_origin}, {image_label}")
                    continue

                all_image_paths.append(image_origin)
                all_label_paths.append(image_label)

        # Create DataFrame with image and label paths
        df = pd.DataFrame(
            data={'image': all_image_paths, 'label': all_label_paths}
        )

        # Sort DataFrame by filename (assumed to be timestamp in a sortable format)
        df['timestamp'] = df['image'].apply(lambda x: int(os.path.splitext(os.path.basename(x))[0].split('-')[0]))
        df = df.sort_values(by='timestamp').reset_index(drop=True)

        if shuffle:
            df = sklearn.utils.shuffle(df)  # Shuffle dataframe if specified

        # Calculate sizes for train, valid, and test sets
        train_size = int(df.shape[0] * train_rate)
        valid_size = int(df.shape[0] * valid_rate)

        print('total: {:d} | train: {:d} | val: {:d} | test: {:d}'.format(
            df.shape[0], train_size, valid_size,
            df.shape[0] - train_size - valid_size))

        # Split dataframe into train, valid, and test sets
        df_train = df[0: train_size]
        df_valid = df[train_size: train_size + valid_size]
        df_test = df[train_size + valid_size:]

        # Save train, valid, and test sets to CSV files
        df_train[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['train']), index=False)
        df_valid[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['valid']), index=False)
        df_test[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['test']), index=False)

    # 测试語義分割時用
    @staticmethod
    def test_label_mapping(label_path):
        """
        Test the label mapping for a single label image.

        :param label_path: Path to the label image
        :return: Tuple of original label numpy array and mapped trainId numpy array
        """
        # Open the label image and convert to numpy array
        label = Image.open(label_path).convert('L')
        label_np = np.array(label, dtype=np.uint8)

        # Map label indices to trainId values
        label_trainId = np.vectorize(lambda x: WildScenesDataset._label_to_trainid.get(x, 255))(label_np)

        return label_np, label_trainId

if __name__=='__main__':
    # Example usage
    root_dirs = [
        os.path.join('..',  'WildScenes2d', 'V-01'),
        os.path.join('..',  'WildScenes2d', 'V-02'),
        os.path.join('..',  'WildScenes2d', 'V-03')
    ]

    for root_dir in root_dirs:
        WildScenesDataset.image_file_base = os.path.join(root_dir, 'image')
        WildScenesDataset.label_file_base = os.path.join(root_dir, 'indexLabel')
        WildScenesDataset.make_data_list()

Created directory: datasets/data_list
total: 2789 | train: 1952 | val: 557 | test: 280
total: 2789 | train: 1952 | val: 557 | test: 280
total: 2789 | train: 1952 | val: 557 | test: 280


In [3]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from data_split import WildScenesDataset
from utils.transforms import TrainTransform, TestTransform

color_map = {
    0: [224, 31, 77],  # Bush
    1: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    3: [127, 127, 127],  # Grass
    4: [145, 24, 178],  # Gravel
    5: [125, 128, 16],  # Log
    6: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    10: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    12: [208, 245, 71],  # Tree-foliage
    13: [238, 47, 227],  # Tree-trunk
    14: [40, 127, 198],  # Water
    15: [0, 0, 0],      # 背景类（黑色）
    16: [128, 128, 128],  # 忽略类（灰色）
}

class EnhancedWildScenesDataset(WildScenesDataset):
    def __init__(self, dataset_type, transform=None):
        super().__init__(dataset_type, transform)
        self.color_map = self._load_color_map()
        self.transform = self._get_transform(dataset_type)

    def __getitem__(self, index):
        image_path = self._data_frame['image'][index]
        label_path = self._data_frame['label'][index]

        image = Image.open(image_path).convert('RGB')
        original_label, mapped_label = self.test_label_mapping(label_path)
        label = Image.fromarray(mapped_label.astype(np.uint8))

        if self.transform is not None:
            image, label = self.transform(image, label)

        return image, label

    def _load_color_map(self):
        return {key: np.array(value) for key, value in color_map.items()}

    def _get_transform(self, dataset_type):
        if dataset_type == 'train':
            return TrainTransform()
        elif dataset_type in ['valid', 'test']:
            return TestTransform()
        else:
            raise ValueError('Invalid dataset type')

    @staticmethod
    def get_color_coded_label(label_trainId):
        height, width = label_trainId.shape
        label_RGB = np.zeros((height, width, 3), dtype=np.uint8)
        for trainId, color in color_map.items():
            label_RGB[label_trainId == trainId] = color
        return label_RGB

    @staticmethod
    def get_data_loader(dataset_type, batch_size=4):
        dataset = EnhancedWildScenesDataset(dataset_type)
        shuffle = dataset_type == 'train'
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True, drop_last=True)

# if __name__ == '__main__':
#     # 测试data_loader
#     train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=4)
#     for images, labels in train_loader:
#         print(f"Batch image shape: {images.shape}")  # 一个批次的图像数据形状，一个批次4个图像，每个图像3通道，每个图像尺寸256*341
#         print(f"Batch label shape: {labels.shape}")  # 一个批次的label数据形状，单通道，表示的是trainId标注的
#         print(f"Batch label unique values: {torch.unique(labels)}")
#         break

#     dataset = EnhancedWildScenesDataset('train')
#     image, label = dataset[0]
#     color_coded_label = dataset.get_color_coded_label(label.numpy())
#     print(f"Color-coded label shape: {color_coded_label.shape}")  # 使用颜色编码的label图像形状，发现是3通道，成了！


Batch image shape: torch.Size([4, 3, 256, 256])
Batch label shape: torch.Size([4, 256, 256])
Batch label unique values: tensor([ 0,  5,  6, 12, 13, 14, 16])
Color-coded label shape: (256, 256, 3)


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import os
import logging
from data_load import EnhancedWildScenesDataset
from utils.metrics import calculate_miou_train, calculate_pixel_accuracy, calculate_dice_coefficient
from utils.losses import CombinedLoss
from utils.log import setup_logger, save_checkpoint
from models.unet import UNet
import torch.nn.functional as F


def setup_logger(log_file):
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_file),
                            logging.StreamHandler()
                        ])

def save_checkpoint(model, optimizer, epoch, metrics, path):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics
    }
    torch.save(state, path)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, num_classes, scaler):
    model.train()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
        pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
        dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

        if not np.isnan(miou):
            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            num_batches += 1

    return (total_loss / len(dataloader),
            total_miou / num_batches if num_batches > 0 else 0.0,
            total_pixel_acc / num_batches if num_batches > 0 else 0.0,
            total_dice / num_batches if num_batches > 0 else 0.0)

def validate_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice

    num_batches = len(dataloader)
    return (total_loss / num_batches,
            total_miou / num_batches,
            total_pixel_acc / num_batches,
            total_dice / num_batches)

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir):
    best_loss = float('inf')
    best_miou = 0
    scaler = GradScaler()

    for epoch in range(num_epochs):
        logging.info(f"Epoch {epoch + 1}/{num_epochs}")

        train_loss, train_miou, train_pixel_acc, train_dice = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, num_classes, scaler)
        logging.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Train mIoU: {train_miou:.4f}, "
                     f"Train Pixel Acc: {train_pixel_acc:.4f}, Train Dice: {train_dice:.4f}")

        val_loss, val_miou, val_pixel_acc, val_dice = validate_epoch(model, val_loader, criterion, device, num_classes)
        logging.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}, "
                     f"Val Pixel Acc: {val_pixel_acc:.4f}, Val Dice: {val_dice:.4f}")

        scheduler.step(val_loss)

        metrics = {
            'val_loss': val_loss,
            'val_miou': val_miou,
            'val_pixel_acc': val_pixel_acc,
            'val_dice': val_dice
        }

        if val_miou > best_miou:
            best_miou = val_miou
            best_model_path = os.path.join(save_dir, f'best_model_epoch_{epoch + 1}.pth')
            save_checkpoint(model, optimizer, epoch + 1, metrics, best_model_path)
            logging.info(f"Epoch {epoch + 1} - Best model saved with Val mIoU: {best_miou:.4f}")

        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            save_checkpoint(model, optimizer, epoch + 1, metrics, checkpoint_path)
            logging.info(f"Epoch {epoch + 1} - Checkpoint saved")

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Current learning rate: {current_lr:.6f}")

    logging.info(f"Training completed after {num_epochs} epochs.")
    return best_model_path

if __name__ == "__main__":
    save_dir = 'model_checkpoints/Unet'
    os.makedirs(save_dir, exist_ok=True)

    log_file = os.path.join(save_dir, 'training.log')
    setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=8)
    val_loader = EnhancedWildScenesDataset.get_data_loader('valid', batch_size=8)

    num_classes = 17

    model = UNet(3, num_classes).to(device)

    criterion = CombinedLoss(weight_focal=1.0, weight_dice=0.5)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5, weight_decay=0.01)

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    num_epochs = 50
    best_model_path = train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir)
    logging.info(f"Training completed! Best model saved at {best_model_path}")


2024-07-26 01:24:16,119 - INFO - Using device: cuda
2024-07-26 01:24:16,738 - INFO - Epoch 1/60
Training: 100%|██████████| 244/244 [22:15<00:00,  5.47s/it]
2024-07-26 01:46:31,841 - INFO - Epoch 1 - Train Loss: 1.1914, Train mIoU: 0.1744, Train Pixel Acc: 0.7038, Train Dice: 0.2150
Validating: 100%|██████████| 69/69 [06:10<00:00,  5.36s/it]
2024-07-26 01:52:41,892 - INFO - Epoch 1 - Val Loss: 0.8037, Val mIoU: 0.5552, Val Pixel Acc: 0.7822, Val Dice: 0.5953
2024-07-26 01:52:42,217 - INFO - Epoch 1 - Best model saved with Val mIoU: 0.5552
2024-07-26 01:52:42,219 - INFO - Current learning rate: 0.000050
2024-07-26 01:52:42,220 - INFO - Epoch 2/60
Training: 100%|██████████| 244/244 [21:25<00:00,  5.27s/it]
2024-07-26 02:14:07,229 - INFO - Epoch 2 - Train Loss: 0.8033, Train mIoU: 0.2750, Train Pixel Acc: 0.7809, Train Dice: 0.3176
Validating: 100%|██████████| 69/69 [06:04<00:00,  5.28s/it]
2024-07-26 02:20:11,414 - INFO - Epoch 2 - Val Loss: 0.7317, Val mIoU: 0.6134, Val Pixel Acc: 0.8018

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
import seaborn as sns
from data_load import EnhancedWildScenesDataset
from torchvision import models
from models.unet import UNet
from sklearn.metrics import confusion_matrix
import csv
import pandas as pd

# 定义颜色映射
color_map = {
    1: [224, 31, 77],  # Bush
    0: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    14: [127, 127, 127],  # Grass
    3: [145, 24, 178],  # Gravel
    13: [125, 128, 16],  # Log
    12: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    16: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    6: [208, 245, 71],  # Tree-foliage
    5: [238, 47, 227],  # Tree-trunk
    4: [40, 127, 198],  # Water
    15: [0, 0, 0],  # 背景类（黑色）
    10: [128, 128, 128],  # 忽略类（灰色）
}

class_names = ['Dirt', 'Bush', 'Fence', 'Gravel', 'Water', 'Tree-trunk', 'Tree-Foliage', 'Other-object',
               'Other-terrain', 'Rock', 'Ignore', 'Structure', 'Mud', 'Log', 'Grass', 'Background', 'Sky']


def apply_color_map(segmentation):
    color_segmentation = np.zeros((*segmentation.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        color_segmentation[segmentation == class_idx] = color
    return color_segmentation


def overlay_segmentation(image, segmentation, alpha=0.5):
    colored_seg = apply_color_map(segmentation)
    return (image * (1 - alpha) + colored_seg * alpha).astype(np.uint8)


def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(apply_color_map(ground_truth.cpu().numpy()))
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(apply_color_map(prediction.cpu().numpy()))
    ax3.set_title('Prediction')
    ax3.axis('off')

    ax4.imshow(overlay_segmentation(image.permute(1, 2, 0).cpu().numpy(), prediction.cpu().numpy()))
    ax4.set_title('Overlay')
    ax4.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()


def plot_confusion_matrix(y_true, y_pred, save_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(save_path)
    plt.close()


def plot_per_class_iou(class_iou, save_path):
    plt.figure(figsize=(12, 6))
    valid_iou = [iou for iou in class_iou if not np.isnan(iou)]
    plt.bar(range(len(valid_iou)), valid_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    valid_class_names = [name for i, name in enumerate(class_names) if i < len(valid_iou)]
    plt.xticks(range(len(valid_iou)), valid_class_names, rotation=90)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def calculate_miou(pred, target, num_classes):
    ious = []
    pred = pred.ravel()
    target = target.ravel()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # 如果该类别不存在，则IoU为NaN
        else:
            ious.append(intersection / union)
    miou = np.nanmean(ious)  # 忽略NaN值计算平均IoU
    return miou, np.array(ious)


def visualize_results(model, dataloader, device, num_classes, save_dir):
    model.eval()
    all_class_ious = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)
    image_ious = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Visualizing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            # 计算每张图片的IoU
            for j in range(images.shape[0]):
                image_index = i * dataloader.batch_size + j
                _, class_ious = calculate_miou(pred[j].cpu().numpy(), labels[j].cpu().numpy(), num_classes)
                all_class_ious += np.nan_to_num(class_ious)
                class_counts += ~np.isnan(class_ious)
                image_ious.append(class_ious)

                # 只保存特定索引的图片
                if image_index in [0, 10, 20, 30]:
                    save_comparison(images[j], labels[j], pred[j], image_index, save_dir)

    # 计算平均IoU
    average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)

    print("Class IoUs:")
    for cls, iou in enumerate(average_ious):
        print(f"Class {cls} ({class_names[cls]}): {iou:.4f}")

    # 绘制每类IoU图
    plot_per_class_iou(average_ious, os.path.join(save_dir, 'per_class_iou.png'))

    # 计算mIoU
    miou = np.mean(average_ious)
    print(f"Mean IoU: {miou:.4f}")

    # 创建并保存CSV文件
    csv_file = os.path.join(save_dir, 'image_class_ious.csv')
    df = pd.DataFrame(image_ious, columns=[f"{cls}_{name}" for cls, name in enumerate(class_names)])
    df.index.name = 'Image_Number'
    df.to_csv(csv_file)
    print(f"Image-wise class IoUs saved to {csv_file}")

    return average_ious, miou

if __name__ == "__main__":
    save_dir = 'visualization_results/Unet'
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 17

    # 创建UNet模型实例
    model = UNet(3, num_classes).to(device)

    # 加载模型权重
    model_path = os.path.join('model_checkpoints', 'Unet', 'checkpoint_epoch_50.pth')
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    # 运行可视化并获取结果
    class_ious, miou = visualize_results(model, test_loader, device, num_classes, save_dir)

    # 保存结果到文件
    results_file = os.path.join(save_dir, 'iou_results.txt')
    with open(results_file, 'w') as f:
        f.write(f"Mean IoU: {miou:.4f}\n\n")
        f.write("Class IoUs:\n")
        for cls, iou in enumerate(class_ious):
            f.write(f"Class {cls} ({class_names[cls]}): {iou:.4f}\n")

    print(f"Visualization results and IoU statistics saved in {save_dir}")