In [10]:
import pandas as pd
import os
import sklearn
from PIL import Image
import numpy as np
import logging

In [11]:
class WildScenesDataset:
    _data_list_dir = os.path.join('datasets', 'data_list')
    _csv_files = {
        'train': os.path.join(_data_list_dir, 'train.csv'),
        'valid': os.path.join(_data_list_dir, 'valid.csv'),
        'test': os.path.join(_data_list_dir, 'test.csv'),
    }
    csv = _csv_files
    _label_to_trainid = {
        0: 15,  # 背景类
        1: 16,  # 忽略类
        2: 0,  # Bush
        3: 1,  # Dirt
        4: 2,  # Fence
        5: 3,  # Grass
        6: 4,  # Gravel
        7: 5,  # Log
        8: 6,  # Mud
        9: 7,  # Other-Object
        10: 8,  # Other-terrain
        11: 16,  # 忽略类
        12: 9,  # Rock
        13: 10,  # Sky
        14: 11,  # Structure
        15: 12,  # Tree-foliage
        16: 13,  # Tree-trunk
        17: 16,  # 忽略类
        18: 14,  # Water
    }
    
    @staticmethod
    def ensure_dir(directory):
        """确保目录存在，如果不存在则创建"""
        if not os.path.exists(directory):
            os.makedirs(directory)
            
    def __init__(self, dataset_type, transform=None):
        assert dataset_type in ('train', 'valid', 'test')
        self._dataset_type = dataset_type
        self._data_frame = pd.read_csv(WildScenesDataset._csv_files[self._dataset_type])
        self._transform = transform

    def __len__(self):
        return len(self._data_frame)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f"Index {index} out of bounds for dataset of size {len(self)}")
        try:
            image_path = self._data_frame['image'].iloc[index]
            label_path = self._data_frame['label'].iloc[index]

            image = Image.open(image_path).convert('RGB')
            label = Image.open(label_path).convert('L')
            # 将灰度图转换为Numpy数组
            label_np = np.array(label)

            # 将标签索引映射到训练标识（trainId）值
            label_trainId = np.vectorize(lambda x: self._label_to_trainid.get(x, 255))(label_np)

            if self._transform is not None:
                for t in self._transform:
                    image, label_trainId = t(image, label_trainId)

            return image, label_trainId
        except Exception as e:
            logging.error(f"Error loading item at index {index}: {str(e)}")

    @staticmethod
    def _get_image_label_dir():
        """
        Traverse server image and label directories and yield image and label paths.
        :return: Generator yielding (image path, label path)
        """
        data_err = 'data error. check!'
        image_base = WildScenesDataset.image_file_base
        label_base = WildScenesDataset.label_file_base

        for image in os.listdir(image_base):
            image_origin = os.path.join(image_base, image)
            image_label = os.path.join(label_base, image)

            if not (os.path.isfile(image_label) and
                    os.path.exists(image_label) and
                    os.path.isfile(image_label)):
                print(image_origin, image_label, data_err)  # Print error message and skip if paths are invalid
                continue

            yield image_origin, image_label

    @staticmethod
    def make_data_list(train_rate=0.7, valid_rate=0.2, shuffle=True):
        """
        Shuffle and generate data_list CSV files with image and label paths sorted by filename.
        :param train_rate: Training set ratio, default 0.7
        :param valid_rate: Validation set ratio, default 0.2
        :param shuffle: Whether to shuffle the dataset, default True
        :return: None
        """
        WildScenesDataset.ensure_dir(WildScenesDataset._data_list_dir)  # 确保目录存在

        g = WildScenesDataset._get_image_label_dir()  # Get generator
        abspaths = list(g)  # Convert generator to list

        # Create DataFrame with image and label paths
        df = pd.DataFrame(
            data=abspaths,
            columns=['image', 'label']
        )

        # Sort DataFrame by filename (assumed to be timestamp in a sortable format)
        df['timestamp'] = df['image'].apply(lambda x: int(os.path.splitext(os.path.basename(x))[0].split('-')[0]))
        df = df.sort_values(by='timestamp').reset_index(drop=True)

        if shuffle:
            df = sklearn.utils.shuffle(df)  # Shuffle dataframe if specified

        # Calculate sizes for train, valid, and test sets
        train_size = int(df.shape[0] * train_rate)
        valid_size = int(df.shape[0] * valid_rate)

        print('total: {:d} | train: {:d} | val: {:d} | test: {:d}'.format(
            df.shape[0], train_size, valid_size,
            df.shape[0] - train_size - valid_size))

        # Split dataframe into train, valid, and test sets
        df_train = df[0: train_size]
        df_valid = df[train_size: train_size + valid_size]
        df_test = df[train_size + valid_size:]

        # Save train, valid, and test sets to CSV files
        df_train[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['train']), index=False)
        df_valid[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['valid']), index=False)
        df_test[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['test']), index=False)

    # 测试語義分割時用
    @staticmethod
    def test_label_mapping(label_path):
        """
        Test the label mapping for a single label image.

        :param label_path: Path to the label image
        :return: Tuple of original label numpy array and mapped trainId numpy array
        """
        # Open the label image and convert to numpy array
        label = Image.open(label_path).convert('L')
        label_np = np.array(label, dtype=np.uint8)

        # Map label indices to trainId values
        label_trainId = np.vectorize(lambda x: WildScenesDataset._label_to_trainid.get(x, 255))(label_np)

        return label_np, label_trainId


