In [35]:
import os
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models
from PIL import Image
import numpy as np

In [36]:
# --------------------------- Dataset ---------------------------
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = transforms.Compose([
              transforms.Resize((256, 256), interpolation=Image.NEAREST),
              transforms.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
            ])(mask)
        
        return image, mask
    


In [37]:

# --------------------------- ASPP Module ---------------------------
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.atrous_block1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.atrous_block6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=6, dilation=6)
        self.atrous_block12 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=12, dilation=12)
        self.atrous_block18 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=18, dilation=18)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
        )
        self.conv1x1 = nn.Conv2d(out_channels * 5, out_channels, kernel_size=1)

    def forward(self, x):
        size = x.shape[2:]
        out1 = self.atrous_block1(x) ## 1x1 Conv
        out2 = self.atrous_block6(x) ## 3x3 Conv Padding = 6
        out3 = self.atrous_block12(x) ## 3x3 Conv Padding = 12
        out4 = self.atrous_block18(x) ## 3x3 Conv Padding = 18
        out5 = self.global_avg_pool(x) ## 3x3 Conv Padding = 18
        out5 = F.interpolate(out5, size=size, mode='bilinear', align_corners=True)
        out = torch.cat([out1, out2, out3, out4, out5], dim=1)
        return self.conv1x1(out)


In [38]:

# --------------------------- Segmentation Model ---------------------------
class SegNet(nn.Module):
    def __init__(self, num_classes):
        super(SegNet, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])
        self.aspp = ASPP(in_channels=2048, out_channels=256)
        self.low_level_conv = nn.Conv2d(256, 48, kernel_size=1)

        self.decoder = nn.Sequential(
            nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        ## x.shape = torch.Size([4, 1, 256, 256])
        shallow = self.encoder[4](self.encoder[3](self.encoder[2](self.encoder[1](self.encoder[0](x)))))
        x = self.encoder(x)
        x = self.aspp(x) # torch.Size([1, 256, 8, 8])
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
        shallow = self.low_level_conv(shallow)
        shallow = F.interpolate(shallow, size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x, shallow], dim=1)
        return self.decoder(x)


In [39]:

# --------------------------- Dice + CE Loss ---------------------------
class DiceCELoss(nn.Module):
    def __init__(self, weight_dice=1.0, weight_ce=1.0, smooth=1.0):
        super(DiceCELoss, self).__init__()
        self.weight_dice = weight_dice
        self.weight_ce = weight_ce
        self.smooth = smooth
        self.ce = nn.CrossEntropyLoss()

    def forward(self, preds, targets):
        ce_loss = self.ce(preds, targets)
        dice_loss = self._dice_loss(preds, targets)
        return self.weight_ce * ce_loss + self.weight_dice * dice_loss

    def _dice_loss(self, preds, targets):
        probs = torch.softmax(preds, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=preds.shape[1]).permute(0, 3, 1, 2).float()
        dims = (0, 2, 3)
        intersection = torch.sum(probs * targets_one_hot, dims)
        cardinality = torch.sum(probs + targets_one_hot, dims)
        dice_score = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1 - dice_score.mean()


In [40]:

# --------------------------- Training and Validation ---------------------------
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for imgs, masks in loader:
        imgs, masks = imgs.cuda(), masks.cuda()
        preds = model(imgs) ## preds:torch.Size([4, 2, 128, 128])  masks:torch.Size([4, 256, 256]
        
        loss = criterion(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader) # average_loss

def validate_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.cuda(), masks.cuda()
            preds = model(imgs)
            loss = criterion(preds, masks)
            total_loss += loss.item()
    return total_loss / len(loader)


In [41]:
data_path = "/root/workspace/ssd/txj_workspace/hawkyu/hawkyu/remote_sence/dataset/data"
img_path = data_path + '/JPEGImages'
mask_path = data_path + '/Annotations'
models_path = data_path + '/Models'

In [42]:

