In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import os
import logging
from data_load import EnhancedWildScenesDataset
from models.unet import UNet
from utils.metrics import calculate_miou_train, calculate_pixel_accuracy, calculate_dice_coefficient
from utils.losses import CombinedLoss
from utils.log import setup_logger, save_checkpoint
import torch.nn.functional as F

def setup_logger(log_file):
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_file),
                            logging.StreamHandler()
                        ])

def save_checkpoint(model, optimizer, epoch, metrics, path):
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics
    }
    torch.save(state, path)

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, num_classes, scaler):
    model.train()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
        pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
        dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

        if not np.isnan(miou):
            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            num_batches += 1

    return (total_loss / len(dataloader),
            total_miou / num_batches if num_batches > 0 else 0.0,
            total_pixel_acc / num_batches if num_batches > 0 else 0.0,
            total_dice / num_batches if num_batches > 0 else 0.0)

def validate_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if outputs.shape[-2:] != labels.shape[-2:]:
                outputs = F.interpolate(outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice

    num_batches = len(dataloader)
    return (total_loss / num_batches,
            total_miou / num_batches,
            total_pixel_acc / num_batches,
            total_dice / num_batches)

def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir):
    best_loss = float('inf')
    best_miou = 0
    scaler = GradScaler()

    for epoch in range(num_epochs):
        logging.info(f"Epoch {epoch + 1}/{num_epochs}")

        train_loss, train_miou, train_pixel_acc, train_dice = train_epoch(model, train_loader, criterion, optimizer, scheduler, device, num_classes, scaler)
        logging.info(f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Train mIoU: {train_miou:.4f}, "
                     f"Train Pixel Acc: {train_pixel_acc:.4f}, Train Dice: {train_dice:.4f}")

        val_loss, val_miou, val_pixel_acc, val_dice = validate_epoch(model, val_loader, criterion, device, num_classes)
        logging.info(f"Epoch {epoch + 1} - Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}, "
                     f"Val Pixel Acc: {val_pixel_acc:.4f}, Val Dice: {val_dice:.4f}")

        scheduler.step(val_loss)

        metrics = {
            'val_loss': val_loss,
            'val_miou': val_miou,
            'val_pixel_acc': val_pixel_acc,
            'val_dice': val_dice
        }

        if val_miou > best_miou:
            best_miou = val_miou
            best_model_path = os.path.join(save_dir, f'best_model_epoch_{epoch + 1}.pth')
            save_checkpoint(model, optimizer, epoch + 1, metrics, best_model_path)
            logging.info(f"Epoch {epoch + 1} - Best model saved with Val mIoU: {best_miou:.4f}")

        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            save_checkpoint(model, optimizer, epoch + 1, metrics, checkpoint_path)
            logging.info(f"Epoch {epoch + 1} - Checkpoint saved")

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Current learning rate: {current_lr:.6f}")

    logging.info(f"Training completed after {num_epochs} epochs.")
    return best_model_path

if __name__ == "__main__":
    save_dir = 'model_checkpoints/Unet'
    os.makedirs(save_dir, exist_ok=True)

    log_file = os.path.join(save_dir, 'training.log')
    setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=8)
    val_loader = EnhancedWildScenesDataset.get_data_loader('valid', batch_size=8)

    num_classes = 18

    model = UNet(3, num_classes).to(device)

    criterion = CombinedLoss(weight_focal=1.0, weight_dice=0.5)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5, weight_decay=0.01)

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    num_epochs = 30
    best_model_path = train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir)
    logging.info(f"Training completed! Best model saved at {best_model_path}")