In [12]:
# Example usage
root_dir = os.path.join('..', 'WildScenes_Dataset-61gd5a0t-', 'data', 'WildScenes', 'WildScenes2d', 'V-01')
WildScenesDataset.image_file_base = os.path.join(root_dir, 'image')
WildScenesDataset.label_file_base = os.path.join(root_dir, 'indexLabel')
WildScenesDataset.make_data_list()

# Test label mapping for the first image in labelIndex
test_label_path = os.path.join(WildScenesDataset.label_file_base,
                               '1623370408-092005506.png')  # Replace with an actual image name
original_label, mapped_label = WildScenesDataset.test_label_mapping(test_label_path)

print("Original label array (shape: {}):".format(original_label.shape))
print(original_label)
print("\nMapped trainId array (shape: {}):".format(mapped_label.shape))
print(mapped_label)

# Optional: print unique values in each array
print("\nUnique values in original label array:", np.unique(original_label))
print("Unique values in mapped trainId array:", np.unique(mapped_label))

total: 1576 | train: 1103 | val: 315 | test: 158
Original label array (shape: (1512, 2016)):
[[ 8  8  8 ...  8  8  8]
 [ 8  8  8 ...  8  8  8]
 [ 8  8  8 ...  8  8  8]
 ...
 [18 18 18 ...  2  2  2]
 [18 18 18 ...  2  2  2]
 [18 18 18 ...  2  2  2]]

Mapped trainId array (shape: (1512, 2016)):
[[ 6  6  6 ...  6  6  6]
 [ 6  6  6 ...  6  6  6]
 [ 6  6  6 ...  6  6  6]
 ...
 [14 14 14 ...  0  0  0]
 [14 14 14 ...  0  0  0]
 [14 14 14 ...  0  0  0]]

Unique values in original label array: [ 2  7  8 14 15 16 17 18]
Unique values in mapped trainId array: [ 0  5  6 11 12 13 14 16]


In [13]:
import random
import torch
import numpy as np
from torchvision.transforms import functional as TF
import cv2
from PIL import Image

class TrainTransform:
    def __init__(self, size=256, gaussian_prob=0.5, gaussian_kernel=(5, 5), gaussian_sigma=(0.1, 2.0)):
        self.size = (size, size)  # Changed to tuple
        self.gaussian_prob = gaussian_prob
        self.gaussian_kernel = gaussian_kernel
        self.gaussian_sigma = gaussian_sigma

    def __call__(self, image, label):
        # Resize
        image = TF.resize(image, self.size)
        label = TF.resize(label, self.size, interpolation=TF.InterpolationMode.NEAREST)

        # Random horizontal flip
        if random.random() > 0.5:
            image = TF.hflip(image)
            label = TF.hflip(label)

        # Random rotation
        angle = random.uniform(-10, 10)
        image = TF.rotate(image, angle)
        label = TF.rotate(label, angle, interpolation=TF.InterpolationMode.NEAREST)

        # Random Gaussian blur
        if random.random() < self.gaussian_prob:
            sigma = random.uniform(self.gaussian_sigma[0], self.gaussian_sigma[1])
            image_np = np.array(image)
            image_np = cv2.GaussianBlur(image_np, self.gaussian_kernel, sigma)
            image = Image.fromarray(image_np)

        # Convert to tensor and normalize
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        label = torch.from_numpy(np.array(label)).long()

        # Verify shapes
        assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        assert image.shape[1] == image.shape[2] == self.size[0], f"Image should be square with size {self.size[0]}, got shape {image.shape}"
        assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

