In [1]:
import pandas as pd
import os
import sklearn
from PIL import Image
import numpy as np
import logging

In [2]:
class WildScenesDataset:
    _data_list_dir = os.path.join('datasets', 'data_list')
    _csv_files = {
        'train': os.path.join(_data_list_dir, 'train.csv'),
        'valid': os.path.join(_data_list_dir, 'valid.csv'),
        'test': os.path.join(_data_list_dir, 'test.csv'),
    }
    csv = _csv_files
    _label_to_trainid = {
        1: 15,  # Ignored class
        2: 0,  # Bush
        3: 1,  # Dirt
        4: 2,  # Fence
        5: 3,  # Grass
        6: 4,  # Gravel
        7: 5,  # Log
        8: 6,  # Mud
        9: 7,  # Other-Object
        10: 8,  # Other-terrain
        11: 15,  # Ignored class
        12: 9,  # Rock
        13: 10,  # Sky
        14: 11,  # Structure
        15: 12,  # Tree-foliage
        16: 13,  # Tree-trunk
        17: 15,  # Ignored class
        18: 14,  # Water
    }
    def __init__(self, dataset_type, transform=None):
        assert dataset_type in ('train', 'valid', 'test')
        self._dataset_type = dataset_type
        self._data_frame = pd.read_csv(WildScenesDataset._csv_files[self._dataset_type])
        self._transform = transform

    def __len__(self):
        return len(self._data_frame)

    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f"Index {index} out of bounds for dataset of size {len(self)}")
        try:
            image_path = self._data_frame['image'].iloc[index]
            label_path = self._data_frame['label'].iloc[index]

            image = Image.open(image_path).convert('RGB')
            label = Image.open(label_path).convert('L')
            # 将灰度图转换为Numpy数组
            label_np = np.array(label)

            # 将标签索引映射到训练标识（trainId）值
            label_trainId = np.vectorize(lambda x: self._label_to_trainid.get(x, 255))(label_np)

            if self._transform is not None:
                for t in self._transform:
                    image, label_trainId = t(image, label_trainId)

            return image, label_trainId
        except Exception as e:
            logging.error(f"Error loading item at index {index}: {str(e)}")

    @staticmethod
    def _get_image_label_dir():
        """
        Traverse server image and label directories and yield image and label paths.
        :return: Generator yielding (image path, label path)
        """
        data_err = 'data error. check!'
        image_base = WildScenesDataset.image_file_base
        label_base = WildScenesDataset.label_file_base

        for image in os.listdir(image_base):
            image_origin = os.path.join(image_base, image)
            image_label = os.path.join(label_base, image)

            if not (os.path.isfile(image_label) and
                    os.path.exists(image_label) and
                    os.path.isfile(image_label)):
                print(image_origin, image_label, data_err)  # Print error message and skip if paths are invalid
                continue

            yield image_origin, image_label

    @staticmethod
    def make_data_list(train_rate=0.7, valid_rate=0.2, shuffle=True):
        """
        Shuffle and generate data_list CSV files with image and label paths sorted by filename.
        :param train_rate: Training set ratio, default 0.7
        :param valid_rate: Validation set ratio, default 0.2
        :param shuffle: Whether to shuffle the dataset, default True
        :return: None
        """
        g = WildScenesDataset._get_image_label_dir()  # Get generator
        abspaths = list(g)  # Convert generator to list

        # Create DataFrame with image and label paths
        df = pd.DataFrame(
            data=abspaths,
            columns=['image', 'label']
        )

        # Sort DataFrame by filename (assumed to be timestamp in a sortable format)
        df['timestamp'] = df['image'].apply(lambda x: int(os.path.splitext(os.path.basename(x))[0].split('-')[0]))
        df = df.sort_values(by='timestamp').reset_index(drop=True)

        if shuffle:
            df = sklearn.utils.shuffle(df)  # Shuffle dataframe if specified

        # Calculate sizes for train, valid, and test sets
        train_size = int(df.shape[0] * train_rate)
        valid_size = int(df.shape[0] * valid_rate)

        print('total: {:d} | train: {:d} | val: {:d} | test: {:d}'.format(
            df.shape[0], train_size, valid_size,
            df.shape[0] - train_size - valid_size))

        # Split dataframe into train, valid, and test sets
        df_train = df[0: train_size]
        df_valid = df[train_size: train_size + valid_size]
        df_test = df[train_size + valid_size:]

        # Save train, valid, and test sets to CSV files
        df_train[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['train']), index=False)
        df_valid[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['valid']), index=False)
        df_test[['image', 'label']].to_csv(os.path.join(WildScenesDataset.csv['test']), index=False)

    # 测试語義分割時用
    @staticmethod
    def test_label_mapping(label_path):
        """
        Test the label mapping for a single label image.

        :param label_path: Path to the label image
        :return: Tuple of original label numpy array and mapped trainId numpy array
        """
        # Open the label image and convert to numpy array
        label = Image.open(label_path).convert('L')
        label_np = np.array(label, dtype=np.uint8)

        # Map label indices to trainId values
        label_trainId = np.vectorize(lambda x: WildScenesDataset._label_to_trainid.get(x, 255))(label_np)

        return label_np, label_trainId


