In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import glob
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision.models.segmentation.fcn import fcn_resnet50
from osgeo import gdal, ogr
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt

In [3]:
class CustomDataset(Dataset):
  '''
  自定义数据集类
  '''
  def __init__(self,images_dir,mask_dir, txt_path, transform=None):
    self.images_dir = images_dir
    self.mask_dir = mask_dir
    self.imagelist = self.get_imagelist(txt_path)
    self.transform = transform

  def get_imagelist(self,txt_path):
    with open(txt_path,'r') as f:
      imagelist = f.readlines()
    return imagelist

  def __len__(self):
    # 返回数据集的大小
    return len(self.imagelist)

  def __getitem__(self,index):
    # 获取单个图片
    img_path = os.path.join(self.images_dir,self.imagelist[index][:-1])
    mask_path = os.path.join(self.mask_dir,self.imagelist[index][:-1])
    raster_file1 = gdal.Open(img_path)
    raster_array1 = raster_file1.ReadAsArray()
    raster_file2 = gdal.Open(mask_path)
    raster_array2 = raster_file2.ReadAsArray()

    # x数据的处理
    img_data = np.zeros((5,raster_array1.shape[1],raster_array1.shape[2]))
    img_data[0] = raster_array2[0,:,:]
    img_data[1] = raster_array2[1,:,:]
    img_data[2] = raster_array1[2,:,:]
    img_data[3] = raster_array1[7,:,:]
    img_data[4] = raster_array1[10,:,:]
    img_data = np.nan_to_num(img_data)  # 用0填充img中的nan
    img_data[0:2] = (img_data[0:2] + 50) / 51  # 归一化
    img_data[2:] = img_data[2:] / 5000 # 归一化
    img = img_data.transpose(1,2,0) # channel first


    # y数据的处理

    mask_data = raster_array2[3] # 第四个波段label
    mask_data = np.nan_to_num(mask_data) # 用0填充mask中的nan
    mask = np.zeros([512,512])
    mask[mask_data == 0] = 0
    mask[mask_data == 1] = 0
    mask[mask_data == 2] = 1
    mask[mask_data == 3] = 1
    # 数据增强
    if self.transform:
      image = self.transform(img)
      mask = self.transform(mask)
    return image,mask

  def get_image_names(self, idx):
    return self.imagelist[idx]

In [4]:
# 数据强化
transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 创建数据集
img_path = '/content/drive/MyDrive/RS/HAND/S2_split/'
mask_path = '/content/drive/MyDrive/RS/HAND/S1_split/'
train_path = '/content/drive/MyDrive/RS/HAND/train.txt'
test_path = '/content/drive/MyDrive/RS/HAND/test.txt'

train_dataset = CustomDataset(img_path,mask_path,train_path,transform)
test_dataset = CustomDataset(img_path,mask_path,test_path,transform)

# 创建数据集加载器
train_loader = DataLoader(train_dataset,batch_size=4,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size=4,shuffle=False)