# --------------------------- Train Script ---------------------------
def train():
    image_dir = img_path
    mask_dir = mask_path
    epoch_num = 20

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])


    dataset = SegmentationDataset(image_dir, mask_dir, transform = transform)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False)


    ## 掩码背景全为零

    model = SegNet(num_classes=5).cuda()
    criterion = DiceCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(1, epoch_num):
        train_loss = train_epoch(model, train_loader, optimizer, criterion)
        val_loss = validate_epoch(model, val_loader, criterion)

        print(f"Epoch {epoch}: Train Loss={train_loss:.4f} | Val Loss={val_loss:.4f}")

        torch.save(model.state_dict(), f"model_epoch{epoch}.pth")
    

    save_path = os.path.join(models_path, "Train_01.pth")
    torch.save(model.state_dict(), save_path)
    print("Model saved successfully!")
    
    


In [None]:
train()

In [None]:
print(os.getcwd())

# Evaluation

## calcuate miou

In [43]:
def calculate_mIoU(preds, labels, num_classes):
    """
    计算 mean IoU
    preds: 预测结果，Tensor，形状为 [B, H, W]，取值为类别索引
    labels: 真实标签，Tensor，形状为 [B, H, W]，取值为类别索引
    num_classes: 类别数
    """
    ious = []
    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (labels == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()

        if union == 0:
            ious.append(float('nan'))  # 当前类别在这张图中没有出现
        else:
            ious.append(intersection / union)
    # 返回所有类别的 mIoU（排除 nan）
    return np.nanmean(ious)


In [44]:
@torch.no_grad()
def evaluate_model(model, dataloader, num_classes=5, max_samples=10):
    model.eval()
    total_miou = 0
    count = 0

    for i, (images, masks) in enumerate(dataloader):
        if i >= max_samples:  # 只验证前 max_samples 个样本
            break

        images = images.cuda()
        masks = masks.cuda()

        outputs = model(images)  # [B, C, H, W]
        preds = torch.argmax(outputs, dim=1)  # [B, H, W]

        for b in range(images.size(0)):
            miou = calculate_mIoU(preds[b], masks[b], num_classes)
            print(f"Image {count + 1} mIoU: {miou:.4f}")
            total_miou += miou
            count += 1

    avg_miou = total_miou / count
    print(f"\n✅ Average mIoU over {count} samples: {avg_miou:.4f}")



In [47]:
## miou作为验证指标
def eval_miou():
    # 修改为你的路径
    model_path = "Train_01.pth"
    image_dir = img_path
    mask_dir = mask_path
    num_classes = 5

    # 加载模型
    model = SegNet(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    model = model.cuda()

    # 数据变换
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    # 数据加载
    dataset = SegmentationDataset(image_dir, mask_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # 验证模型
    evaluate_model(model, dataloader, num_classes=num_classes, max_samples=1000)

In [48]:
eval_miou()


  model.load_state_dict(torch.load(model_path))


Image 1 mIoU: 0.9464
Image 2 mIoU: 0.9556
Image 3 mIoU: 0.9125
Image 4 mIoU: 0.9487
Image 5 mIoU: 0.6716
Image 6 mIoU: 0.9251
Image 7 mIoU: 0.9563
Image 8 mIoU: 0.9284
Image 9 mIoU: 0.9489
Image 10 mIoU: 0.9614
Image 11 mIoU: 0.9368
Image 12 mIoU: 0.9299
Image 13 mIoU: 0.9159
Image 14 mIoU: 0.9160
Image 15 mIoU: 0.9412
Image 16 mIoU: 0.9023
Image 17 mIoU: 0.9504
Image 18 mIoU: 0.9043
Image 19 mIoU: 0.9540
Image 20 mIoU: 0.6900
Image 21 mIoU: 0.9581
Image 22 mIoU: 0.9573
Image 23 mIoU: 0.9539
Image 24 mIoU: 0.9384
Image 25 mIoU: 0.9078
Image 26 mIoU: 0.9653
Image 27 mIoU: 0.9548
Image 28 mIoU: 0.6086
Image 29 mIoU: 0.9146
Image 30 mIoU: 0.9264
Image 31 mIoU: 0.9560
Image 32 mIoU: 0.9611
Image 33 mIoU: 0.9421
Image 34 mIoU: 0.9292
Image 35 mIoU: 0.8882
Image 36 mIoU: 0.9600
Image 37 mIoU: 0.9381
Image 38 mIoU: 0.9237
Image 39 mIoU: 0.9108
Image 40 mIoU: 0.9211
Image 41 mIoU: 0.9212
Image 42 mIoU: 0.9215
Image 43 mIoU: 0.9198
Image 44 mIoU: 0.9585
Image 45 mIoU: 0.9253
Image 46 mIoU: 0.94

## calcuate confusion martix

In [None]:
import numpy as np
import torch

def compute_confusion_matrix(preds, labels, num_classes):
    """
    preds: Tensor[B, H, W] - 预测标签
    labels: Tensor[B, H, W] - 真实标签
    return: (num_classes, num_classes) 混淆矩阵
    """
    preds = preds.view(-1).cpu().numpy()
    labels = labels.view(-1).cpu().numpy()
    mask = (labels >= 0) & (labels < num_classes)
    conf_matrix = np.bincount(
        num_classes * labels[mask] + preds[mask],
        minlength=num_classes**2
    ).reshape(num_classes, num_classes)
    return conf_matrix


In [None]:
@torch.no_grad()
def evaluate_model_confusion(model, dataloader, num_classes=5, max_samples=None):
    model.eval()
    total_conf_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)
    count = 0

    for i, (images, masks) in enumerate(dataloader):
        if max_samples is not None and count >= max_samples:
            break

        images = images.cuda()
        masks = masks.cuda()

        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)

        conf = compute_confusion_matrix(preds, masks, num_classes)
        total_conf_matrix += conf
        count += images.size(0)

    print("✅ 混淆矩阵（行=真实，列=预测）:")
    print(total_conf_matrix)

    # 可选：计算mIoU
    ious = []
    for i in range(num_classes):
        tp = total_conf_matrix[i, i]
        fn = total_conf_matrix[i, :].sum() - tp
        fp = total_conf_matrix[:, i].sum() - tp
        denom = tp + fp + fn
        if denom == 0:
            iou = np.nan
        else:
            iou = tp / denom
        ious.append(iou)
        print(f"Class {i}: IoU = {iou:.4f}")

    mean_iou = np.nanmean(ious)
    print(f"✅ Evaluated {count} image(s) for confusion matrix and mIoU.")
    print(f"\n✅ Mean IoU: {mean_iou:.4f}")
    

In [None]:
def eval_confusion_martix():
    # 修改为你的路径
    model_path = "Train_01.pth"
    image_dir = img_path
    mask_dir = mask_path
    num_classes = 5

    # 加载模型
    model = SegNet(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path))
    model = model.cuda()

    # 数据变换
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    # 数据加载
    dataset = SegmentationDataset(image_dir, mask_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # 验证模型
    evaluate_model_confusion(model, dataloader, num_classes=num_classes, max_samples=1000)

In [None]:
eval_confusion_martix()


  model.load_state_dict(torch.load(model_path))


✅ 混淆矩阵（行=真实，列=预测）:
[[46279618   301110   345828    24995        0]
 [  295013  8784001    44812        0        0]
 [  256040    54926  8923974        0        0]
 [   15588      710        8   209377        0]
 [       0        0        0        0        0]]
Class 0: IoU = 0.9739
Class 1: IoU = 0.9265
Class 2: IoU = 0.9271
Class 3: IoU = 0.8352
Class 4: IoU = nan
✅ Evaluated 1000 image(s) for confusion matrix and mIoU.

✅ Mean IoU: 0.9157


# Checkpoint