In [3]:
# Example usage
root_dir = os.path.join('..', 'WildScenes_Dataset-61gd5a0t-', 'data', 'WildScenes', 'WildScenes2d', 'V-01')
WildScenesDataset.image_file_base = os.path.join(root_dir, 'image')
WildScenesDataset.label_file_base = os.path.join(root_dir, 'indexLabel')
WildScenesDataset.make_data_list()

# Test label mapping for the first image in labelIndex
test_label_path = os.path.join(WildScenesDataset.label_file_base,
                               '1623370408-092005506.png')  # Replace with an actual image name
original_label, mapped_label = WildScenesDataset.test_label_mapping(test_label_path)

print("Original label array (shape: {}):".format(original_label.shape))
print(original_label)
print("\nMapped trainId array (shape: {}):".format(mapped_label.shape))
print(mapped_label)

# Optional: print unique values in each array
print("\nUnique values in original label array:", np.unique(original_label))
print("Unique values in mapped trainId array:", np.unique(mapped_label))

total: 1576 | train: 1103 | val: 315 | test: 158
Original label array (shape: (1512, 2016)):
[[ 8  8  8 ...  8  8  8]
 [ 8  8  8 ...  8  8  8]
 [ 8  8  8 ...  8  8  8]
 ...
 [18 18 18 ...  2  2  2]
 [18 18 18 ...  2  2  2]
 [18 18 18 ...  2  2  2]]

Mapped trainId array (shape: (1512, 2016)):
[[ 6  6  6 ...  6  6  6]
 [ 6  6  6 ...  6  6  6]
 [ 6  6  6 ...  6  6  6]
 ...
 [14 14 14 ...  0  0  0]
 [14 14 14 ...  0  0  0]
 [14 14 14 ...  0  0  0]]

Unique values in original label array: [ 2  7  8 14 15 16 17 18]
Unique values in mapped trainId array: [ 0  5  6 11 12 13 14 15]


In [10]:
import random
import torch
import numpy as np
from torchvision.transforms import functional as TF
import cv2
from PIL import Image

