In [1]:
# 調整 從epoch 30學習率就降低為0了，收斂過快，後期loss浮動沒有明顯下降，mIOU 像素精準度 Dice相關係數也沒有明顯上升
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
from torch.cuda.amp import GradScaler, autocast
from transformers import Mask2FormerForUniversalSegmentation
from data_load_for_mask2former import EnhancedWildScenesDataset
from tqdm import tqdm
import numpy as np
import os
import logging
from utils.metrics import calculate_miou_train, calculate_pixel_accuracy, calculate_dice_coefficient
from utils.losses import CombinedLoss
from utils.log import setup_logger, save_checkpoint
import torch.nn.functional as F
import math
from models.custom_mask2former import CustomMask2Former



def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, num_classes, scaler):
    model.train()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)

            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:],
                                        mode='bilinear', align_corners=False)

            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()

        # 调整梯度裁剪的 max_norm 值為0.5
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
        pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
        dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

        if not np.isnan(miou):
            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            num_batches += 1

    return (total_loss / len(dataloader),
            total_miou / num_batches if num_batches > 0 else 0.0,
            total_pixel_acc / num_batches if num_batches > 0 else 0.0,
            total_dice / num_batches if num_batches > 0 else 0.0)


def validate_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:],
                                        mode='bilinear', align_corners=False)

            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice

    num_batches = len(dataloader)
    return (total_loss / num_batches,
            total_miou / num_batches,
            total_pixel_acc / num_batches,
            total_dice / num_batches)


def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir, num_classes):
    best_miou = 0
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    scaler = GradScaler()

    for epoch in range(num_epochs):
        current_epoch = epoch + 1
        logging.info(f"Epoch {current_epoch}/{num_epochs}")

        train_loss, train_miou, train_pixel_acc, train_dice = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, num_classes, scaler)
        logging.info(f"Epoch {current_epoch} - Train Loss: {train_loss:.4f}, Train mIoU: {train_miou:.4f}, "
                     f"Train Pixel Acc: {train_pixel_acc:.4f}, Train Dice: {train_dice:.4f}")

        val_loss, val_miou, val_pixel_acc, val_dice = validate_epoch(
            model, val_loader, criterion, device, num_classes)
        logging.info(f"Epoch {current_epoch} - Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}, "
                     f"Val Pixel Acc: {val_pixel_acc:.4f}, Val Dice: {val_dice:.4f}")

        scheduler.step(val_loss)  # 使用验证损失来调整学习率

        metrics = {
            'miou': val_miou,
            'pixel_acc': val_pixel_acc,
            'dice': val_dice
        }

        if val_miou > best_miou:
            best_miou = val_miou
            best_model_path = os.path.join(save_dir, f'best_model_epoch_{current_epoch}.pth')
            save_checkpoint(model, optimizer, current_epoch, metrics, best_model_path)
            logging.info(f"Epoch {current_epoch} - Best model saved with mIoU: {best_miou:.4f}")

        # if current_epoch % 5 == 0:
        #     checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{current_epoch}.pth')
        #     save_checkpoint(model, optimizer, current_epoch, metrics, checkpoint_path)
        #     logging.info(f"Epoch {current_epoch} - Checkpoint saved")

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Current learning rate: {current_lr:.6f}")

    logging.info(f"Training completed after {num_epochs} epochs.")
    return best_model_path


if __name__ == "__main__":
    # 创建保存目录
    save_dir = os.path.join('model_checkpoints', 'Mask2Former_Swin-L')
    os.makedirs(save_dir, exist_ok=True)

    # 设置日志文件路径
    log_file = os.path.join(save_dir, 'training.log')
    setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=16)
    val_loader = EnhancedWildScenesDataset.get_data_loader('valid', batch_size=16)

    num_classes = 17

    model = CustomMask2Former(num_classes=num_classes).to(device)

    # 調整損失函數比重
    criterion = CombinedLoss(weight_focal=0.75, weight_dice=0.25)

    # 降低優化器的weight_decay，嘗試不同的學習率
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=0.001)
    
    num_epochs = 60
    
    # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)

    best_model_path = train(model, train_loader, val_loader, criterion, optimizer, scheduler,
                            num_epochs, device, save_dir, num_classes)

    logging.info("Training and prediction completed!")

