In [1]:
import torch
from torch import nn
import torch.nn.functional as f
import torchvision
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
import tqdm
from dataset import CamVidDataset
from utils import label_accuracy_score
from utils import get_weights
from utils import CrossEntropyLoss2d
# from segnet1 import SegNet
from segnet_dropout import SegNet
# from deeplabv3_plus import DeepLab
# from segnet import SegNet
# from segnet import SegResNet

In [2]:
# 数据集中对应的标签
classes = ["Animal", "Archway","Bicyclist","Bridge","Building","Car","CartLuggagePram",
          "Child","Column_Pole", "Fence", "LaneMkgsDriv", "LaneMkgsNonDriv", "Misc_Text",
          "MotorcycleScooter", "OtherMoving", "ParkingBlock", "Pedestrian", "Road", "RoadShoulder",
          "Sidewalk", "SignSymbol", "Sky", "SUVPickupTruck", "TrafficCone", "TrafficLight",
          "Train", "Tree", "Truck_Bus", "Tunnel", "VegetationMisc", "Void", "Wall"]

# 各种标签所对应的RGB值
colormap = [[64,128,64],[192,0,128],[0,128,192],[0,128,64],[128,0,0],[64,0,128],
           [64,0,192],[192,128,64],[192,192,128],[64,64,128],[128,0,192],[192,0,64],
           [128,128,64],[192,0,192],[128,64,64],[64,192,128],[64,64,0],[128,64,128],
           [128,128,192],[0,0,192],[192,128,128],[128,128,128],[64,128,192],[0,0,64],
           [0,64,64],[192,64,128],[128,128,0],[192,128,192],[64,0,64],[192,192,0],
           [0,0,0],[64,192,0]]

num_classes = len(classes)
print(num_classes)
print(len(colormap))

32
32


In [3]:
size = 224
height = size
width = size
camvid_train = CamVidDataset(mode="train",add_val = False, augmentation = True)
camvid_test = CamVidDataset(mode="test")

train_data = DataLoader(camvid_train, batch_size=5, shuffle=True)
valid_data = DataLoader(camvid_test, batch_size=5)
print(len(train_data),len(camvid_train))

274 1367


In [4]:
# 计算类权重
target = []
for image,label in camvid_train:
    target.append(label.numpy())
target = np.array(target)
target = torch.from_numpy(target)
weight = get_weights(target)

In [5]:
num_classes = len(classes)
# net = SegResNet(12)
net = SegNet(3,num_classes)
# net = SegNet(32)
if torch.cuda.is_available():
    net = net.cuda()
net.load_state_dict(torch.load(r"vgg16_bn-6c64b313.pth"),strict=False)
# criterion = nn.CrossEntropyLoss()
criterion = CrossEntropyLoss2d(weight = weight)
# 学习率不能太高，否则输出的标签就全变成0了
# LEARNING_RATE = 5e-6

# basic_optim = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# optimizer = basic_optim
# basic_optim = torch.optim.SGD(net.parameters(), lr=1e-2, weight_decay=1e-4)
basic_optim = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0)
optimizer = basic_optim
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[15,30,45], 
#                                                  gamma=0.1)


In [6]:
EPOCHES = 200

# 训练时的数据
train_loss = []
train_acc = []
train_acc_cls = []
train_mean_iu = []
train_fwavacc = []

# 验证时的数据
eval_loss = []
eval_acc = []
eval_acc_cls = []
eval_mean_iu = []
eval_fwavacc = []

# 记录在训练和测试集上预测出全零的图片数量
num_zero_train_epoch = 0
num_zero_test_epoch = 0
num_zero_train = 0
num_zero_test = 0
train_zero = []
test_zero = []


for e in (range(EPOCHES)):
    
    _train_loss = 0
    _train_acc = 0
    _train_acc_cls = 0
    _train_mean_iu = 0
    _train_fwavacc = 0
    
    prev_time = datetime.now()
    net = net.train()
    for img_data, img_label in (train_data):
        if torch.cuda.is_available:
            im = Variable(img_data).cuda()
            label = Variable(img_label).cuda()
        else:
            im = Variable(img_data)
            label = Variable(img_label)

        # 前向传播
        out = net(im)
        loss = criterion(out, label)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        _train_loss += loss.item()
        
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        # 如果得到的所有像素点的分类都是0，则输出
        if(np.all(label_pred==0)):
            # print("train: all zero! epoch: "+str(e))
            num_zero_train_epoch += 1
        label_true = label.data.cpu().numpy()

        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            _train_acc += acc
            _train_acc_cls += acc_cls
            _train_mean_iu += mean_iu
            _train_fwavacc += fwavacc