class TrainTransform:
    def __init__(self, size=256, gaussian_prob=0.5, gaussian_kernel=(5, 5), gaussian_sigma=(0.1, 2.0)):
        self.size = (size, size)  # Changed to tuple
        self.gaussian_prob = gaussian_prob
        self.gaussian_kernel = gaussian_kernel
        self.gaussian_sigma = gaussian_sigma

    def __call__(self, image, label):
        # Resize
        image = TF.resize(image, self.size)
        label = TF.resize(label, self.size, interpolation=TF.InterpolationMode.NEAREST)

        # Random horizontal flip
        if random.random() > 0.5:
            image = TF.hflip(image)
            label = TF.hflip(label)

        # Random rotation
        angle = random.uniform(-10, 10)
        image = TF.rotate(image, angle)
        label = TF.rotate(label, angle, interpolation=TF.InterpolationMode.NEAREST)

        # Random Gaussian blur
        if random.random() < self.gaussian_prob:
            sigma = random.uniform(self.gaussian_sigma[0], self.gaussian_sigma[1])
            image_np = np.array(image)
            image_np = cv2.GaussianBlur(image_np, self.gaussian_kernel, sigma)
            image = Image.fromarray(image_np)

        # Convert to tensor and normalize
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        label = torch.from_numpy(np.array(label)).long()

        # Verify shapes
        assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        assert image.shape[1] == image.shape[2] == self.size[0], f"Image should be square with size {self.size[0]}, got shape {image.shape}"
        assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

class TestTransform:
    def __init__(self, size=256):
        self.size = (size, size)  # Changed to tuple

    def __call__(self, image, label):
        # Resize
        image = TF.resize(image, self.size)
        label = TF.resize(label, self.size, interpolation=TF.InterpolationMode.NEAREST)

        # Convert to tensor and normalize
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        label = torch.from_numpy(np.array(label)).long()

        # Verify shapes
        assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        assert image.shape[1] == image.shape[2] == self.size[0], f"Image should be square with size {self.size[0]}, got shape {image.shape}"
        assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

In [9]:
import os
import sys
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# 假设这些导入是正确的，如果不是，请相应调整

# from transforms import TrainTransform, TestTransform

color_map = {
    0: [224, 31, 77],  # Bush
    1: [64, 180, 78],  # Dirt
    2: [26, 127, 127],  # Fence
    3: [127, 127, 127],  # Grass
    4: [145, 24, 178],  # Gravel
    5: [125, 128, 16],  # Log
    6: [251, 225, 48],  # Mud
    7: [248, 190, 190],  # Other-object
    8: [89, 239, 239],  # Other-terrain
    9: [173, 255, 196],  # Rock
    10: [19, 0, 126],  # Sky
    11: [167, 110, 44],  # Structure
    12: [208, 245, 71],  # Tree-foliage
    13: [238, 47, 227],  # Tree-trunk
    14: [40, 127, 198],  # Water
}

class EnhancedWildScenesDataset(WildScenesDataset):
    def __init__(self, dataset_type, transform=None):
        super().__init__(dataset_type, transform)
        self.color_map = self._load_color_map()
        self.transform = self._get_transform(dataset_type)

    def __getitem__(self, index):
        image_path = self._data_frame['image'][index]
        label_path = self._data_frame['label'][index]

        image = Image.open(image_path).convert('RGB')
        # Use test_label_mapping to get both original and mapped labels
        original_label, mapped_label = self.test_label_mapping(label_path)

        # Convert numpy array to PIL Image for compatibility with transforms
        label = Image.fromarray(mapped_label.astype(np.uint8))

        if self.transform is not None:
            image, label = self.transform(image, label)

        # Verify shapes
        # assert image.shape[0] == 3, f"Image should have 3 channels, got {image.shape[0]}"
        # assert image.shape[1] == image.shape[2], f"Image should be square, got shape {image.shape}"
        # assert label.shape == image.shape[1:], f"Label shape {label.shape} doesn't match image shape {image.shape[1:]}"

        return image, label

    def _load_color_map(self):
        return {key: np.array(value) for key, value in color_map.items()}

    def _get_transform(self, dataset_type):
        if dataset_type == 'train':
            return TrainTransform()
        elif dataset_type in ['valid', 'test']:
            return TestTransform()
        else:
            raise ValueError('Invalid dataset type')

    @staticmethod
    def get_color_coded_label(label_trainId):
        """
        Convert trainId label to RGB color-coded label.
        :param label_trainId: numpy array of trainId labels
        :return: numpy array of RGB color-coded labels
        """
        height, width = label_trainId.shape
        label_RGB = np.zeros((height, width, 3), dtype=np.uint8)
        for trainId, color in color_map.items():
            label_RGB[label_trainId == trainId] = color
        return label_RGB

