In [1]:
import torch
import os
import cv2
import random
import numpy as np
from torchvision import models
from torch.utils.data import Dataset, DataLoader

model = models.segmentation.deeplabv3_resnet50()
model.classifier[4] = torch.nn.Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
model = model.cuda()

model.load_state_dict(torch.load(r"work/nongye/pretrainModel/weight8.t7"))
#for param in model.parameters():
#    param.requires_grad = False

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [2]:
class DatGenerator(Dataset):
    def __init__(self):
        self.imgpath1 = r"work/2048train/data1"
        self.imgpath2 = r"work/2048train/data2"
        
        self.labelpath1 = r"work/2048train/label1"
        self.labelpath2 = r"work/2048train/label2"
        
        self.imglist1 = os.listdir(self.imgpath1)
        self.imglist2 = os.listdir(self.imgpath2)
        
        self.flip = 0.5
        self.rotate = 0.5
        self.mirror = 0.5
        self.angle45 = 0.5
        
        self.length1 = len(self.imglist1) 
        self.length2 = len(self.imglist2)
        
    def __len__(self):
        return len(self.imglist1) + len(self.imglist2)
        
    def __getitem__(self,index):
        if(index < self.length1):
            img = cv2.imread(os.path.join(self.imgpath1, self.imglist1[index]))
            label = cv2.imread(os.path.join(self.labelpath1, self.imglist1[index]))[:,:,0]
            #y = np.zeros(label.shape + (4,))
            #for i in range(4):
            #    y[:,:,i] = (label == i)
            img,label = self.transfor(img, label, self.imgpath1, self.imglist1[index], self.labelpath1)
            return img.transpose(2,0,1), label
        else:
            img = cv2.imread(os.path.join(self.imgpath2, self.imglist2[index - self.length1]))
            label = cv2.imread(os.path.join(self.labelpath2, self.imglist2[index - self.length1]))[:,:,0]
            img,label = self.transfor(img, label, self.imgpath2, self.imglist2[index - self.length1], self.labelpath2)
            return img.transpose(2,0,1), label
            
            
    def transfor(self, Img, Label, foldPath, ImgName, LabelPath):
        if(random.random() < self.angle45):
            RotateMatrix = cv2.getRotationMatrix2D(center=(1024, 1024), angle=-45, scale=1)
            RotImg = cv2.warpAffine(Img, RotateMatrix, (1712, 1712))
            RotLabel = cv2.warpAffine(Label, RotateMatrix, (1712, 1712))
            x = random.randrange(310,1200)
            y = random.randrange(310,1200)
            Img = RotImg[x:x+512,y:y+512]
            Label = RotLabel[x:x+512,y:y+512]
        else:
            cx = random.randrange(0, 1500)
            cy = random.randrange(0, 1500)  
            Img = Img[cx:cx+512, cy:cy+512,:]
            Label = Label[cx:cx+512, cy:cy+512]
        if(random.random() < self.flip):
            Img = cv2.flip(Img, 0)
            Label = cv2.flip(Label, 0)
            
        if(random.random() < self.mirror):
            Img = cv2.flip(Img, 1)
            Label = cv2.flip(Label, 1)
            
        if(random.random() < self.rotate):
            Img = Img.transpose(1,0,2)
            Label = Label.transpose(1,0)
        return Img,Label
          
datloader = DataLoader(DatGenerator(), batch_size = 12, shuffle = True, drop_last = True)

