In [1]:
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mask_folder = "F:\\PythonProject\\deep\\image_segmentation\\deep_learning\\data\\object_detection_segment\\segmentation"
data_folder = "F:\\PythonProject\\deep\\image_segmentation\\deep_learning\\data\\object_detection_segment\\object_detection/"
sr_data_folder = "../data/super_resolution/"
checkpoint = "../chapter_three/net.pth"
sr_checkpoint = "../datachapter_three/sr.pth"

batch_size = 8
lr = 1e-3
epoch_lr = [(20,0.01),(10,0.001),(10,0.0001)]

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
# 图像增强
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, ToPILImage
from sklearn.model_selection import train_test_split
from glob import glob
import os.path as osp
from PIL import Image
import numpy as np
import cv2

class Compose:
    def __init__(self,transform_list):
        self.transform_list = transform_list
    def __call__(self, img, mask):
        for transform in self.transform_list:
            img, mask = transform(img,mask)
        return img, mask

class ToArraySegment:
    def __call__(self,img,mask):
        img = np.array(img)
        mask = np.array(mask)
        return img,mask

class ToTensorSegment:
    def __call__(self,img,mask):
        return torch.from_numpy(img).permute(2,0,1).float()/255.,torch.from_numpy(mask).float()/255.

class Resize:
    def __init__(self,size=320):
        self.size = size
    def __call__(self,img,mask):
        img = cv2.resize(img,(self.size,self.size))
        mask = cv2.resize(mask,(self.size,self.size))
        return img,mask

class Expand:
    def __call__(self,img,mask):
        if np.random.randint(2):
            width,_,channels = img.shape
            ratio = np.random.uniform()
            expand_img = np.zeros((int(width*(1+ratio)),int(width*(1+ratio)),channels))
            expand_mask = np.zeros((int(width*(1+ratio)),int(width*(1+ratio))))
            left = np.random.uniform(0,width*ratio)
            top = np.random.uniform(0,width*ratio)
            left = int(left)
            top = int(top)
            expand_img[top:top+width,left:left+width,:] = img
            expand_mask[top:top+width,left:left+width]=mask

            return expand_img,expand_mask
        else:
            return img,mask

class MIrror:
    def __call__(self,img,mask):
        #在绝对坐标系啊运行
        if np.random.randint(2):
            width = img.shape[0]
            img = img[:,::-1]
            mask = mask[:,::-1]
            return img,mask

class TrainTrainsform:
    def __init__(self,size=320):
        self.size = size
        self.augment = Compose([
            ToArraySegment(),
            MIrror(),
            Expand(),
            Resize(self.size),
            ToTensorSegment()
        ])
    def __call__(self,img,mask):
        img,mask = self.augment(img,mask)
        return img,mask

class TestTrainsform:
    def __init__(self,size=320):
        self.size = size
        self.augment=Compose([
            ToArraySegment(),
            Resize(self.size),
            ToTensorSegment()
        ])

    def __call__(self,img,mask):
        img,mask = self.augment(img,mask)
        return img,mask
        

In [20]:
#数据加载
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image
from glob import glob
import os 
from sklearn.model_selection import train_test_split
#from transform import TrainTransform,TestTrainsform
#from config import data_folder,mask_folder

class SegmentationData(Dataset):
    def __init__(self,data_folder=data_folder,mask_folder=mask_folder,subset="train",trainsform=None):
        super(SegmentationData,self).__init__()
        img_paths = sorted(glob(os.path.join(data_folder,"*.jpg")))
        mask_paths = sorted(glob(os.path.join(mask_folder,"*.jpg")))
        for i in range(len(img_paths)):
            assert os.path.basename(img_paths[i])==os.path.basename(mask_paths[i])
        img_paths_train,img_paths_test,mask_paths_train,mask_paths_test = train_test_split(img_paths,mask_paths,test_size=0.2,random_state=20)
        if subset=="train":
            self.img_paths=img_paths_train
            self.mask_paths = mask_paths_train
        else:
            self.img_paths = img_paths_test
            self.mask_paths = mask_paths_test

        self.transform = trainsform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index]).resize((224,224))
        mask_path = self.mask_paths[index]
        mask = Image.open(mask_path).resize((224,224)).convert("L")
        if self.transform:
            image,mask = self.transform(image,mask)
        else:
            image,mask = ToTensor()(image),ToTensor()(mask)
        return image,mask

    def __len__(self):
        return len(self.img_paths)