def get_data_loader(dataset_type, batch_size=4):
    dataset = EnhancedWildScenesDataset(dataset_type)
    shuffle = dataset_type == 'train'
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, pin_memory=True, drop_last=True)

# 测试data_loader
train_loader = get_data_loader('train', batch_size=4)
for images, labels in train_loader:
    print(f"Batch image shape: {images.shape}") # 一個批次的圖像數據形狀，一個批次4個圖象，每個圖象3通道，每個圖像尺寸256*341
    print(f"Batch label shape: {labels.shape}") # 一個批次的label數據形狀，單通道，表示的是trainId標注的
    print(f"Batch label unique values: {torch.unique(labels)}")
    break

dataset = EnhancedWildScenesDataset('train')
image, label = dataset[0]
color_coded_label = dataset.get_color_coded_label(label.numpy())
print(f"Color-coded label shape: {color_coded_label.shape}") # 使用顔色編碼的label圖像形狀，發現是3通道，成了！

Batch image shape: torch.Size([4, 3, 256, 256])
Batch label shape: torch.Size([4, 256, 256])
Batch label unique values: tensor([ 0,  3,  5,  6,  9, 12, 13, 14, 15])
Color-coded label shape: (256, 256, 3)


In [11]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
from data_load import EnhancedWildScenesDataset
from tqdm import tqdm
import numpy as np
import os
import logging
from utils.metrics import calculate_miou_train, calculate_pixel_accuracy, calculate_dice_coefficient
from utils.losses import CombinedLoss
from models.custom_deeplabv3 import CustomDeepLabV3
from models.dense_unet import DenseUNet, TransitionUp, DenseBlock
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, num_classes, scaler):
    model.train()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0
    num_batches = 0

    for images, labels in tqdm(dataloader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        total_loss += loss.item()
        pred = torch.argmax(outputs, dim=1)
        miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
        pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
        dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

        if not np.isnan(miou):
            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice
            num_batches += 1

    return (total_loss / len(dataloader),
            total_miou / num_batches if num_batches > 0 else 0.0,
            total_pixel_acc / num_batches if num_batches > 0 else 0.0,
            total_dice / num_batches if num_batches > 0 else 0.0)


def validate_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    total_loss = 0
    total_miou = 0
    total_pixel_acc = 0
    total_dice = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = outputs['out'] if isinstance(outputs, dict) else outputs

            if len(labels.shape) == 4 and labels.shape[1] > 1:
                labels = torch.argmax(labels, dim=1)

            if outputs.shape[2:] != labels.shape[1:]:
                outputs = F.interpolate(outputs, size=labels.shape[1:], mode='bilinear', align_corners=True)

            loss = criterion(outputs, labels)

            total_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            miou = calculate_miou_train(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)
            pixel_acc = calculate_pixel_accuracy(pred.cpu().numpy(), labels.cpu().numpy())
            dice = calculate_dice_coefficient(pred.cpu().numpy(), labels.cpu().numpy(), num_classes)

            total_miou += miou
            total_pixel_acc += pixel_acc
            total_dice += dice

    num_batches = len(dataloader)
    return (total_loss / num_batches,
            total_miou / num_batches,
            total_pixel_acc / num_batches,
            total_dice / num_batches)


def save_checkpoint(model, optimizer, epoch, metrics, filename):
    # 保存检查点
    state = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'metrics': metrics
    }
    torch.save(state, filename)


def setup_logger(log_file):
    # 设置日志记录器
    logging.basicConfig(filename=log_file, level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)