2024-07-18 23:00:53,688 - INFO - Using device: cuda
2024-07-18 23:00:53,959 - INFO - Epoch 1/30
Training: 100%|██████████████████████████████████████████████████████████████████████| 137/137 [08:05<00:00,  3.54s/it]
2024-07-18 23:08:59,236 - INFO - Epoch 1 - Train Loss: 2.1668, Train mIoU: 0.1777, Train Pixel Acc: 0.5775, Train Dice: 0.2115
Validating: 100%|██████████████████████████████████████████████████████████████████████| 39/39 [02:16<00:00,  3.50s/it]
2024-07-18 23:11:15,574 - INFO - Epoch 1 - Val Loss: 1.6074, Val mIoU: 0.4863, Val Pixel Acc: 0.7662, Val Dice: 0.5237
2024-07-18 23:11:15,912 - INFO - Epoch 1 - Best model saved with Val mIoU: 0.4863
2024-07-18 23:11:15,912 - INFO - Current learning rate: 0.000050
2024-07-18 23:11:15,913 - INFO - Epoch 2/30
Training: 100%|██████████████████████████████████████████████████████████████████████| 137/137 [07:55<00:00,  3.47s/it]
2024-07-18 23:19:11,514 - INFO - Epoch 2 - Train Loss: 1.4664, Train mIoU: 0.4233, Train Pixel Acc: 0.7909,

In [14]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import os
import logging
from data_load import EnhancedWildScenesDataset
from models.unet import UNet
from utils.metrics import calculate_miou, calculate_pixel_accuracy, calculate_dice_coefficient
import matplotlib.pyplot as plt

def setup_logger(log_file):
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_file),
                            logging.StreamHandler()
                        ])
    return logging.getLogger(__name__)  # Return logger object