#         scheduler.step()
    # print("epoch "+str(e)+" train : "+" num of zero label=" + str(num_zero_train_epoch))
    num_zero_train += num_zero_train_epoch
    train_zero.append(num_zero_train_epoch)
    num_zero_train_epoch = 0
    
    # 记录当前轮的数据
    train_loss.append(_train_loss/len(train_data))
    train_acc.append(_train_acc/len(camvid_train))
    train_acc_cls.append(_train_acc_cls)
    train_mean_iu.append(_train_mean_iu/len(camvid_train))
    train_fwavacc.append(_train_fwavacc)


    net = net.eval()
    
    _eval_loss = 0
    _eval_acc = 0
    _eval_acc_cls = 0
    _eval_mean_iu = 0
    _eval_fwavacc = 0
    
    for img_data, img_label in valid_data:
        if torch.cuda.is_available():
            im = Variable(img_data).cuda()
            label = Variable(img_label).cuda()
        else:
            im = Variable(img_data)
            label = Variable(img_label)
        
        # forward
        out = net(im)
        loss = criterion(out, label)
        _eval_loss += loss.item()
        
        label_pred = out.max(dim=1)[1].data.cpu().numpy()
        if(np.all(label_pred==0)):
            # print("test: all zero! epoch: "+str(e))
            num_zero_test_epoch += 1
        label_true = label.data.cpu().numpy()
        for lbt, lbp in zip(label_true, label_pred):
            acc, acc_cls, mean_iu, fwavacc = label_accuracy_score(lbt, lbp, num_classes)
            _eval_acc += acc
            _eval_acc_cls += acc_cls
            _eval_mean_iu += mean_iu
            _eval_fwavacc += fwavacc
            
    # print("epoch "+str(e)+" test : "+" num of zero label= " + str(num_zero_test_epoch))
    num_zero_test += num_zero_test_epoch
    test_zero.append(num_zero_test_epoch)
    num_zero_test_epoch = 0
    
    # 记录当前轮的数据
    eval_loss.append(_eval_loss/len(valid_data))
    eval_acc.append(_eval_acc/len(camvid_test))
    eval_acc_cls.append(_eval_acc_cls)
    eval_mean_iu.append(_eval_mean_iu/len(camvid_test))
    eval_fwavacc.append(_eval_fwavacc)

    # 打印当前轮训练的结果
    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    epoch_str = ('Epoch: {}, Train Loss: {:.5f}, Train Acc: {:.5f}, Train Mean IU: {:.5f}, \
Valid Loss: {:.5f}, Valid Acc: {:.5f},Valid Mean IU: {:.5f} ,'.format(
        e, _train_loss / len(train_data), _train_acc / len(camvid_train),_train_mean_iu / len(camvid_train),
        _eval_loss / len(valid_data), _eval_acc / len(camvid_test),_eval_mean_iu / len(camvid_test)))
    time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
    print(epoch_str + time_str)
    # show()

  acc_cls = np.diag(hist) / hist.sum(axis=1)
  iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))


Epoch: 0, Train Loss: 2.97228, Train Acc: 0.13571, Train Mean IU: 0.03113, Valid Loss: 2.79619, Valid Acc: 0.10394,Valid Mean IU: 0.01893 ,Time: 0:1:11
Epoch: 1, Train Loss: 2.68851, Train Acc: 0.18584, Train Mean IU: 0.04616, Valid Loss: 2.67868, Valid Acc: 0.13890,Valid Mean IU: 0.01895 ,Time: 0:0:35
Epoch: 2, Train Loss: 2.58813, Train Acc: 0.25194, Train Mean IU: 0.05805, Valid Loss: 2.61170, Valid Acc: 0.15936,Valid Mean IU: 0.02239 ,Time: 0:0:35
Epoch: 3, Train Loss: 2.49068, Train Acc: 0.33199, Train Mean IU: 0.07486, Valid Loss: 2.54838, Valid Acc: 0.26379,Valid Mean IU: 0.05616 ,Time: 0:0:35
Epoch: 4, Train Loss: 2.36929, Train Acc: 0.37766, Train Mean IU: 0.08532, Valid Loss: 2.29990, Valid Acc: 0.32244,Valid Mean IU: 0.07178 ,Time: 0:0:36
Epoch: 5, Train Loss: 2.27501, Train Acc: 0.42617, Train Mean IU: 0.09995, Valid Loss: 2.11556, Valid Acc: 0.37165,Valid Mean IU: 0.08569 ,Time: 0:0:35
Epoch: 6, Train Loss: 2.17330, Train Acc: 0.43370, Train Mean IU: 0.10186, Valid Loss: 1

In [7]:
import pandas as pd

data = {"train_loss":train_loss,"eval_loss":eval_loss,"train_acc":train_acc,"eval_acc":eval_acc,"train_miou":train_mean_iu,"eval_miou":eval_mean_iu}
result = pd.DataFrame(data)
result.to_csv("result_aug_batch5_lr0.001_drop.csv",index=False)

In [8]:
def predict(img, label): # 预测结果
    img = Variable(img.unsqueeze(0)).cuda()
    out = net(img)
    pred = out.max(1)[1].squeeze().cpu().data.numpy()
    return pred, label
# 显示当前网络的训练结果
def show(size=224, num_image=4, img_size=10, offset=0, shuffle=False):
    _, figs = plt.subplots(num_image, 3, figsize=(img_size, img_size))
    for i in range(num_image):
        if(shuffle==True):
            offset = rand.randint(0, min(len(camvid_train)-i-1, len(camvid_test)-i-1))
        img_data, img_label = camvid_test[i+offset]
        pred, label = predict(img_data, img_label)
        img_data = Image.open(camvid_test.data_list[i+offset])
        img_label = Image.open(camvid_test.label_list[i+offset])
        img_data, img_label = crop(img_data, img_label)
        figs[i, 0].imshow(img_data)  # 原始图片
        figs[i, 0].axes.get_xaxis().set_visible(False)  # 去掉x轴
        figs[i, 0].axes.get_yaxis().set_visible(False)  # 去掉y轴
        figs[i, 1].imshow(img_label)                    # 标签
        figs[i, 1].axes.get_xaxis().set_visible(False)  # 去掉x轴
        figs[i, 1].axes.get_yaxis().set_visible(False)  # 去掉y轴
        figs[i, 2].imshow(pred)                         # 模型输出结果
        figs[i, 2].axes.get_xaxis().set_visible(False)  # 去掉x轴
        figs[i, 2].axes.get_yaxis().set_visible(False)  # 去掉y轴

    # 在最后一行图片下面添加标题
    figs[num_image-1, 0].set_title("Image", y=-0.2*(10/img_size))
    figs[num_image-1, 1].set_title("Label", y=-0.2*(10/img_size))
    figs[num_image-1, 2].set_title("segnet", y=-0.2*(10/img_size))
    plt.show()