def train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir, num_classes):
    # 训练模型的主函数
    best_miou = 0
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    scaler = GradScaler()

    for epoch in range(num_epochs):
        current_epoch = epoch + 1
        logging.info(f"Epoch {current_epoch}/{num_epochs}")

        train_loss, train_miou, train_pixel_acc, train_dice = train_epoch(
            model, train_loader, criterion, optimizer, scheduler, device, num_classes, scaler)
        logging.info(f"Epoch {current_epoch} - Train Loss: {train_loss:.4f}, Train mIoU: {train_miou:.4f}, "
                     f"Train Pixel Acc: {train_pixel_acc:.4f}, Train Dice: {train_dice:.4f}")

        val_loss, val_miou, val_pixel_acc, val_dice = validate_epoch(
            model, val_loader, criterion, device, num_classes)
        logging.info(f"Epoch {current_epoch} - Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}, "
                     f"Val Pixel Acc: {val_pixel_acc:.4f}, Val Dice: {val_dice:.4f}")

        metrics = {
            'miou': val_miou,
            'pixel_acc': val_pixel_acc,
            'dice': val_dice
        }

        if val_miou > best_miou:
            best_miou = val_miou
            best_model_path = os.path.join(save_dir, f'best_model_epoch_{current_epoch}.pth')
            save_checkpoint(model, optimizer, current_epoch, metrics, best_model_path)
            logging.info(f"Epoch {current_epoch} - Best model saved with mIoU: {best_miou:.4f}")

        if current_epoch % 5 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{current_epoch}.pth')
            save_checkpoint(model, optimizer, current_epoch, metrics, checkpoint_path)
            logging.info(f"Epoch {current_epoch} - Checkpoint saved")

        current_lr = optimizer.param_groups[0]['lr']
        logging.info(f"Current learning rate: {current_lr:.6f}")

    logging.info(f"Training completed after {num_epochs} epochs.")
    return best_model_path

if __name__ == "__main__":
    setup_logger('training.log')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")

    train_loader = EnhancedWildScenesDataset.get_data_loader('train', batch_size=8)
    val_loader = EnhancedWildScenesDataset.get_data_loader('valid', batch_size=8)

    # test_loader = EnhancedWildScenesDataset.get_data_loader('test', batch_size=8)

    num_classes = 18

    # 选择模型
    model = CustomDeepLabV3(num_classes=num_classes).to(device)

    # 选择损失函数
    # criterion = FocalLoss(alpha=1, gamma=2)
    criterion = CombinedLoss(weight_focal=1.0, weight_dice=0.5)

    # 选择优化器
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001, nesterov=True)
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

    num_epochs = 60
    steps_per_epoch = len(train_loader)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=0.01,
        steps_per_epoch=steps_per_epoch,
        epochs=num_epochs,
        pct_start=0.3,
        anneal_strategy='cos',
        div_factor=25,
        final_div_factor=1000
    )

    save_dir = 'model_checkpoints'
    best_model_path = train(model, train_loader, val_loader, criterion, optimizer, scheduler,
                            num_epochs, device, save_dir, num_classes)

    logging.info("Training and prediction completed!")


2024-07-15 03:23:44,834 - INFO - Using device: cuda
2024-07-15 03:23:45,891 - INFO - Epoch 1/60
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 137/137 [03:18<00:00,  1.45s/it]
2024-07-15 03:27:04,110 - INFO - Epoch 1 - Train Loss: 0.8528, Train mIoU: 0.5948, Train Pixel Acc: 0.7644, Train Dice: 0.6334
Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:52<00:00,  1.34s/it]
2024-07-15 03:27:56,492 - INFO - Epoch 1 - Val Loss: 0.8528, Val mIoU: 0.6205, Val Pixel Acc: 0.7122, Val Dice: 0.6636
2024-07-15 03:27:57,102 - INFO - Epoch 1 - Best model saved with mIoU: 0.6205
2024-07-15 03:27:57,103 - INFO - Current learning rate: 0.000473
2024-07-15 03:27:57,104 - INFO - Epoch 2/60
Training: 100%|████████████████████████████████████████████████