if __name__=="__main__":
    topil = ToPILImage()
    transform = TrainTrainsform()
    data = SegmentationData(data_folder,mask_folder,transform)
    image,mask = data[0]
    image,mask = topil(image),topil(mask)
    image.save("./sample.jpg")
    mask.save("./sample_mask.jpg")
    image.show()

In [23]:
#模型搭建
import torch
from torch import nn
from torchvision.models import resnet18

class DecoderBlock(nn.Module):
    def __init__(self,in_channel,out_channel,kernel_size) -> None:
        super(DecoderBlock,self).__init__()
        #卷积
        self.conv1 = nn.Conv2d(in_channel,in_channel//4,kernel_size,padding=1,bias=False)
        self.bn1 = nn.BatchNorm2d(in_channel//4)
        self.relu1 = nn.ReLU(inplace=True)

        #反卷积
        self.deconv = nn.ConvTranspose2d(
            in_channel//4,
            in_channel//4,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(in_channel//4)
        self.relu2 = nn.ReLU(inplace=True)

        #卷积
        self.conv3 = nn.Conv2d(
            in_channel//4,
            out_channel,
            kernel_size=kernel_size,
            padding=1,
            bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channel)
        self.relu3 = nn.ReLU(inplace=True)

    def forward(self,x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.deconv(x)))
        x = self.relu3(self.bn3(self.conv3(x)))

        return x


class ResNet18Unet(nn.Module):
    def __init__(self,num_classes=2,pretrained=True) -> None:
        super(ResNet18Unet,self).__init__()
        base = resnet18(pretrained=pretrained)
        self.firstconv = base.conv1
        self.firstbn = base.bn1
        self.firstrelu = base.relu
        self.firstmaxpool = base.maxpool
        self.encoder1 = base.layer1
        self.encoder2 = base.layer2
        self.encoder3 = base.layer3
        self.encoder4 = base.layer4

        out_channels = [64,128,256,512]
        self.center = DecoderBlock(
            in_channel=out_channels[3],
            out_channel=out_channels[3],
            kernel_size=3
        )
        self.decoder4 = DecoderBlock(
            in_channel=out_channels[3]+out_channels[2],
            out_channel=out_channels[2],
            kernel_size=3
        )
        self.decoder3 = DecoderBlock(
            in_channel=out_channels[2]+out_channels[1],
            out_channel=out_channels[1],
            kernel_size=3
        )
        self.decoder2 = DecoderBlock(
            in_channel=out_channels[1]+out_channels[0],
            out_channel=out_channels[0],
            kernel_size=3
        )
        self.decoder1 = DecoderBlock(
            in_channel=out_channels[0]+out_channels[0],
            out_channel=out_channels[0],
            kernel_size=3
        )
        self.finalconv = nn.Sequential(
            nn.Conv2d(out_channels[0],32,3,padding=1,bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.1,False),
            nn.Conv2d(32,num_classes,1)
        )

    def forward(self,x):
        x = self.firstconv(x)
        x = self.firstbn(x)
        x = self.firstrelu(x)
        x_ = self.firstmaxpool(x)

        #Encoder
        e1 = self.encoder1(x_)

        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)

        center = self.center(e4)

        d4 = self.decoder4(torch.cat([center,e3],1))
        d3 = self.decoder3(torch.cat([d4,e2],1))
        d2 = self.decoder2(torch.cat([d3,e1],1))
        d1 = self.decoder1(torch.cat([d2,x],1))

        f = self.finalconv(d1)

        return f

if __name__=="__main__":
    net = ResNet18Unet(pretrained=False)
    img = torch.rand(1,3,320,320)
    out = net(img)
    print(out.shape)



torch.Size([1, 2, 320, 320])


In [24]:
#模型训练
import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from tqdm import tqdm
import os.path as osp

def Train():
    net = ResNet18Unet().to(device=device)
    trainTrainsform = TrainTrainsform()
    trainset = SegmentationData(data_folder,mask_folder,trainTrainsform)
    testset = SegmentationData(data_folder,mask_folder,subset="test",trainsform=TestTrainsform())
    trainloader = DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=0)
    print(next(iter(trainloader)))
    testloader = DataLoader(testset,batch_size=batch_size,shuffle=True,num_workers=0)
    cirteron = nn.CrossEntropyLoss(weight=torch.Tensor([0.3,1.0]).to(device=device))
    best_loss = 1e9
    if osp.exists(checkpoint):
        ckpt = torch.load(checkpoint)
        best_loss = ckpt["loss"]
        net.load_state_dict(ckpt["params"])
        print("checkpoint loaded……")
    writer = SummaryWriter("logs")
    for n,(num_eopchs,lr) in enumerate(epoch_lr):
        optimizer = optim.SGD(net.parameters(),lr=lr,momentum=0.9,weight_decay=5e-3)
        for epoch in range(num_eopchs):
            net.train()
            #pbar = tqdm(enumerate(trainloader),total=len(trainloader))
            epoch_loss = 0.0
            for i, (img,mask) in enumerate(trainloader):
                out = net(img.to(device))
                loss = cirteron(out,mask.to(device).long().squeeze(1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if i%10==0:
                    #pbar.set_description("loss:{}".format(loss))
                    pass
                epoch_loss+=loss.item()
                print("Epoch_loss:{}".format(epoch_loss/len(trainloader.dataset)))
                writer.add_scalar("seg_epoch_loss",epoch_loss/len(trainloader.dataset),sum([e[0] for e in epoch_lr[:n]])+epoch)
                with torch.no_grad():
                    net.eval()
                    test_loss = 0.0
                    for i,(img,mask) in tqdm(enumerate(testloader),total=len(testloader)):
                        out = net(img.to(device))
                        loss = cirteron(out,mask.to(device).long().squeeze(1))
                        test_loss+=loss.item()
                    print("Test_loss:{}".format(test_loss/len(testloader.dataset)))
                    writer.add_scalar(
                        "seg_test_loss",
                        test_loss/len(testloader.dataset),
                        sum([e[0] for e in epoch_lr[:n]])+epoch
                    )
                if test_loss<best_loss:
                    best_loss=test_loss
                    torch.save({"params":net.state_dict(),"loss":test_loss},checkpoint)
    writer.close()

if __name__=="__main__":
    Train()


[tensor([[[[0.3020, 0.3020, 0.3020,  ..., 0.3804, 0.3804, 0.3804],
          [0.3020, 0.3020, 0.3020,  ..., 0.3804, 0.3804, 0.3804],
          [0.3059, 0.3059, 0.3059,  ..., 0.3843, 0.3843, 0.3843],
          ...,
          [0.3137, 0.3137, 0.3137,  ..., 0.3882, 0.3882, 0.3882],
          [0.3137, 0.3137, 0.3137,  ..., 0.3882, 0.3882, 0.3882],
          [0.3137, 0.3137, 0.3137,  ..., 0.3882, 0.3882, 0.3882]],

         [[0.5647, 0.5647, 0.5647,  ..., 0.5608, 0.5608, 0.5608],
          [0.5647, 0.5647, 0.5647,  ..., 0.5608, 0.5608, 0.5608],
          [0.5686, 0.5686, 0.5686,  ..., 0.5647, 0.5647, 0.5647],
          ...,
          [0.5412, 0.5412, 0.5412,  ..., 0.5686, 0.5686, 0.5686],
          [0.5412, 0.5412, 0.5412,  ..., 0.5686, 0.5686, 0.5686],
          [0.5412, 0.5412, 0.5412,  ..., 0.5686, 0.5686, 0.5686]],

         [[0.9216, 0.9216, 0.9216,  ..., 0.8980, 0.8980, 0.8980],
          [0.9216, 0.9216, 0.9216,  ..., 0.8980, 0.8980, 0.8980],
          [0.9255, 0.9255, 0.9255,  ..., 

100%|██████████| 125/125 [00:16<00:00,  7.48it/s]

Test_loss:0.09158184945583343





FileNotFoundError: [Errno 2] No such file or directory: '../chapter_three/net.pth'