class TestTransform:
    def __init__(self, size=256):
        self.size = (size, size)  # Changed to tuple

    def __call__(self, image, label):
        # Resize
        image = TF.resize(image, self.size)
        label = TF.resize(label, self.size, interpolation=TF.InterpolationMode.NEAREST)

        # Convert to tensor and normalize
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        label = torch.from_numpy(np.array(label)).long()

        # Verify shapes
        assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        assert image.shape[1] == image.shape[2] == self.size[0], f"Image should be square with size {self.size[0]}, got shape {image.shape}"
        assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

In [14]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from data_split import WildScenesDataset
from utils.transforms import TrainTransform, TestTransform

color_map = {
    0: [224, 31, 77],  # Bush
    1: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    3: [127, 127, 127],  # Grass
    4: [145, 24, 178],  # Gravel
    5: [125, 128, 16],  # Log
    6: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    10: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    12: [208, 245, 71],  # Tree-foliage
    13: [238, 47, 227],  # Tree-trunk
    14: [40, 127, 198],  # Water
    15: [0, 0, 0],      # 背景类（黑色）
    16: [128, 128, 128],  # 忽略类（灰色）
}

class EnhancedWildScenesDataset(WildScenesDataset):
    def __init__(self, dataset_type, transform=None):
        super().__init__(dataset_type, transform)
        self.color_map = self._load_color_map()
        self.transform = self._get_transform(dataset_type)

    def __getitem__(self, index):
        image_path = self._data_frame['image'][index]
        label_path = self._data_frame['label'][index]

        image = Image.open(image_path).convert('RGB')
        # Use test_label_mapping to get both original and mapped labels
        original_label, mapped_label = self.test_label_mapping(label_path)

        # Convert numpy array to PIL Image for compatibility with transforms
        label = Image.fromarray(mapped_label.astype(np.uint8))

        if self.transform is not None:
            image, label = self.transform(image, label)

        # Verify shapes
        # assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        # assert image.shape[1] == image.shape[2], f"Image should be square, got shape {image.shape}"
        # assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

    def _load_color_map(self):
        return {key: np.array(value) for key, value in color_map.items()}

    def _get_transform(self, dataset_type):
        if dataset_type == 'train':
            return TrainTransform()
        elif dataset_type in ['valid', 'test']:
            return TestTransform()
        else:
            raise ValueError('Invalid dataset type')

    @staticmethod
    def get_color_coded_label(label_trainId):
        """
        Convert trainId label to RGB color-coded label.
        :param label_trainId: numpy array of trainId labels
        :return: numpy array of RGB color-coded labels
        """
        height, width = label_trainId.shape
        label_RGB = np.zeros((height, width, 3), dtype=np.uint8)
        for trainId, color in color_map.items():
            label_RGB[label_trainId == trainId] = color
        return label_RGB

def get_data_loader(dataset_type, batch_size=4):
    dataset = EnhancedWildScenesDataset(dataset_type)
    shuffle = dataset_type == 'train'
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True, drop_last=True)

# 测试data_loader
train_loader = get_data_loader('train', batch_size=4)
for images, labels in train_loader:
    print(f"Batch image shape: {images.shape}") # 一個批次的圖像數據形狀，一個批次4個圖象，每個圖象3通道，每個圖像尺寸256*341
    print(f"Batch label shape: {labels.shape}") # 一個批次的label數據形狀，單通道，表示的是trainId標注的
    print(f"Batch label unique values: {torch.unique(labels)}")
    break

dataset = EnhancedWildScenesDataset('train')
image, label = dataset[0]
color_coded_label = dataset.get_color_coded_label(label.numpy())
print(f"Color-coded label shape: {color_coded_label.shape}") # 使用顔色編碼的label圖像形狀，發現是3通道，成了！

Batch image shape: torch.Size([4, 3, 256, 256])
Batch label shape: torch.Size([4, 256, 256])
Batch label unique values: tensor([ 0,  1,  2,  5,  6,  7, 12, 13, 14, 16])
Color-coded label shape: (256, 256, 3)


