In [1]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.functional as F
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
#from torchsummary import summary
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
from sklearn.cluster import KMeans

## 图像数据处理

### 读取图片

In [2]:
#这里将dataset2改成dataset1即可读取dataset1中的数据
train_list = sorted([os.path.join('dataset2/train/',img) for img in os.listdir('dataset2/train/')])
trainGT_list = sorted([os.path.join('dataset2/train_GT/SEG/',img) for img in os.listdir('dataset2/train_GT/SEG/')])

In [3]:
#dataset1则设为628，dataset2设为500
imgsize = 500

In [4]:
def loadpro_img(file_names):
    images = []
    for file_name in file_names:
        img = cv2.imread(file_name,-1)#读取时不做改变，默认参数1加载彩色图片
        img = img.astype(np.uint8)
        #img = Image.fromarray(img)
        images.append(img)
    return np.array(images)

In [5]:
train_data = loadpro_img(train_list)
trainGT_data = loadpro_img(trainGT_list)

### 图片one-hot编码及pad

In [6]:
# GT生成二值化图像mask
trainGTbin = np.zeros(trainGT_data.shape)
for i in range(len(trainGT_data)):
    gt = trainGT_data[i].reshape((-1,))
    index = np.argwhere(gt > 0)
    #trainGTbin[i][index] = 1 
    gtbin = trainGTbin[i].reshape((-1,))
    gtbin[index] = 1
    trainGTbin[i] = gtbin.reshape((imgsize,imgsize))

In [7]:
# mask换成one-hot
from sklearn.preprocessing import OneHotEncoder
mask = np.zeros((len(trainGTbin),2,imgsize,imgsize))
for i in range(len(trainGTbin)):
    enc = OneHotEncoder(categories='auto')
    a=enc.fit_transform(trainGTbin[i].reshape((-1,1)))
    label=a.toarray()
    label1=label[:,0].reshape((imgsize,imgsize))
    label2=label[:,1].reshape((imgsize,imgsize))
    mask[i] = np.array([label1,label2])

In [8]:
#输入图像镜像pad
data = np.zeros((len(train_data),1,imgsize+2*92,imgsize+2*92))
for i in range(len(train_data)):
    data[i] = np.pad(train_data[i],((92,92)),'symmetric')

In [9]:
data_tensor = torch.from_numpy(data)
label_tensor = torch.from_numpy(mask)

### 重写dataset，图片分组

In [10]:
#重写dataset，图片分组
data_transforms = transforms.Compose([
        #transforms.Pad(92,padding_mode='symmetric'),
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5])
    ])
datalabel_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])
class myDataset(Dataset):
    def __init__(self,imgs,labels,gt,transform_x=None,transform_y=None):
        self.transform_x = transform_x
        self.transform_y = transform_y
        self.imgs = imgs
        self.labels = labels
        self.gt = gt
    def __getitem__(self,index):
        img = self.imgs[index]
        img = img/255
        img = (img-0.5)/0.5
        #img = self.transform_x(img)
        label = self.labels[index]
        #label = self.transform_y(label)
        return torch.from_numpy(img),torch.from_numpy(label),torch.from_numpy(self.gt[index])
    def __len__(self):
        return len(self.imgs)
train_dataset = myDataset(data,mask,trainGT_data,data_transforms,datalabel_transforms)
dataloders = DataLoader(train_dataset,batch_size=2,shuffle=True)

## 模型搭建

