In [1]:
from net.unetpp import NestedUNet as unetpp
from util import DiceLoss
import cv2
import numpy as np
import os
import pandas as ps
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn as nn
from torch import optim
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
# This is for the progress bar.
from tqdm.auto import tqdm
from torch.nn import functional as F
from torchvision import transforms
from torchvision.utils import save_image
from model import UNet, MyDataset
from torch.utils.data import random_split
from my_metrics import calculate_metrics
import csv

In [2]:
#设置设备，选择cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('cuda' if torch.cuda.is_available() else 'cpu')
data_folder_path = '../data/FAZ/Domain1/train/imgs'
mask_folder_path = '../data/FAZ/Domain1/train/mask'

data_files = os.listdir(data_folder_path)
mask_files= os.listdir(mask_folder_path)
del_count = 0 
for j in range(len(data_files)):
        if data_files[j-del_count].endswith(".png"):
            data_files[j-del_count] = data_files[j-del_count]
        else:
            del data_files[j-del_count]
            del_count = del_count+1
del_count = 0      
for j in range(len(mask_files)):
        if mask_files[j-del_count].endswith(".png"):
            mask_files[j-del_count] = mask_files[j-del_count]
        else:
            del mask_files[j-del_count]
            del_count = del_count+1

dataset = MyDataset(data_file=data_files,mask_files=mask_files,data_folder_path=data_folder_path,mask_folder_path=mask_folder_path)
weight_path='params/unet.pth'
#data_path=r'data'
model = UNet().to(device)

train_dataset, valid_dataset = random_split(dataset=dataset, lengths=[195, 49], generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(train_dataset,batch_size=16,shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=16,shuffle=True)

train_loader_length = len(train_loader)

opt = optim.SGD(model.parameters(),lr=0.01, momentum=0.9, weight_decay=0.0001)
ce_loss = nn.BCELoss()
dice_loss = DiceLoss(1)
all_epoch = 100 # 设置了早停策略

lossArr = []
dicelist = []
jclist = []
hd95list = []
assdlist = []
splist = []
recalllist = []
prelist = []

best_score = 999  # 用于保存最好的性能指标值
patience = 8  # 容忍的 epoch 数，即连续多少个 epoch 没有改进就停止训练
counter = 0  # 计数器，记录连续 epoch 未改进的次数

prev_dice = 0
print("model starts training\n")
for current_epoch in range(all_epoch):
    model.train()
    lossItem = 0
    with tqdm(total=train_loader_length, desc=f'Epoch {current_epoch + 1}/{all_epoch}', unit='batch') as pbar:
        for idx, (image,segment_image) in enumerate(train_loader):
            image = image.to(device)
            segment_image = segment_image.to(device)
            opt.zero_grad()
            output_img = model(image)
            train_loss = 0.5*ce_loss(output_img, segment_image)+0.5*dice_loss(output_img, segment_image)
            train_loss.backward()
            opt.step()
            lossItem = lossItem + train_loss.item()
             # 更新进度条
            
            pbar.update(1)  # 更新进度条
       
        print("start to eval")

        model.eval()
        valid_loss = 0
        all_dc,all_jc,all_hd,all_assd,all_sp,all_recall,all_pre = 0,0,0,0,0,0,0
        datalen = len(valid_dataset)
        for idx,(valid_img,valid_seg_img) in enumerate(valid_loader):
            valid_img = valid_img.to(device)
            valid_seg_img = valid_seg_img.to(device)
            with torch.no_grad():
                valid_output = model(valid_img)
            for pred,target in zip(valid_output,valid_seg_img):
                dice,jaccard,hd95_score,assd_score,sp_score,recall_score,pre_score = calculate_metrics(pred,target)
                all_dc += dice/datalen
                all_jc += jaccard/datalen
                all_hd += hd95_score/datalen
                all_assd += assd_score/datalen
                all_sp += sp_score/datalen
                all_recall += recall_score/datalen
                all_pre += pre_score/datalen
                valid_loss += 0.5*ce_loss(valid_output,valid_seg_img)+0.5*dice_loss(valid_output,valid_seg_img)
    
        
        current_score = valid_loss
        
        pbar.set_postfix({'Train_Loss': lossItem,'Valid_Loss':valid_loss})  # 更新进度条显示的信息
        
        torch.save(model.state_dict(), f'params/unet_{current_epoch}.pth')
                
        
        lossArr.append(lossItem)
        dicelist.append(all_dc)
        jclist.append(all_jc)
        hd95list.append(all_hd)
        assdlist.append(all_assd)
        splist.append(all_sp)
        recalllist.append(all_recall)
        prelist.append(all_pre)
        
transposed_lists = zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)

with open('output_unet.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'lossArr', 'dicelist', 'jclist', 'hd95list', 'assdlist', 'splist', 'recalllist', 'prelist'])

    for epoch, row in enumerate(zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)):
        writer.writerow([epoch, *row])

cuda
model starts training



Epoch 1/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 2/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 3/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 4/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 5/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 6/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 7/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 8/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 9/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 10/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 11/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 12/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 13/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 14/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 15/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 16/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 17/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 18/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 19/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 20/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 21/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 22/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 23/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 24/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 25/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 26/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 27/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 28/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 29/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 30/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 31/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 32/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 33/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 34/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 35/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 36/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 37/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 38/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 39/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 40/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 41/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 42/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 43/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 44/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 45/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 46/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 47/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 48/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 49/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 50/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 51/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 52/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 53/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 54/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 55/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 56/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 57/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 58/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 59/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 60/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 61/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 62/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 63/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 64/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 65/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 66/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 67/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 68/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 69/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 70/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 71/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 72/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 73/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 74/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 75/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 76/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 77/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 78/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 79/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 80/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 81/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 82/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 83/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 84/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 85/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 86/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 87/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 88/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 89/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 90/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 91/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 92/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 93/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 94/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 95/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 96/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 97/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 98/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 99/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


Epoch 100/100:   0%|          | 0/13 [00:00<?, ?batch/s]

start to eval


In [3]:
transposed_lists = zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)

with open('output_unetpp.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['epoch', 'lossArr', 'dicelist', 'jclist', 'hd95list', 'assdlist', 'splist', 'recalllist', 'prelist'])

    for epoch, row in enumerate(zip(lossArr, dicelist, jclist, hd95list, assdlist, splist, recalllist, prelist)):
        writer.writerow([epoch, *row])