In [6]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from data_load import EnhancedWildScenesDataset
from tqdm import tqdm
import numpy as np
import os
import logging
from utils.metrics import calculate_miou_train, calculate_pixel_accuracy, calculate_dice_coefficient
from utils.losses import CombinedLoss
from models.custom_deeplabv3 import CustomDeepLabV3
import torch.nn.functional as F

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, num_classes, scaler):
    model.train()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        total_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
        pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
        dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

        if not np.isnan(miou):
            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            num_batches += 1

    return (total_loss / len(dataloader),
            total_miou / num_batches if num_batches > 0 else 0.0,
            total_pixel_acc / num_batches if num_batches > 0 else 0.0,
            total_dice / num_batches if num_batches > 0 else 0.0)


def validate_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice

    num_batches = len(dataloader)
    return (total_loss / num_batches,
            total_miou / num_batches,
            total_pixel_acc / num_batches,
            total_dice / num_batches)


def save_checkpoint(model, optimizer, epoch, metrics, filename):
    # 保存检查点
    state = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'metrics': metrics
    }
    torch.save(state, filename)


def setup_logger(log_file):
    # 设置日志记录器
    logging.basicConfig(filename=log_file, level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)


def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir, num_classes):
    # 训练模型的主函数
    best_miou = 0
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    scaler = GradScaler()

    for epoch in range(num_epochs):
        current_epoch = epoch + 1
        logging.info(f"Epoch {current_epoch}/{num_epochs}")

        train_loss, train_miou, train_pixel_acc, train_dice = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, num_classes, scaler)
        logging.info(f"Epoch {current_epoch} - Train Loss: {train_loss:.4f}, Train mIoU: {train_miou:.4f}, "
                     f"Train Pixel Acc: {train_pixel_acc:.4f}, Train Dice: {train_dice:.4f}")

        val_loss, val_miou, val_pixel_acc, val_dice = validate_epoch(
            model, val_loader, criterion, device, num_classes)
        logging.info(f"Epoch {current_epoch} - Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}, "
                     f"Val Pixel Acc: {val_pixel_acc:.4f}, Val Dice: {val_dice:.4f}")

        metrics = {
            'miou': val_miou,
            'pixel_acc': val_pixel_acc,
            'dice': val_dice
        }

        if val_miou > best_miou:
            best_miou = val_miou
            best_model_path = os.path.join(save_dir, f'best_model_epoch_{current_epoch}.pth')
            save_checkpoint(model, optimizer, current_epoch, metrics, best_model_path)
            logging.info(f"Epoch {current_epoch} - Best model saved with mIoU: {best_miou:.4f}")

        # if current_epoch % 5 == 0:
        #     checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{current_epoch}.pth')
        #     save_checkpoint(model, optimizer, current_epoch, metrics, checkpoint_path)
        #     logging.info(f"Epoch {current_epoch} - Checkpoint saved")

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Current learning rate: {current_lr:.6f}")

    logging.info(f"Training completed after {num_epochs} epochs.")
    return best_model_path


if __name__ == "__main__":
    # 创建保存目录
    save_dir = os.path.join('model_checkpoints', 'DeepLabV3 ResNet 101')
    os.makedirs(save_dir, exist_ok=True)

    # 设置日志文件路径
    log_file = os.path.join(save_dir, 'training.log')
    setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=16)
    val_loader = EnhancedWildScenesDataset.get_data_loader('valid', batch_size=16)

    # test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    num_classes = 17

    # 选择模型
    model = CustomDeepLabV3(num_classes=num_classes).to(device)

    # 选择损失函数
    # criterion = FocalLoss(alpha=1, gamma=2)
    criterion = CombinedLoss(weight_focal=1.0, weight_dice=0.5)

    # 选择优化器
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001, nesterov=True)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

    num_epochs = 60
    steps_per_epoch = len(train_loader)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=0.01,
        steps_per_epoch=steps_per_epoch,
        epochs=num_epochs,
        pct_start=0.3,
        anneal_strategy='cos',
        div_factor=25,
        final_div_factor=1000
    )

    best_model_path = train(model, train_loader, val_loader, criterion, optimizer, scheduler,
                            num_epochs, device, save_dir, num_classes)

    logging.info("Training and prediction completed!")