In [11]:
class convnet(nn.Module):
    def __init__(self,inchannels,outchannels):
        super(convnet,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inchannels,outchannels,3),
            nn.BatchNorm2d(outchannels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(outchannels,outchannels,3),
            nn.BatchNorm2d(outchannels),
            nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x = self.conv(x)
        return x

In [12]:
class myUnet(nn.Module):
    def __init__(self):
        super(myUnet,self).__init__()
        #down_sample
        self.convnet1 = convnet(1,64)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.convnet2 = convnet(64,128)
        self.maxpool2 = nn.MaxPool2d(2)
        
        self.convnet3 = convnet(128,256)
        self.maxpool3 = nn.MaxPool2d(2)
        
        self.convnet4 = convnet(256,512)
        self.maxpool4 = nn.MaxPool2d(2)
        
        self.convnet5 = convnet(512,1024)
        
        #up_sample
        self.up1 = nn.ConvTranspose2d(1024,512,2,stride=2)
        self.convnet6 = convnet(1024,512)
        
        self.up2 = nn.ConvTranspose2d(512,256,2,stride=2)
        self.convnet7 = convnet(512,256)
        
        self.up3 = nn.ConvTranspose2d(256,128,2,stride=2)
        self.convnet8 = convnet(256,128)
        
        self.up4 = nn.ConvTranspose2d(128,64,2,stride=2)
        self.convnet9 = convnet(128,64)
        
        self.conv10 = nn.Conv2d(64,2,1)
    def forward(self,x):
        x1 = self.convnet1(x)
        x2 = self.convnet2(self.maxpool1(x1))
        x3 = self.convnet3(self.maxpool2(x2))
        x4 = self.convnet4(self.maxpool3(x3))
        
        x5 = self.convnet5(self.maxpool4(x4))
        
        xup1 = self.up1(x5)
        pad = (x4.size(2)-xup1.size(2))//2
        xcrop1 = x4[:,:,pad:pad+xup1.size(2),pad:pad+xup1.size(2)]
        xcat1 = torch.cat([xcrop1,xup1],1)
        x6 = self.convnet6(xcat1)
        
        xup2 = self.up2(x6)
        pad = (x3.size(2)-xup2.size(2))//2
        xcrop2 = x3[:,:,pad:pad+xup2.size(2),pad:pad+xup2.size(2)]
        xcat2 = torch.cat([xcrop2,xup2],1)
        x7 = self.convnet7(xcat2)
        
        xup3 = self.up3(x7)
        pad = (x2.size(2)-xup3.size(2))//2
        xcrop3 = x2[:,:,pad:pad+xup3.size(2),pad:pad+xup3.size(2)]
        xcat3 = torch.cat([xcrop3,xup3],1)
        x8 = self.convnet8(xcat3)
        
        xup4 = self.up4(x8)
        pad = (x1.size(2)-xup4.size(2))//2
        xcrop4 = x1[:,:,pad:pad+xup4.size(2),pad:pad+xup4.size(2)]
        xcat4 = torch.cat([xcrop4,xup4],1)
        x9 = self.convnet9(xcat4)
        
        xf = self.conv10(x9)
        return xf

## 训练

In [13]:
model = myUnet().cuda()

In [15]:
## 加载现有模型，则不需要下面的训练过程，可以直接开始评估
model.load_state_dict(torch.load('unet2params1.pkl'))

<All keys matched successfully>

In [14]:
#训练参数的设计
epoch_nums = 35
#optimizer = torch.optim.SGD(model.parameters(),lr=1e-6,momentum=0.8)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4)
loss = nn.BCEWithLogitsLoss().cuda()
batch_size=2

In [15]:
#训练过程
losses = []
accs = []
jses = []
for echo in range(epoch_nums):
    train_loss = 0
    train_js=0
    acc = 0
    n = 0
    for img,label,gt in tqdm(dataloders):
        img = Variable(img.float()).cuda()
        label = Variable(label).cuda()
        out = model(img)
        lossvalue = loss(out,label.float())
        optimizer.zero_grad()
        lossvalue.backward()
        optimizer.step()
        
        train_loss += float(lossvalue)
        #acc = 0
        n += batch_size
        #loss参考意义不大
        if n >len(dataloders.dataset):
            pred=out[0].argmax(0)
            true=label[0].argmax(0)
            accnum = (pred==true).sum()
            acc += float(accnum)/(imgsize*imgsize)
        else:
            for j in range(batch_size):
                pred = torch.sigmoid(out[j])
                pred=pred.argmax(0)
                true=label[j].argmax(0)
                accnum = (pred==true).sum()
                acc += float(accnum)/(imgsize*imgsize)
        #jsvalue
#         if n >len(dataloders.dataset):
#             pred=out[0].argmax(0).cpu()
#             pred = pred.numpy()
#             pred = pred.astype(np.uint8)
#             maxval,pred_img = cv2.connectedComponents(pred, 4, cv2.CV_32S)
#             true=gt[0]
#             _,score=Jaccard_eval(pred_img,true)
#             train_js += score
#         else:
#             for j in range(batch_size):
#                 pred=out[j].argmax(0).cpu()
#                 pred = pred.numpy()
#                 pred = pred.astype(np.uint8)
#                 true=gt[j]
#                 maxval,pred_img = cv2.connectedComponents(pred, 4, cv2.CV_32S)
#                 _,score=Jaccard_eval(pred_img,true)
#                 train_js += score
    jses.append(acc/len(dataloders.dataset))   
    losses.append(train_loss/len(dataloders))
    print('echo:'+ ' '+str(echo))
    print('loss:'+ ' '+str(train_loss/len(dataloders.dataset)))
    print('acc:'+ ' '+str(acc/len(dataloders.dataset)))
        

100%|██████████| 84/84 [00:51<00:00,  1.39it/s]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 0
loss: 0.2666672148874828
acc: 0.8511354047619047


100%|██████████| 84/84 [01:54<00:00,  1.95s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 1
loss: 0.2241946311578864
acc: 0.8783200476190474


100%|██████████| 84/84 [02:21<00:00,  1.46s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 2
loss: 0.20106244690361477
acc: 0.8857747380952378


100%|██████████| 84/84 [02:49<00:00,  2.06s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 3
loss: 0.18360234335774467
acc: 0.8901962857142852


100%|██████████| 84/84 [02:39<00:00,  2.11s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 4
loss: 0.16834300242009617
acc: 0.8966801904761911


100%|██████████| 84/84 [02:48<00:00,  2.01s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 5
loss: 0.15643310103388058
acc: 0.8995805952380946


100%|██████████| 84/84 [02:31<00:00,  1.77s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 6
loss: 0.14546752641243593
acc: 0.9036394761904765


100%|██████████| 84/84 [02:39<00:00,  1.96s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 7
loss: 0.13642989764256136
acc: 0.9079555714285719


100%|██████████| 84/84 [02:18<00:00,  1.10it/s]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 8
loss: 0.12863714318899883
acc: 0.911835214285714


100%|██████████| 84/84 [01:50<00:00,  1.64s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 9
loss: 0.11816571280360222
acc: 0.9190058809523807


100%|██████████| 84/84 [02:30<00:00,  1.02it/s]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 10
loss: 0.11096080055549032
acc: 0.923596119047619


100%|██████████| 84/84 [02:29<00:00,  1.84s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 11
loss: 0.10500690438562915
acc: 0.9267613095238089


100%|██████████| 84/84 [02:17<00:00,  1.16s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 12
loss: 0.09496035267199789
acc: 0.9350543333333334


100%|██████████| 84/84 [02:30<00:00,  1.96s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 13
loss: 0.08866490716380733
acc: 0.9393232619047612


100%|██████████| 84/84 [02:21<00:00,  1.29s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 14
loss: 0.08152453140133903
acc: 0.9448032380952378


100%|██████████| 84/84 [02:04<00:00,  2.06s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 15
loss: 0.07480077259242535
acc: 0.9497113571428569


100%|██████████| 84/84 [01:52<00:00,  1.06s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 16
loss: 0.06872488970735244
acc: 0.9545202857142859


100%|██████████| 84/84 [02:46<00:00,  1.61s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 17
loss: 0.06414900405243748
acc: 0.9574474523809519


100%|██████████| 84/84 [03:40<00:00,  2.61s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 18
loss: 0.05999473939161925
acc: 0.9598781428571433


100%|██████████| 84/84 [03:56<00:00,  3.00s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 19
loss: 0.054256148503295014
acc: 0.9644102380952382


100%|██████████| 84/84 [03:59<00:00,  3.03s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 20
loss: 0.051247499350990565
acc: 0.9660927619047619


100%|██████████| 84/84 [03:49<00:00,  2.75s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 21
loss: 0.047971450812405066
acc: 0.968096547619048


100%|██████████| 84/84 [04:13<00:00,  3.23s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 22
loss: 0.04587424932313817
acc: 0.9691277857142857


100%|██████████| 84/84 [04:45<00:00,  3.27s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 23
loss: 0.04504803705605723
acc: 0.9689228571428565


100%|██████████| 84/84 [04:45<00:00,  3.38s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 24
loss: 0.04204334897388305
acc: 0.9709939523809523


100%|██████████| 84/84 [04:32<00:00,  3.03s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 25
loss: 0.03917756164446473
acc: 0.9732229523809527


100%|██████████| 84/84 [04:18<00:00,  3.27s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 26
loss: 0.03667912119999528
acc: 0.975171880952381


100%|██████████| 84/84 [04:41<00:00,  2.33s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 27
loss: 0.034648114283170016
acc: 0.9762338095238089


100%|██████████| 84/84 [04:33<00:00,  3.13s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 28
loss: 0.03311372834390828
acc: 0.976861023809524


100%|██████████| 84/84 [04:40<00:00,  3.52s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 29
loss: 0.03206446691460553
acc: 0.9776954761904756


100%|██████████| 84/84 [04:15<00:00,  3.17s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 30
loss: 0.030409898453702528
acc: 0.978765547619048


100%|██████████| 84/84 [04:32<00:00,  2.97s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 31
loss: 0.029706198966041916
acc: 0.9791069761904764


100%|██████████| 84/84 [04:10<00:00,  3.52s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 32
loss: 0.028936017326833235
acc: 0.9793348571428572


100%|██████████| 84/84 [04:16<00:00,  3.34s/it]
  0%|          | 0/84 [00:00<?, ?it/s]

echo: 33
loss: 0.02805051693160619
acc: 0.9798344285714284


100%|██████████| 84/84 [04:33<00:00,  3.17s/it]

echo: 34
loss: 0.026977916869024437
acc: 0.9805693571428571





In [103]:
#保存训练模型
torch.save(model.state_dict(),'unet2paramsn.pkl')

## 测试评估

In [16]:
# Jaccard相似度评估函数
def get_onelabelmap(data,label):
    index = np.argwhere(data==label)
    sam = np.zeros(data.shape)
    if len(index) is not 0:
        for i in range(len(index)):
            sam[index[i][0]][index[i][1]] = 1
    sam = sam.astype(np.int32)#注意类型转换，否则无法进行与运算
    return sam
def get_fit_GTimgs(cell_pred,label,cell_gt):
    #对于每个预测标签细胞，求与其匹配的gt细胞
    fit_lists = []
    Jaccard_lists = []
    labels = np.unique(cell_pred)
    cell_gt_map = get_onelabelmap(cell_gt,label)
    if len(labels) is not 0:
        for i in range(1,len(labels)):
            cell_pred_map = get_onelabelmap(cell_pred,labels[i])
            andmap=cell_pred_map&cell_gt_map
            if np.sum(andmap)>0.5*np.sum(cell_gt_map):
                a = np.sum(cell_pred_map&cell_gt_map)
                b = np.sum(cell_pred_map|cell_gt_map)
                fit_lists.append(labels[i])
                Jaccard_lists.append(a/b)
    return fit_lists,Jaccard_lists
def Jaccard_eval(cell_pred,cell_gt):
    #Jaccard相似度评估
    labels_gt = np.unique(cell_gt)
    JS_val = np.zeros((len(labels_gt)-1,))
    for i in range(1,len(labels_gt)):
        fit_lists,Jaccard_lists = get_fit_GTimgs(cell_pred,labels_gt[i],cell_gt)
        if len(fit_lists) is not 0:
             JS_val[i-1] = np.max(Jaccard_lists)
    return JS_val,np.mean(JS_val)

### 后处理方法

#### watershed process

In [17]:
def watershed_process(mask):
    kernel = np.ones((3,3),np.uint8)
    # sure background area
    sure_bg = cv2.dilate(mask,kernel,iterations=3)
    # Finding sure foreground area
    dist_transform = cv2.distanceTransform(mask,cv2.DIST_L2,5)
    ret, sure_fg = cv2.threshold(dist_transform,0.592*dist_transform.max(),255,0)
    # Finding unknown region
    sure_fg = np.uint8(sure_fg)
    unknown = cv2.subtract(sure_bg,sure_fg)
    # Marker labelling
    ret, markers = cv2.connectedComponents(sure_fg)
    # Add one to all labels so that sure background is not 0, but 1
    markers = markers + 1
    # Now, mark the region of unknown with zero
    markers[unknown == 1] = 0
    #predpro = 255 - predpro
    rgb = cv2.cvtColor(mask,cv2.COLOR_GRAY2BGR)
    markers2 = cv2.watershed(rgb,markers)
    newmark=markers2
    newmark[markers2==-1] = 1
    newmark = newmark-1
    return newmark

#### clusters

In [18]:
def kluster_proess(mask):
    #maxval,pred_img,_,_ = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
    pred_img = watershed_process(mask)
    zonelabels = np.unique(pred_img)
    imgpro = np.zeros(mask.shape,np.uint8)
    curnum = 0
    for i in range(1,len(zonelabels)):
        zone = np.zeros(mask.shape,np.uint8)
        zone[pred_img==i]=1
        index = np.argwhere(pred_img==i)
        # plt.imshow(zone)
        # plt.show()
        #erode
        kernel = np.ones((3,3),np.uint8)
        img_erode = cv2.erode(zone,kernel,iterations = 20)
        # plt.imshow(img_erode)
        # plt.show()

        #clusters
        maxval,curzone,_,centorids = cv2.connectedComponentsWithStats(img_erode, 4, cv2.CV_32S)
        curlabels = np.unique(curzone)
        kclasses = len(curlabels)-1

        flags = centorids[1:,:]
        if len(flags)>0:
            dataxy = index.astype(np.float32)
            clf = KMeans(n_clusters=kclasses,init=flags,n_init=1,tol=1e-6)
            clf.fit(dataxy)
            clf.labels_ += 1 #最小标签是0

            clf.labels_ += curnum #防止标签重复
            for j in range(len(index)):
                imgpro[index[j,0]][index[j,1]] = clf.labels_[j]
            curnum = np.max(clf.labels_)
        else:
            for j in range(len(index)):
                imgpro[index[j,0]][index[j,1]] = curnum + 1
            curnum = 1
            
        #print(curnum)
    #plt.imshow(imgpro)
    #plt.show()
    return imgpro

### 训练数据评估

In [21]:
model.eval()
pred_all = []
score_all = []
isval = True #true表明用来预测，否则就是验证正确率
i = 0
for img,label,gt in train_dataset:
    img = img.reshape((1,1,684,684))
    img = Variable(img.float()).cuda()
    label = Variable(label).cuda()
    orimg = img
    with torch.no_grad():
        testout = model(img)
        testout = torch.sigmoid(testout)
    #testloss = loss(testout,label)
    pred = testout[0].argmax(0).cpu()
    pred = pred.numpy()
    pred = pred.astype(np.uint8)
    
    #pred_img = kluster_proess(pred)
    #pred_img = watershed_process(pred)
    maxval,pred_img = cv2.connectedComponents(pred, 4, cv2.CV_32S)
    pred_all.append(pred_img)
    if isval:
        true = trainGT_data[i]
        _,score = Jaccard_eval(pred_img,true)
        score_all.append(score)
        print('image:{}/168, score: {:.4f}'.format(i+1,score))
    i = i+1
print('final score: %.5f'%np.mean(score_all))        

image:1/168, score: 0.3382
image:2/168, score: 0.1950
image:3/168, score: 0.1930
image:4/168, score: 0.6099
image:5/168, score: 0.3090
image:6/168, score: 0.1762
image:7/168, score: 0.2953
image:8/168, score: 0.0812
image:9/168, score: 0.0755
image:10/168, score: 0.5223
image:11/168, score: 0.1499
image:12/168, score: 0.1492
image:13/168, score: 0.4380
image:14/168, score: 0.3157
image:15/168, score: 0.2952
image:16/168, score: 0.7331
image:17/168, score: 0.5484
image:18/168, score: 0.4328
image:19/168, score: 0.3599
image:20/168, score: 0.3231
image:21/168, score: 0.1472
image:22/168, score: 0.1750
image:23/168, score: 0.1276
image:24/168, score: 0.3813
image:25/168, score: 0.0638
image:26/168, score: 0.3175
image:27/168, score: 0.0693
image:28/168, score: 0.2614
image:29/168, score: 0.0568
image:30/168, score: 0.5725
image:31/168, score: 0.4809
image:32/168, score: 0.3160
image:33/168, score: 0.0509
image:34/168, score: 0.0813
image:35/168, score: 0.1948
image:36/168, score: 0.1485
i

## 预测

In [18]:
import imageio

In [19]:
#存储路径，需要先建好文件夹，否则会报错
result_path = 'dataset2/test_RES'
test_list = sorted([os.path.join('dataset2/test/',img) for img in os.listdir('dataset2/test/')])
test_data = loadpro_img(test_list)

In [20]:
test_datapro = np.zeros((len(test_data),1,684,684))
for i in range(len(test_data)):
    test_datapro[i] = np.pad(test_data[i],((92,92)),'symmetric')
    test_datapro[i] = test_datapro[i]/255.0
    test_datapro[i] = (test_datapro[i]-0.5)/0.5

In [21]:
test_datapro = torch.from_numpy(test_datapro)

In [30]:
model.eval()
pred_all=[]
i = 0
for index,img in enumerate(test_datapro):
    img = img.reshape((1,1,684,684))
    img = Variable(img.float()).cuda()
    with torch.no_grad():
        testout = model(img)
        testout = torch.sigmoid(testout)
    #testloss = loss(testout,label)
    predtry = testout[0][0].cpu()
    predtry = predtry.numpy()*255
    predtry = predtry.astype(np.uint8)
    pred = testout[0].argmax(0).cpu()
    pred = pred.numpy()
    pred = pred.astype(np.uint8)
    
    ## 可以选择不同的后处理方法
    #pred_img = kluster_proess(pred)
    pred_img = watershed_process(pred)
    #maxval,pred_img = cv2.connectedComponents(pred, 4, cv2.CV_32S)
    imageio.imwrite(os.path.join(result_path, 'mask{:0>3d}.tif'.format(index)),pred_img.astype(np.uint16))
    pred_all.append(pred_img)
    i = i+1
    print('image:{}/6'.format(i))
print('finish')    

image:1/6
image:2/6
image:3/6
image:4/6
image:5/6
image:6/6
finish