def test(model, dataloader, device, num_classes, save_dir, logger):
    model.eval()
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0
    class_iou = np.zeros(num_classes)

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Testing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            miou, class_iou_batch = calculate_miou(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            class_iou += class_iou_batch

            if i in [0, 10, 20, 30]:  # Save images 0, 10, 20, and 30
                save_comparison(images[0], labels[0], pred[0], i, save_dir)
                logger.info(f"Saved comparison for image {i}")

    num_batches = len(dataloader)
    avg_miou = total_miou / num_batches
    avg_pixel_acc = total_pixel_acc / num_batches
    avg_dice = total_dice / num_batches
    avg_class_iou = class_iou / num_batches

    return avg_miou, avg_pixel_acc, avg_dice, avg_class_iou
            
def normalize_image(image):
    """Normalize image tensor to [0, 1] range."""
    image = image.cpu().numpy()
    image = (image - image.min()) / (image.max() - image.min())
    return np.clip(image, 0, 1)

def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
    
    # Normalize and convert image to correct format
    image_np = normalize_image(image)
    image_np = np.transpose(image_np, (1, 2, 0))  # Change from (C, H, W) to (H, W, C)
    
    # Original image
    ax1.imshow(image_np)
    ax1.set_title("Original Image")
    ax1.axis('off')
    
    # Ground truth
    ax2.imshow(ground_truth.cpu().numpy(), cmap='viridis')
    ax2.set_title("Ground Truth")
    ax2.axis('off')
    
    # Prediction
    ax3.imshow(prediction.cpu().numpy(), cmap='viridis')
    ax3.set_title("Prediction")
    ax3.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()

if __name__ == "__main__":
    save_dir = 'prediction/Unet'
    log_file = os.path.join(save_dir, 'testing.log')
    
    os.makedirs(save_dir, exist_ok=True)
    
    logger = setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    num_classes = 17

    model = UNet(3, num_classes).to(device)
    
    best_model_path = os.path.join('model_checkpoints', 'Unet', 'best_model_epoch_30.pth')
    
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Loaded model from {best_model_path}")
    else:
        logger.error(f"Model file not found at {best_model_path}")
        exit(1)

    miou, pixel_acc, dice, class_iou = test(model, test_loader, device, num_classes, save_dir, logger)
    
    logger.info(f"Test Results:")
    logger.info(f"Mean IoU: {miou:.4f}")
    logger.info(f"Pixel Accuracy: {pixel_acc:.4f}")
    logger.info(f"Dice Coefficient: {dice:.4f}")
    
    # 打印每个类别的IoU
    for i, iou in enumerate(class_iou):
        logger.info(f"Class {i} IoU: {iou:.4f}")

    # 可视化每个类别的IoU
    plt.figure(figsize=(12, 6))
    plt.bar(range(num_classes), class_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    plt.savefig(os.path.join(save_dir, 'per_class_iou.png'))
    plt.close()

2024-07-22 05:20:29,763 - INFO - Using device: cuda


2024-07-22 05:20:30,344 - INFO - Loaded model from model_checkpoints/Unet/best_model_epoch_30.pth
Testing:   0%|          | 0/35 [00:00<?, ?it/s]2024-07-22 05:20:37,779 - INFO - Saved comparison for image 0
Testing:  29%|██▊       | 10/35 [00:57<02:19,  5.59s/it]2024-07-22 05:21:33,673 - INFO - Saved comparison for image 10
Testing:  57%|█████▋    | 20/35 [01:53<01:23,  5.57s/it]2024-07-22 05:22:29,755 - INFO - Saved comparison for image 20
Testing:  86%|████████▌ | 30/35 [02:49<00:27,  5.59s/it]2024-07-22 05:23:26,069 - INFO - Saved comparison for image 30
Testing: 100%|██████████| 35/35 [03:18<00:00,  5.66s/it]
2024-07-22 05:23:48,441 - INFO - Test Results:
2024-07-22 05:23:48,442 - INFO - Mean IoU: 0.4426
2024-07-22 05:23:48,443 - INFO - Pixel Accuracy: 0.8469
2024-07-22 05:23:48,444 - INFO - Dice Coefficient: 0.7168
2024-07-22 05:23:48,444 - INFO - Class 0 IoU: 0.7873
2024-07-22 05:23:48,445 - INFO - Class 1 IoU: nan
2024-07-22 05:23:48,446 - INFO - Class 2 IoU: nan
2024-07-22 05:2

In [13]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
import seaborn as sns
from data_load import EnhancedWildScenesDataset
from torchvision import models
from models.unet import UNet
from sklearn.metrics import confusion_matrix
import csv
import pandas as pd

# 定义颜色映射
color_map = {
    1: [224, 31, 77],  # Bush
    0: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    14: [127, 127, 127],  # Grass
    3: [145, 24, 178],  # Gravel
    13: [125, 128, 16],  # Log
    12: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    16: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    6: [208, 245, 71],  # Tree-foliage
    5: [238, 47, 227],  # Tree-trunk
    4: [40, 127, 198],  # Water
    15: [0, 0, 0],  # 背景类（黑色）
    10: [128, 128, 128],  # 忽略类（灰色）
}

class_names = ['Dirt', 'Bush', 'Fence', 'Gravel', 'Water', 'Tree-trunk', 'Tree-Foliage', 'Other-object',
               'Other-terrain', 'Rock', 'Ignore', 'Structure', 'Mud', 'Log', 'Grass', 'Background', 'Sky']


def apply_color_map(segmentation):
    color_segmentation = np.zeros((*segmentation.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        color_segmentation[segmentation == class_idx] = color
    return color_segmentation


def overlay_segmentation(image, segmentation, alpha=0.5):
    colored_seg = apply_color_map(segmentation)
    return (image * (1 - alpha) + colored_seg * alpha).astype(np.uint8)


def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(apply_color_map(ground_truth.cpu().numpy()))
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(apply_color_map(prediction.cpu().numpy()))
    ax3.set_title('Prediction')
    ax3.axis('off')

    ax4.imshow(overlay_segmentation(image.permute(1, 2, 0).cpu().numpy(), prediction.cpu().numpy()))
    ax4.set_title('Overlay')
    ax4.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()


def plot_confusion_matrix(y_true, y_pred, save_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(save_path)
    plt.close()


def plot_per_class_iou(class_iou, save_path):
    plt.figure(figsize=(12, 6))
    valid_iou = [iou for iou in class_iou if not np.isnan(iou)]
    plt.bar(range(len(valid_iou)), valid_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    valid_class_names = [name for i, name in enumerate(class_names) if i < len(valid_iou)]
    plt.xticks(range(len(valid_iou)), valid_class_names, rotation=90)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def calculate_miou(pred, target, num_classes):
    ious = []
    pred = pred.ravel()
    target = target.ravel()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # 如果该类别不存在，则IoU为NaN
        else:
            ious.append(intersection / union)
    miou = np.nanmean(ious)  # 忽略NaN值计算平均IoU
    return miou, np.array(ious)


def visualize_results(model, dataloader, device, num_classes, save_dir):
    model.eval()
    all_class_ious = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)
    image_ious = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Visualizing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            # 计算每张图片的IoU
            for j in range(images.shape[0]):
                image_index = i * dataloader.batch_size + j
                _, class_ious = calculate_miou(pred[j].cpu().numpy(), labels[j].cpu().numpy(), num_classes)
                all_class_ious += np.nan_to_num(class_ious)
                class_counts += ~np.isnan(class_ious)
                image_ious.append(class_ious)

                # 只保存特定索引的图片
                if image_index in [0, 10, 20, 30]:
                    save_comparison(images[j], labels[j], pred[j], image_index, save_dir)

    # 计算平均IoU
    average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)

    print("Class IoUs:")
    for cls, iou in enumerate(average_ious):
        print(f"Class {cls} ({class_names[cls]}): {iou:.4f}")

    # 绘制每类IoU图
    plot_per_class_iou(average_ious, os.path.join(save_dir, 'per_class_iou.png'))

    # 计算mIoU
    miou = np.mean(average_ious)
    print(f"Mean IoU: {miou:.4f}")

    # 创建并保存CSV文件
    csv_file = os.path.join(save_dir, 'image_class_ious.csv')
    df = pd.DataFrame(image_ious, columns=[f"{cls}_{name}" for cls, name in enumerate(class_names)])
    df.index.name = 'Image_Number'
    df.to_csv(csv_file)
    print(f"Image-wise class IoUs saved to {csv_file}")

    return average_ious, miou

if __name__ == "__main__":
    save_dir = 'visualization_results/Unet'
    os.makedirs(save_dir, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 17

    # 创建UNet模型实例
    model = UNet(3, num_classes).to(device)

    # 加载模型权重
    model_path = os.path.join('model_checkpoints', 'Unet', 'best_model_epoch_30.pth')
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    # 运行可视化并获取结果
    class_ious, miou = visualize_results(model, test_loader, device, num_classes, save_dir)

    # 保存结果到文件
    results_file = os.path.join(save_dir, 'iou_results.txt')
    with open(results_file, 'w') as f:
        f.write(f"Mean IoU: {miou:.4f}\n\n")
        f.write("Class IoUs:\n")
        for cls, iou in enumerate(class_ious):
            f.write(f"Class {cls} ({class_names[cls]}): {iou:.4f}\n")

    print(f"Visualization results and IoU statistics saved in {save_dir}")

Visualizing: 100%|██████████| 35/35 [03:10<00:00,  5.44s/it]

Class IoUs:
Class 0 (Dirt): 0.6739
Class 1 (Bush): 0.0000
Class 2 (Fence): 0.4713
Class 3 (Gravel): 0.0244
Class 4 (Water): 0.0000
Class 5 (Tree-trunk): 0.5130
Class 6 (Tree-Foliage): 0.8327
Class 7 (Other-object): 0.0966
Class 8 (Other-terrain): 0.0000
Class 9 (Rock): 0.0000
Class 10 (Ignore): 0.0000
Class 11 (Structure): 0.2700
Class 12 (Mud): 0.1508
Class 13 (Log): 0.0000
Class 14 (Grass): 0.6017
Class 15 (Background): 0.0000
Class 16 (Sky): 0.5573
Mean IoU: 0.2466
Image-wise class IoUs saved to visualization_results/Unet/image_class_ious.csv
Visualization results and IoU statistics saved in visualization_results/Unet



  average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)