Using device: cuda
Epoch 1/60
Training: 100%|██████████| 68/68 [16:51<00:00, 14.88s/it]
Epoch 1 - Train Loss: 1.6685, Train mIoU: 0.3724, Train Pixel Acc: 0.5963, Train Dice: 0.4039
Validating: 100%|██████████| 19/19 [02:55<00:00,  9.23s/it]
Epoch 1 - Val Loss: 0.4175, Val mIoU: 0.5663, Val Pixel Acc: 0.7771, Val Dice: 0.5985
Checkpoint saved: model_checkpoints\Mask2Former_Swin-L\best_model_epoch_1.pth
Epoch 1 - Best model saved with mIoU: 0.5663
Current learning rate: 0.000100
Epoch 2/60
Training: 100%|██████████| 68/68 [17:24<00:00, 15.36s/it]
Epoch 2 - Train Loss: 0.4065, Train mIoU: 0.5417, Train Pixel Acc: 0.7908, Train Dice: 0.5796
Validating: 100%|██████████| 19/19 [02:44<00:00,  8.66s/it]
Epoch 2 - Val Loss: 0.3566, Val mIoU: 0.5958, Val Pixel Acc: 0.8258, Val Dice: 0.6357
Checkpoint saved: model_checkpoints\Mask2Former_Swin-L\best_model_epoch_2.pth
Epoch 2 - Best model saved with mIoU: 0.5958
Current learning rate: 0.000100
Epoch 3/60
Training: 100%|██████████| 68/68 [17:16<00

In [5]:
import os
import logging
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from data_load_for_mask2former import EnhancedWildScenesDataset
from models.custom_mask2former import CustomMask2Former
from utils.metrics import calculate_miou, calculate_pixel_accuracy, calculate_dice_coefficient

def setup_logger(log_file):
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_file),
                            logging.StreamHandler()
                        ])
    return logging.getLogger(__name__)  # Return logger object

def normalize_image(image):
    """Normalize image tensor to [0, 1] range."""
    image = image.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    return np.clip(image, 0, 1)

def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    # Normalize and convert image to correct format
    image_np = normalize_image(image)
    image_np = np.transpose(image_np, (1, 2, 0))  # Change from (C, H, W) to (H, W, C)

    ax1.imshow(image_np)
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(ground_truth.cpu().numpy(), cmap='viridis')
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(prediction.cpu().numpy(), cmap='viridis')
    ax3.set_title('Prediction')
    ax3.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()