2024-07-15 14:19:29,035 - INFO - Using device: cuda
2024-07-15 14:19:29,624 - INFO - Epoch 1/60
Training: 100%|██████████████████████████████████████████████████████████████████████| 137/137 [07:10<00:00,  3.15s/it]
2024-07-15 14:26:40,590 - INFO - Epoch 1 - Train Loss: 0.8442, Train mIoU: 0.5946, Train Pixel Acc: 0.7653, Train Dice: 0.6317
Validating: 100%|██████████████████████████████████████████████████████████████████████| 39/39 [01:58<00:00,  3.03s/it]
2024-07-15 14:28:38,886 - INFO - Epoch 1 - Val Loss: 0.6269, Val mIoU: 0.6499, Val Pixel Acc: 0.8248, Val Dice: 0.6840
2024-07-15 14:28:39,423 - INFO - Epoch 1 - Best model saved with mIoU: 0.6499
2024-07-15 14:28:39,423 - INFO - Current learning rate: 0.000473
2024-07-15 14:28:39,424 - INFO - Epoch 2/60
Training: 100%|██████████████████████████████████████████████████████████████████████| 137/137 [07:09<00:00,  3.13s/it]
2024-07-15 14:35:48,867 - INFO - Epoch 2 - Train Loss: 0.6338, Train mIoU: 0.6359, Train Pixel Acc: 0.8139, Tra

In [2]:
import os
import logging
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from data_load import EnhancedWildScenesDataset
from models.custom_deeplabv3 import CustomDeepLabV3
from utils.metrics import calculate_miou, calculate_pixel_accuracy, calculate_dice_coefficient
import torch.nn.functional as F


def setup_logger(log_file):
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_file),
                            logging.StreamHandler()
                        ])


def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(ground_truth.cpu().numpy())
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(prediction.cpu().numpy())
    ax3.set_title('Prediction')
    ax3.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()