In [6]:
def focalloss(res, label,epoch, type_num = 4):
    mask = np.zeros_like(res.cpu().detach().numpy())
    label_weight = np.copy(mask)
    weight_mask = np.zeros_like(res.cpu().detach().numpy())
    weight_mask = weight_mask.astype(np.float32)
    weight_mask[:,0,:,:] = 0.2
    weight_mask[:,1,:,:] = 1.0
    weight_mask[:,2,:,:] = 7.0
    weight_mask[:,3,:,:] = 4.0
    
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(11,11))
    for batch in range(res.shape[0]):
        for i in range(type_num):
            temp = (label[batch] == i).numpy()
            mask[batch][i] = temp
            dilate_img = cv2.dilate(temp,kernel) 
            label_weight[batch][i] = 1 - (dilate_img - temp)
            
    
    mask = mask.astype(np.float32)
    """"""
    #数据集过脏 降低2头的loss， 当label=1时，降低pred<0.1的权重。因为差别太大，可能为脏样本。 反之，当label=0时，降低pred>0.9的权重
    GHMweight_mask = np.copy(mask)
    GHMweight_mask = torch.from_numpy(GHMweight_mask).cuda()
    #GHMweight_p = (0 < torch.mul(GHMweight_mask, res)) & (torch.mul(GHMweight_mask, res) < 0.1)
    GHMweight_n = 0.9 < torch.mul(1 - GHMweight_mask, res)
    #GHMweight_p = GHMweight_p.float()
    GHMweight_n = GHMweight_n.float()
    
    GHM_ones_mask = 1 - GHMweight_n#(GHMweight_p + GHMweight_n)
    #GHMweight_p = (torch.mul(GHMweight_p, res) * 10.0)**3
    GHMweight_n = GHMweight_n * (10.0 - 10.0 * torch.mul(GHMweight_n, res))**4
    GHMweight = GHM_ones_mask + GHMweight_n# + GHMweight_p
    #GHMweight =1 - GHMweight_n * 0.9#(GHMweight_p * 0.9 + GHMweight_n * 0.9)
    
    mask = torch.from_numpy(mask).cuda()
    label_weight = torch.from_numpy(label_weight).cuda()
    weight_mask = torch.from_numpy(weight_mask).cuda()
    
    positive_loss = ((1 - torch.mul(mask, res))**2) * torch.log(torch.mul(mask, res) + (1.000001 - mask)) #添加平滑，否则会nan
    negative_loss = (torch.mul(res, 1 - mask)**2) * torch.log(1.000001 - torch.mul(res, 1 - mask))
    #loss = -1 * (positive_loss + negative_loss)
    loss = -1 * (torch.mul(mask, positive_loss) + torch.mul(1 - mask, negative_loss))
    loss = torch.mul(loss,weight_mask)
    loss = torch.mul(loss,label_weight)
    #if(epoch > 2):
    #    loss = torch.mul(loss,GHMweight.float())
    return torch.mean(loss)

In [4]:
class ValGenerator(Dataset):
    def __init__(self):
        self.imgpath1 = r"work/2048val/data1"
        self.imgpath2 = r"work/2048val/data2"
        
        self.labelpath1 = r"work/2048val/label1"
        self.labelpath2 = r"work/2048val/label2"
        
        self.imglist1 = os.listdir(self.imgpath1)
        self.imglist2 = os.listdir(self.imgpath2)
        
        self.length1 = len(self.imglist1) 
        self.length2 = len(self.imglist2)
        
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        
    def __len__(self):
        return len(self.imglist1) + len(self.imglist2)
        
    def __getitem__(self,index):
        if(index < self.length1):
            img = cv2.imread(os.path.join(self.imgpath1, self.imglist1[index]))
            label = cv2.imread(os.path.join(self.labelpath1, self.imglist1[index]))[:,:,0]
            img = self.process(img)
            return img.transpose(2,0,1), label
        else:
            img = cv2.imread(os.path.join(self.imgpath2, self.imglist2[index - self.length1]))
            label = cv2.imread(os.path.join(self.labelpath2, self.imglist2[index - self.length1]))[:,:,0]
            img = self.process(img)
            return img.transpose(2,0,1), label
    def process(self,img):
        (b, g, r) = cv2.split(img)
        bH = self.clahe.apply(b)
        gH = self.clahe.apply(g)
        rH = self.clahe.apply(r)
        img = cv2.merge((bH, gH, rH))
        return cv2.GaussianBlur(img,(3,3),0)

def val():
    valloader = DataLoader(ValGenerator(), batch_size = 2, shuffle = True, drop_last = True)
    model.eval()
    TP_sum, FP_sum, FN_sum = [0,0,0], [0,0,0], [0,0,0]
    for val_data, val_label in valloader:
        res = model(val_data.cuda().float())
        res = torch.softmax(res["out"], dim = 1)
        pred = torch.argmax(res,dim = 1)
        
        for i in [1,2,3]:
            TP = (pred == i) * (val_label.cuda() == i)
            FN = (pred != i) * (val_label.cuda() == i)
            FP = (pred == i) * (val_label.cuda() != i)
            
            TP_sum[i-1] += torch.sum(TP)
            FP_sum[i-1] += torch.sum(FP)
            FN_sum[i-1] += torch.sum(FN)
    iou1 = TP_sum[0].float()/(TP_sum[0] + FN_sum[0] + FP_sum[0]  + 1e-6).float()
    iou2 = TP_sum[1].float()/(TP_sum[1] + FN_sum[1] + FP_sum[1]  + 1e-6).float()
    iou3 = TP_sum[2].float()/(TP_sum[2] + FN_sum[2] + FP_sum[2]  + 1e-6).float()
    iou = (iou1 + iou2 + iou3)/3
    print("val iou = %s, iou1 = %s, iou2 = %s, iou3 = %s"%(iou, iou1, iou2, iou3))      
    model.train()