def test(model, dataloader, device, num_classes, save_dir, logger):
    model.eval()
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    class_iou = np.zeros(num_classes)
    image_count = 0
    save_indices = {0, 10, 20, 30}  # Set of indices to save

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Testing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)

            pred = torch.argmax(outputs, dim=1)

            miou, class_iou_batch = calculate_miou(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            class_iou += class_iou_batch

            # Save ground truth and prediction for specific images in the batch
            for j in range(images.shape[0]):
                if image_count in save_indices:
                    save_comparison(images[j], labels[j], pred[j], image_count, save_dir)
                    logger.info(f"Saved comparison for image {image_count}")
                image_count += 1

    num_batches = len(dataloader)
    avg_miou = total_miou / num_batches
    avg_pixel_acc = total_pixel_acc / num_batches
    avg_dice = total_dice / num_batches
    avg_class_iou = class_iou / num_batches

    logger.info(f"Total images processed: {image_count}")

    return avg_miou, avg_pixel_acc, avg_dice, avg_class_iou

if __name__ == "__main__":
    save_dir = 'prediction/Mask2Former'
    log_file = os.path.join(save_dir, 'testing.log')

    os.makedirs(save_dir, exist_ok=True)

    logger = setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    num_classes = 17

    model = CustomMask2Former(num_classes=num_classes).to(device)

    best_model_path = os.path.join('model_checkpoints', 'Mask2Former_Swin-L', 'best_model_epoch_48.pth')

    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Loaded model from {best_model_path}")
    else:
        logger.error(f"Model file not found: {best_model_path}")
        exit(1)

    miou, pixel_acc, dice, class_iou = test(model, test_loader, device, num_classes, save_dir, logger)

    logger.info(f"Test Results:")
    logger.info(f"Mean IoU: {miou:.4f}")
    logger.info(f"Pixel Accuracy: {pixel_acc:.4f}")
    logger.info(f"Dice Coefficient: {dice:.4f}")

    logger.info("Per-class IoU:")
    for i, iou in enumerate(class_iou):
        logger.info(f"Class {i}: {iou:.4f}")

    # Visualize per-class IoU
    plt.figure(figsize=(12, 6))
    plt.bar(range(num_classes), class_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    plt.savefig(os.path.join(save_dir, 'per_class_iou.png'))
    plt.close()

Using device: cuda
Loaded model from model_checkpoints\Mask2Former_Swin-L\best_model_epoch_48.pth
Testing:   0%|          | 0/19 [00:00<?, ?it/s]Saved comparison for image 0
Testing:   5%|▌         | 1/19 [00:06<02:02,  6.82s/it]Saved comparison for image 10
Testing:  11%|█         | 2/19 [00:13<01:52,  6.61s/it]Saved comparison for image 20
Testing:  16%|█▌        | 3/19 [00:20<01:51,  6.94s/it]Saved comparison for image 30
Testing: 100%|██████████| 19/19 [02:06<00:00,  6.63s/it]
Total images processed: 152
Test Results:
Mean IoU: 0.4783
Pixel Accuracy: 0.8571
Dice Coefficient: 0.7614
Per-class IoU:
Class 0: 0.7960
Class 1: nan
Class 2: nan
Class 3: nan
Class 4: nan
Class 5: 0.4989
Class 6: 0.8346
Class 7: 0.1172
Class 8: nan
Class 9: nan
Class 10: nan
Class 11: nan
Class 12: 0.1926
Class 13: 0.1425
Class 14: 0.6896
Class 15: nan
Class 16: 0.8284


In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
import seaborn as sns
from data_load_for_mask2former import EnhancedWildScenesDataset
from models.custom_mask2former import CustomMask2Former
from sklearn.metrics import confusion_matrix
import csv
import pandas as pd

# 定义颜色映射
color_map = {
    1: [224, 31, 77],  # Bush
    0: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    14: [127, 127, 127],  # Grass
    3: [145, 24, 178],  # Gravel
    13: [125, 128, 16],  # Log
    12: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    16: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    6: [208, 245, 71],  # Tree-foliage
    5: [238, 47, 227],  # Tree-trunk
    4: [40, 127, 198],  # Water
    15: [0, 0, 0],  # 背景类（黑色）
    10: [128, 128, 128],  # 忽略类（灰色）
}

class_names = ['Dirt', 'Bush', 'Fence', 'Gravel', 'Water', 'Tree-trunk', 'Tree-Foliage', 'Other-object',
               'Other-terrain', 'Rock', 'Ignore', 'Structure', 'Mud', 'Log', 'Grass', 'Background', 'Sky']

def plot_per_class_iou(class_iou, save_path):
    plt.figure(figsize=(12, 6))
    plt.bar(range(len(class_iou)), class_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    plt.xticks(range(len(class_iou)), class_names, rotation=90)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def apply_color_map(segmentation):
    color_segmentation = np.zeros((*segmentation.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        color_segmentation[segmentation == class_idx] = color
    return color_segmentation

def overlay_segmentation(image, segmentation, alpha=0.5):
    colored_seg = apply_color_map(segmentation)
    return (image * (1 - alpha) + colored_seg * alpha).astype(np.uint8)

def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(apply_color_map(ground_truth.cpu().numpy()))
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(apply_color_map(prediction.cpu().numpy()))
    ax3.set_title('Prediction')
    ax3.axis('off')

    ax4.imshow(overlay_segmentation(image.permute(1, 2, 0).cpu().numpy(), prediction.cpu().numpy()))
    ax4.set_title('Overlay')
    ax4.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()

def calculate_miou(pred, target, num_classes):
    ious = []
    pred = pred.ravel()
    target = target.ravel()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # 如果该类别不存在，则IoU为NaN
        else:
            ious.append(intersection / union)
    miou = np.nanmean(ious)  # 忽略NaN值计算平均IoU
    return miou, np.array(ious)

def visualize_results(model, dataloader, device, num_classes, save_dir):
    model.eval()
    all_class_ious = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)
    image_ious = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Visualizing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            for j in range(images.shape[0]):
                image_index = i * dataloader.batch_size + j
                _, class_ious = calculate_miou(pred[j].cpu().numpy(), labels[j].cpu().numpy(), num_classes)
                all_class_ious += np.nan_to_num(class_ious)
                class_counts += ~np.isnan(class_ious)
                image_ious.append(class_ious)

                if image_index in [0, 10, 20, 30]:
                    save_comparison(images[j], labels[j], pred[j], image_index, save_dir)

    average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)

    print("Class IoUs:")
    for cls, iou in enumerate(average_ious):
        print(f"Class {cls} ({class_names[cls]}): {iou:.4f}")

    plot_per_class_iou(average_ious, os.path.join(save_dir, 'per_class_iou.png'))
    print(f"Per-class IoU plot saved to {os.path.join(save_dir, 'per_class_iou.png')}")

    miou = np.mean(average_ious)
    print(f"Mean IoU: {miou:.4f}")

    csv_file = os.path.join(save_dir, 'image_class_ious.csv')
    df = pd.DataFrame(image_ious, columns=[f"{cls}_{name}" for cls, name in enumerate(class_names)])
    df.index.name = 'Image_Number'
    df.to_csv(csv_file)
    print(f"Image-wise class IoUs saved to {csv_file}")

    return average_ious, miou

if __name__ == "__main__":
    save_dir = 'visualization_results/Mask2Former'
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    num_classes = 17

    model = CustomMask2Former(num_classes=num_classes).to(device)

    model_path = os.path.join('model_checkpoints', 'Mask2Former_Swin-L', 'best_model_epoch_24.pth')
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {model_path}")

    model.eval()

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    class_ious, miou = visualize_results(model, test_loader, device, num_classes, save_dir)

    results_file = os.path.join(save_dir, 'iou_results.txt')
    with open(results_file, 'w') as f:
        f.write(f"Mean IoU: {miou:.4f}\n\n")
        f.write("Class IoUs:\n")
        for cls, iou in enumerate(class_ious):
            f.write(f"Class {cls} ({class_names[cls]}): {iou:.4f}\n")

    print(f"Visualization results and IoU statistics saved in {save_dir}")

Using device: cuda
Loaded model from model_checkpoints\Mask2Former_Swin-L\best_model_epoch_24.pth


Visualizing:   0%|          | 0/29 [00:00<?, ?it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.0749335..2.6400008].
Visualizing:   3%|▎         | 1/29 [00:04<02:14,  4.80s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.1179044..2.6400008].
Visualizing:   7%|▋         | 2/29 [00:08<01:57,  4.35s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-1.9486468..2.6400008].
Visualizing:  10%|█         | 3/29 [00:13<01:51,  4.29s/it]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.1179044..2.6400008].
Visualizing: 100%|██████████| 29/29 [01:47<00:00,  3.71s/it]

Class IoUs:
Class 0 (Dirt): 0.6643
Class 1 (Bush): 0.0623
Class 2 (Fence): 0.4809
Class 3 (Gravel): 0.5521
Class 4 (Water): 0.0000
Class 5 (Tree-trunk): 0.3863
Class 6 (Tree-Foliage): 0.8332
Class 7 (Other-object): 0.1582
Class 8 (Other-terrain): 0.0000
Class 9 (Rock): 0.1148
Class 10 (Ignore): 0.0000
Class 11 (Structure): 0.1158
Class 12 (Mud): 0.1731
Class 13 (Log): 0.0979
Class 14 (Grass): 0.6769
Class 15 (Background): 0.0000
Class 16 (Sky): 0.5704
Per-class IoU plot saved to visualization_results/Mask2Former\per_class_iou.png
Mean IoU: 0.2874
Image-wise class IoUs saved to visualization_results/Mask2Former\image_class_ious.csv
Visualization results and IoU statistics saved in visualization_results/Mask2Former



  average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)