def test(model, dataloader, device, num_classes, save_dir):
    model.eval()
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    class_iou = np.zeros(num_classes)

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Testing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            miou, class_iou_batch = calculate_miou(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            class_iou += class_iou_batch

            if i % 10 == 0:  # Save every 10th image
                save_comparison(images[0], labels[0], pred[0], i, save_dir)

    num_batches = len(dataloader)
    avg_miou = total_miou / num_batches
    avg_pixel_acc = total_pixel_acc / num_batches
    avg_dice = total_dice / num_batches
    avg_class_iou = class_iou / num_batches

    return avg_miou, avg_pixel_acc, avg_dice, avg_class_iou


if __name__ == "__main__":
    save_dir = 'prediction/DeepLabV3_Resnet101'
    log_file = os.path.join(save_dir, 'testing.log')

    os.makedirs(save_dir, exist_ok=True)

    setup_logger(log_file)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    num_classes = 17

    model = CustomDeepLabV3(num_classes=num_classes).to(device)

    best_model_path = os.path.join('model_checkpoints', 'DeepLabV3 ResNet 101', 'best_model_epoch_20.pth')

    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        logging.info(f"Loaded model from {best_model_path}")
    else:
        logging.error(f"Model file not found: {best_model_path}")
        exit(1)

    miou, pixel_acc, dice, class_iou = test(model, test_loader, device, num_classes, save_dir)

    logging.info(f"Test Results:")
    logging.info(f"Mean IoU: {miou:.4f}")
    logging.info(f"Pixel Accuracy: {pixel_acc:.4f}")
    logging.info(f"Dice Coefficient: {dice:.4f}")

    logging.info("Per-class IoU:")
    for i, iou in enumerate(class_iou):
        logging.info(f"Class {i}: {iou:.4f}")

    # 可视化每个类别的IoU
    plt.figure(figsize=(12, 6))
    plt.bar(range(num_classes), class_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    plt.savefig(os.path.join(save_dir, 'per_class_iou.png'))
    plt.close()

2024-07-22 02:58:18,353 - INFO - Using device: cuda


2024-07-22 02:58:20,109 - INFO - Loaded model from model_checkpoints/DeepLabV3 ResNet 101/best_model_epoch_20.pth
Testing: 100%|██████████| 35/35 [03:27<00:00,  5.94s/it]
2024-07-22 03:01:47,972 - INFO - Test Results:
2024-07-22 03:01:47,973 - INFO - Mean IoU: 0.4199
2024-07-22 03:01:47,974 - INFO - Pixel Accuracy: 0.8244
2024-07-22 03:01:47,975 - INFO - Dice Coefficient: 0.7097
2024-07-22 03:01:47,976 - INFO - Per-class IoU:
2024-07-22 03:01:47,977 - INFO - Class 0: 0.8110
2024-07-22 03:01:47,978 - INFO - Class 1: nan
2024-07-22 03:01:47,978 - INFO - Class 2: nan
2024-07-22 03:01:47,979 - INFO - Class 3: nan
2024-07-22 03:01:47,982 - INFO - Class 4: nan
2024-07-22 03:01:47,983 - INFO - Class 5: 0.4309
2024-07-22 03:01:47,984 - INFO - Class 6: 0.7979
2024-07-22 03:01:47,985 - INFO - Class 7: nan
2024-07-22 03:01:47,985 - INFO - Class 8: nan
2024-07-22 03:01:47,988 - INFO - Class 9: nan
2024-07-22 03:01:47,988 - INFO - Class 10: nan
2024-07-22 03:01:47,989 - INFO - Class 11: nan
2024-07

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
import seaborn as sns
from data_load import EnhancedWildScenesDataset
from models.custom_deeplabv3 import CustomDeepLabV3
from sklearn.metrics import confusion_matrix
import csv
import pandas as pd

# 定义颜色映射
color_map = {
    1: [224, 31, 77],  # Bush
    0: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    14: [127, 127, 127],  # Grass
    3: [145, 24, 178],  # Gravel
    13: [125, 128, 16],  # Log
    12: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    16: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    6: [208, 245, 71],  # Tree-foliage
    5: [238, 47, 227],  # Tree-trunk
    4: [40, 127, 198],  # Water
    15: [0, 0, 0],  # 背景类（黑色）
    10: [128, 128, 128],  # 忽略类（灰色）
}

class_names = ['Dirt', 'Bush', 'Fence', 'Gravel', 'Water', 'Tree-trunk', 'Tree-Foliage', 'Other-object',
               'Other-terrain', 'Rock', 'Ignore', 'Structure', 'Mud', 'Log', 'Grass', 'Background', 'Sky']


def apply_color_map(segmentation):
    color_segmentation = np.zeros((*segmentation.shape, 3), dtype=np.uint8)
    for class_idx, color in color_map.items():
        color_segmentation[segmentation == class_idx] = color
    return color_segmentation


def overlay_segmentation(image, segmentation, alpha=0.5):
    colored_seg = apply_color_map(segmentation)
    return (image * (1 - alpha) + colored_seg * alpha).astype(np.uint8)


def save_comparison(image, ground_truth, prediction, index, save_dir):
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))

    ax1.imshow(image.permute(1, 2, 0).cpu().numpy())
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2.imshow(apply_color_map(ground_truth.cpu().numpy()))
    ax2.set_title('Ground Truth')
    ax2.axis('off')

    ax3.imshow(apply_color_map(prediction.cpu().numpy()))
    ax3.set_title('Prediction')
    ax3.axis('off')

    ax4.imshow(overlay_segmentation(image.permute(1, 2, 0).cpu().numpy(), prediction.cpu().numpy()))
    ax4.set_title('Overlay')
    ax4.axis('off')

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'comparison_{index}.png'))
    plt.close()


def plot_confusion_matrix(y_true, y_pred, save_path):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(save_path)
    plt.close()


def plot_per_class_iou(class_iou, save_path):
    plt.figure(figsize=(12, 6))
    valid_iou = [iou for iou in class_iou if not np.isnan(iou)]
    plt.bar(range(len(valid_iou)), valid_iou)
    plt.title('Per-class IoU')
    plt.xlabel('Class')
    plt.ylabel('IoU')
    valid_class_names = [name for i, name in enumerate(class_names) if i < len(valid_iou)]
    plt.xticks(range(len(valid_iou)), valid_class_names, rotation=90)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def calculate_miou(pred, target, num_classes):
    ious = []
    pred = pred.ravel()
    target = target.ravel()
    for cls in range(num_classes):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = np.logical_and(pred_inds, target_inds).sum()
        union = np.logical_or(pred_inds, target_inds).sum()
        if union == 0:
            ious.append(float('nan'))  # 如果该类别不存在，则IoU为NaN
        else:
            ious.append(intersection / union)
    miou = np.nanmean(ious)  # 忽略NaN值计算平均IoU
    return miou, np.array(ious)