In [5]:
# 定义块
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1,
                      stride=1, padding_mode='reflect', bias=False),

            nn.BatchNorm2d(out_channels),
            # nn.InstanceNorm2d(out_channels),  # nnUNet

            nn.ReLU(inplace=True),
            # nn.LeakyReLU(),  # nnUNet

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1,
                      stride=1, padding_mode='reflect', bias=False),

            nn.BatchNorm2d(out_channels),
            # nn.InstanceNorm2d(out_channels),  # nnUNet

            nn.ReLU(inplace=True)
            # nn.LeakyReLU()  # nnUNet
        )

    def forward(self, x):
        return self.block(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.block = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.block(x)


In [6]:
# 定义UNet模型
class UNet(nn.Module):
    def __init__(self, in_channels=5, out_channels=2):
        super().__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.down1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = DoubleConv(64, 128)
        self.down2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = DoubleConv(128, 256)
        self.down3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = DoubleConv(256, 512)
        self.down4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_mid = DoubleConv(512, 1024)

        self.up1 = UpSample(1024, 512)
        self.conv5 = DoubleConv(1024,512)
        self.up2 = UpSample(512, 256)
        self.conv6 = DoubleConv(512,256)
        self.up3 = UpSample(256, 128)
        self.conv7 = DoubleConv(256,128)
        self.up4 = UpSample(128, 64)
        self.conv8 = DoubleConv(128,64)

        self.out_channel = nn.Conv2d(64, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        down1 = self.down1(conv1)
        conv2 = self.conv2(down1)
        down2 = self.down2(conv2)
        conv3 = self.conv3(down2)
        down3 = self.down3(conv3)
        conv4 = self.conv4(down3)
        down4 = self.down4(conv4)

        conv_mid = self.conv_mid(down4)

        up1 = self.up1(conv_mid)
        cat1 = torch.cat([up1, conv4], dim=1)
        down5  = self.conv5(cat1)

        up2 = self.up2(down5)
        cat2 = torch.cat([up2, conv3], dim=1)
        down6  = self.conv6(cat2)

        up3 = self.up3(down6)
        cat3 = torch.cat([up3, conv2], dim=1)
        down7  = self.conv7(cat3)

        up4 = self.up4(down7)
        cat4 = torch.cat([up4, conv1], dim=1)
        down8  = self.conv8(cat4)

        out_channel = self.out_channel(down8)

        return out_channel

In [7]:
model = UNet()
# 定义损失函数、优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [8]:
# 定义损失函数、优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [9]:
def computeIOU(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()

  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  intersection = torch.sum(output * target)
  union = torch.sum(target) + torch.sum(output) - intersection
  iou = (intersection + .0000001) / (union + .0000001)

  if iou != iou:
    print("failed, replacing with 0")
    iou = torch.tensor(0).float()

  return iou

def computeAccuracy(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()

  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  correct = torch.sum(output.eq(target))

  return correct.float() / len(target)

def truePositives(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  correct = torch.sum(output * target)

  return correct

def trueNegatives(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 0)
  target = (target == 0)
  correct = torch.sum(output * target)

  return correct

def falsePositives(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 1)
  target = (target == 0)
  correct = torch.sum(output * target)

  return correct

def falseNegatives(output, target):
  output = torch.argmax(output, dim=1).flatten()
  target = target.flatten()
  no_ignore = target.ne(255).cuda()
  output = output.masked_select(no_ignore)
  target = target.masked_select(no_ignore)
  output = (output == 0)
  target = (target == 1)
  correct = torch.sum(output * target)

  return correct

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# 把网络设置在训练模式
model.train()
model.to(device)

cuda:0


UNet(
  (conv1): DoubleConv(
    (block): Sequential(
      (0): Conv2d(5, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): DoubleConv(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=reflect)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, padding_mode=r

In [11]:
loss_history = [] #训练过程中的loss数据
num_epochs = 20
for epoch in range(num_epochs):
    for i,data in enumerate(train_loader,0):
        inputs, targets = data[0].to(device), data[1].to(device)
        inputs = inputs.float()
        targets = targets.long()
        #(0) 复位优化器的梯度
        optimizer.zero_grad()
        #(1) 前向计算
        y_pred = model(inputs)
        #(2) 计算loss
        targets = targets.squeeze(1)
        loss = loss_fn(y_pred, targets)
        #(3) 反向求导
        loss.backward()
        #(4) 反向迭代
        optimizer.step()

    # 记录训练过程中的准确率
    loss_history.append(loss.item())
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

Epoch 1/20, Loss: 0.16447287797927856
Epoch 2/20, Loss: 0.1635044366121292
Epoch 3/20, Loss: 0.16615250706672668
Epoch 4/20, Loss: 0.17054006457328796
Epoch 5/20, Loss: 0.09933825582265854
Epoch 6/20, Loss: 0.18359552323818207
Epoch 7/20, Loss: 0.09616755694150925
Epoch 8/20, Loss: 0.06506507098674774
Epoch 9/20, Loss: 0.15975381433963776
Epoch 10/20, Loss: 0.06251441687345505
Epoch 11/20, Loss: 0.19316399097442627
Epoch 12/20, Loss: 0.11907809972763062
Epoch 13/20, Loss: 0.07195213437080383
Epoch 14/20, Loss: 0.02803892269730568
Epoch 15/20, Loss: 0.09281354397535324
Epoch 16/20, Loss: 0.056515585631132126
Epoch 17/20, Loss: 0.18024761974811554
Epoch 18/20, Loss: 0.08565669506788254
Epoch 19/20, Loss: 0.18324299156665802
Epoch 20/20, Loss: 0.034076839685440063


In [12]:
model.eval()
iou = []
acc = []
TP = 0
TN = 0
FP = 0
FN = 0
for idx in range(len(test_dataset)):
  inputs, labels = test_dataset[idx]
  inputs = inputs.unsqueeze(0)
  inputs = inputs.to(device).float()
  labels = labels.to(device)
  with torch.no_grad():
    outputs = model(inputs)
  test_acc = computeAccuracy(outputs,labels).cpu().numpy()
  acc.append(test_acc)
  test_iou = computeIOU(outputs,labels).cpu().numpy()
  iou.append(test_iou)

  TP += truePositives(outputs,labels).cpu().numpy()
  TN += trueNegatives(outputs,labels).cpu().numpy()
  FP += falsePositives(outputs,labels).cpu().numpy()
  FN += falseNegatives(outputs,labels).cpu().numpy()

mean_iou = sum(iou) / len(iou)
mean_acc = sum(acc) / len(acc)

In [13]:
print(mean_acc)
print(mean_iou)

print(TP)
print(TN)
print(FP)
print(FN)

confusion_matrix = [TP,TN,FP,FN]

0.9700045757983105
0.48786355919423086
1278448.0
19826865
171315
481324


In [14]:
# 保存模型参数为字典
torch.save(model.state_dict(), '/content/drive/MyDrive/RS/HAND/UNet_S1S2_state_dict.pth')
np.savetxt('/content/drive/MyDrive/RS/HAND/S1S2_evaluate/S1S2_UNet_loss.txt', loss_history)
np.savetxt('/content/drive/MyDrive/RS/HAND/S1S2_evaluate/S1S2_UNet_acc.txt', acc)
np.savetxt('/content/drive/MyDrive/RS/HAND/S1S2_evaluate/S1S2_UNet_iou.txt', iou)
np.savetxt('/content/drive/MyDrive/RS/HAND/S1S2_evaluate/S1S2_UNet_confusion_matrix.txt', confusion_matrix)