In [8]:
#loss_fn = torch.nn.CrossEntropyLoss()

model.train(True)
iou_perepoch = []
iou_mean = [0,0,0]
lr = 1e-5
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(16,71):
    count = 0
    if(epoch % 20 == 0):
        lr /= 2
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for train_data,train_label in datloader:
        res = model(train_data.cuda().float())
        res = torch.sigmoid(res["out"])
        pred = torch.argmax(res,dim = 1)
        optimizer.zero_grad()
        #loss = loss_fn(res, train_label.cuda().long())
        loss = focalloss(res, train_label,epoch)
        loss.backward()
        optimizer.step()
        #print(loss)
        #print(torch.sum(pred == train_label.cuda().long()))
        count += 1
        for i in [1,2,3]:
            TP = (pred == i) * (train_label.cuda() == i)
            FN = (pred != i) * (train_label.cuda() == i)
            FP = (pred == i) * (train_label.cuda() != i)
            iou = torch.sum(TP).float()/(torch.sum(TP) + torch.sum(FP) + torch.sum(FN) + 1e-6)
            if(torch.sum(train_label.cuda() == i) != 0):
                iou_mean[i-1] = iou_mean[i-1] * 0.97 + iou.item() * 0.03
        
        #if(count % 180 == 0):    
    iou = (iou_mean[0]*2 + iou_mean[1]*1.3 + iou_mean[2]*3.1) / 6.4
    print("epoch = %s, lr = %s, loss = %s"%(epoch, lr,loss.item()))
    print("iou = %.4f, iou1 = %.4f, iou2 = %.4f, iou3 = %.4f"%(iou,iou_mean[0],iou_mean[1],iou_mean[2]))
    
    if(epoch % 5 == 0):
        torch.save(model.state_dict(), r"work/nongye/7_19MODEL/weight%d.t7"%(epoch//5))
        val()
        iou_perepoch.append(iou)
        pixCaculate = [0,0,0,0]
        GTCaculate = [0,0,0,0]
        for i in range(4):
            pixCaculate[i] = torch.sum(pred == i).item()
            GTCaculate[i] = torch.sum(train_label == i)
        pixTotal = sum(pixCaculate)
        GTTotal = sum(GTCaculate).float()
        print("pix0 = %s, pix1 = %s, pix2 = %s, pix3 = %s"%(pixCaculate[0]/pixTotal, pixCaculate[1]/pixTotal, pixCaculate[2]/pixTotal, pixCaculate[3]/pixTotal))
        print("GT0 = %s, GT1 = %s, GT2 = %s, GT3 = %s"%(GTCaculate[0]/GTTotal, GTCaculate[1]/GTTotal, GTCaculate[2]/GTTotal, GTCaculate[3]/GTTotal))

In [9]:
# 生成结果
class TestGenerator(Dataset):
    def __init__(self):
        self.imgpath1 = r"work/nongye/new_test/new_test/test3"
        self.imgpath2 = r"work/nongye/new_test/new_test/test4"
        
        self.imglist1 = os.listdir(self.imgpath1)
        self.imglist2 = os.listdir(self.imgpath2)
        
        self.length1 = len(self.imglist1) 
        self.length2 = len(self.imglist2)
        
    def __len__(self):
        return len(self.imglist1) + len(self.imglist2)
        
    def __getitem__(self,index):
        if(index < self.length1):
            img = cv2.imread(os.path.join(self.imgpath1, self.imglist1[index]))
            return img.transpose(2,0,1), self.imglist1[index]
        else:
            img = cv2.imread(os.path.join(self.imgpath2, self.imglist2[index - self.length1]))
            return img.transpose(2,0,1), self.imglist2[index - self.length1]
          
testloader = DataLoader(TestGenerator(), batch_size = 4)
model.eval()
count = 0
for test_img, testImgName in testloader:
    res = model(test_img.cuda().float())
    res = torch.sigmoid(res["out"])
    pred = torch.argmax(res,dim = 1)
    pred = pred.cpu().numpy()
    count += 1
    if(count%400 == 0):
        print("count == ",count)
    for i in range(res.shape[0]):
        name = os.path.join(r"work/nongye/7_19/epoch=15", testImgName[i])
        cv2.imwrite(name, pred[i])

count ==  400
count ==  800
count ==  1200
count ==  1600


In [10]:
#!zip -r work/nongye/7_19/epoch=15.zip work/nongye/7_19/epoch=15

  adding: work/nongye/7_19/epoch=15/image3_21120_13440.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_15360_13440.bmp (deflated 98%)  adding: work/nongye/7_19/epoch=15/image4_18816_21504.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_8832_6144.bmp (deflated 98%)  adding: work/nongye/7_19/epoch=15/image4_3072_19968.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_20352_11136.bmp (deflated 98%)  adding: work/nongye/7_19/epoch=15/image4_12672_24576.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image4_4608_10752.bmp (deflated 98%)  adding: work/nongye/7_19/epoch=15/image3_384_14208.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image4_4992_15360.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_14208_11904.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_23424_18816.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=15/image3_14976_9600.bmp (deflated 98%)  adding: work/nongye/7_19/epoch=15/image3_921

In [11]:
#!unzip data/data9014/new_train.zip
#!unzip data/data9287/new_train1.zip -d work/2048train/
#!zip -r work/nongye/7_19/epoch=8.zip work/nongye/7_19/epoch=8

  adding: work/nongye/7_19/epoch=8/image3_21120_13440.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_15360_13440.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image4_18816_21504.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_8832_6144.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image4_3072_19968.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_20352_11136.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image4_12672_24576.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image4_4608_10752.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_384_14208.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image4_4992_15360.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_14208_11904.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_23424_18816.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_14976_9600.bmp (deflated 99%)  adding: work/nongye/7_19/epoch=8/image3_9216_12288.bmp (d

In [26]:
"""
import shutil

def splitValidate():
    dat1 = os.listdir(r"work/2048train/data1")
    dat2 = os.listdir(r"work/2048train/data2")
    label1 = os.listdir(r"work/2048train/label1") 
    label2 = os.listdir(r"work/2048train/label2") 
    random.shuffle(dat1)
    random.shuffle(dat2)
    
    print("dat1.len = %s, dat2.len = %s"%(len(dat1), len(dat2)))
    
    val1 = dat1[:15]
    val2 = dat2[:15]
    
    for imgname in val1:
        img = cv2.imread(os.path.join(r"work/2048train/data1", imgname))
        label = cv2.imread(os.path.join(r"work/2048train/label1", imgname))
        for x in range(0,2048,512):
            for y in range(0,2048,512):
                new_name = imgname[:4] + "_%s_%s"%(x,y)+".bmp"
                cv2.imwrite(os.path.join(r"work/2048val/data1",new_name), img[x:x+512,y:y+512])
                cv2.imwrite(os.path.join(r"work/2048val/label1",new_name), label[x:x+512,y:y+512])
        srcpath = os.path.join(r"work/2048train/data1", imgname)
        newpath = os.path.join(r"work/2048val/temp/data", imgname)
        shutil.move(srcpath, newpath)
        
        srcpath = os.path.join(r"work/2048train/label1", imgname)
        newpath = os.path.join(r"work/2048val/temp/label", imgname)
        shutil.move(srcpath, newpath)
    
    for imgname in val2:
        img = cv2.imread(os.path.join(r"work/2048train/data2", imgname))
        label = cv2.imread(os.path.join(r"work/2048train/label2", imgname))
        for x in range(0,2048,512):
            for y in range(0,2048,512):
                new_name = imgname[:4] + "_%s_%s"%(x,y)+".bmp"
                cv2.imwrite(os.path.join(r"work/2048val/data2",new_name), img[x:x+512,y:y+512])
                cv2.imwrite(os.path.join(r"work/2048val/label2",new_name), label[x:x+512,y:y+512])
        srcpath = os.path.join(r"work/2048train/data2", imgname)
        newpath = os.path.join(r"work/2048val/temp/data", imgname)
        shutil.move(srcpath, newpath)
        
        srcpath = os.path.join(r"work/2048train/label2", imgname)
        newpath = os.path.join(r"work/2048val/temp/label", imgname)
        shutil.move(srcpath, newpath)
        
    dat1 = os.listdir(r"work/2048train/data1")
    dat2 = os.listdir(r"work/2048train/data2")
    print("dat1.len = %s, dat2.len = %s"%(len(dat1), len(dat2)))
"""    

dat1.len = 320, dat2.len = 402