def visualize_results(model, dataloader, device, num_classes, save_dir):
    model.eval()
    all_class_ious = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)
    image_ious = []

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(dataloader, desc="Visualizing")):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            pred = torch.argmax(outputs, dim=1)

            # 计算每张图片的IoU
            for j in range(images.shape[0]):
                image_index = i * dataloader.batch_size + j
                _, class_ious = calculate_miou(pred[j].cpu().numpy(), labels[j].cpu().numpy(), num_classes)
                all_class_ious += np.nan_to_num(class_ious)  # 将NaN转换为0
                class_counts += ~np.isnan(class_ious)  # 统计非NaN值的数量
                image_ious.append(class_ious)

                # 保存每张图片的可视化结果
                save_comparison(images[j], labels[j], pred[j], image_index, save_dir)

    # 计算平均IoU
    average_ious = np.where(class_counts > 0, all_class_ious / class_counts, 0)

    print("Class IoUs:")
    for cls, iou in enumerate(average_ious):
        print(f"Class {cls} ({class_names[cls]}): {iou:.4f}")

    # 绘制每类IoU图
    plot_per_class_iou(average_ious, os.path.join(save_dir, 'per_class_iou.png'))

    # 计算mIoU
    miou = np.mean(average_ious)
    print(f"Mean IoU: {miou:.4f}")

    # 创建并保存CSV文件
    csv_file = os.path.join(save_dir, 'image_class_ious.csv')
    df = pd.DataFrame(image_ious, columns=[f"{cls}_{name}" for cls, name in enumerate(class_names)])
    df.index.name = 'Image_Number'
    df.to_csv(csv_file)
    print(f"Image-wise class IoUs saved to {csv_file}")

    return average_ious, miou


if __name__ == "__main__":
    # 设置保存目录
    save_dir = 'visualization_results/DeepLabV3_Resnet101'
    os.makedirs(save_dir, exist_ok=True)

    # 加载模型和数据
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 17

    # 创建模型实例
    model = CustomDeepLabV3(num_classes=num_classes).to(device)

    # 加载模型权重
    model_path = os.path.join('model_checkpoints', 'DeepLabV3 ResNet 101', 'best_model_epoch_20.pth')
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # 确保模型处于评估模式
    model.eval()

    # 加载测试数据
    test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    # 打印一些样本数据
    for images, labels in test_loader:
        print("Sample labels shape:", labels.shape)
        print("Unique values in sample labels:", torch.unique(labels))
        break

    # 运行可视化并获取结果
    class_ious, miou = visualize_results(model, test_loader, device, num_classes, save_dir)

    # 保存结果到文件
    results_file = os.path.join(save_dir, 'iou_results.txt')
    with open(results_file, 'w') as f:
        f.write(f"Mean IoU: {miou:.4f}\n\n")
        f.write("Class IoUs:\n")
        for cls, iou in enumerate(class_ious):
            f.write(f"Class {cls} ({class_names[cls]}): {iou:.4f}\n")

    print(f"Visualization results and IoU statistics saved in {save_dir}")


Sample labels shape: torch.Size([8, 256, 256])
Unique values in sample labels: tensor([ 0,  1,  2,  5,  6,  7, 11, 12, 13, 14, 16])


Visualizing:   0%|          | 0/35 [00:00<?, ?it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Visualizing:   3%|▎         | 1/35 [00:09<05:31,  9.75s/it]Clipping i

Class IoUs:
Class 0 (Dirt): 0.7203
Class 1 (Bush): 0.0000
Class 2 (Fence): 0.4876
Class 3 (Gravel): 0.0921
Class 4 (Water): 0.0000
Class 5 (Tree-trunk): 0.3720
Class 6 (Tree-Foliage): 0.7955
Class 7 (Other-object): 0.0123
Class 8 (Other-terrain): 0.0000
Class 9 (Rock): 0.1165
Class 10 (Ignore): 0.0000
Class 11 (Structure): 0.2892
Class 12 (Mud): 0.1223
Class 13 (Log): 0.0530
Class 14 (Grass): 0.6002
Class 15 (Background): 0.0000
Class 16 (Sky): 0.4857
Mean IoU: 0.2439
Image-wise class IoUs saved to visualization_results/DeepLabV3_Resnet101/image_class_ious.csv
Visualization results and IoU statistics saved in visualization_results/DeepLabV3_Resnet101
