# 基于VGG主干的FCN网络的图像分割

**TGS Salt Identification Challenge**

Segment salt deposits beneath the Earth's surface

![image](https://storage.googleapis.com/kaggle-competitions/kaggle/10151/logos/header.png)

数据集网址：https://www.kaggle.com/c/tgs-salt-identification-challenge

In [None]:
!cat /proc/cpuinfo | grep name | cut -f2 -d: | uniq -c
!nvidia-smi

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install torchsummary

In [None]:
# Set your own project id here
PROJECT_ID = 'fcn8_resnet50'
from google.cloud import storage
storage_client = storage.Client(project=PROJECT_ID)
def upload_files(bucket_name, source_folder):
 bucket = storage_client.get_bucket(bucket_name)
 for filename in os.listdir(source_folder):
 
  blob = bucket.blob(filename)
  blob.upload_from_filename(source_folder + filename)

In [None]:
!mkdir -p '../content/kgs_data'
!cp '/kaggle/input/tgs-salt-identification-challenge/train.zip' -d '../content/kgs_data_train.zip'
!unzip '../content/kgs_data_train.zip' -d '../content/kgs_data'

## 数据预处理
#### 调用必要库

In [None]:
import numpy as np
import os
import time 
import math
import glob
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T
from torchsummary import summary

def timeSince(since):
  now = time.time()
  s = now - since
  m = math.floor(s / 60)
  s -= m * 60
  return '%dm %ds' % (m, s)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("current device is : ", device)

In [None]:
!ls

#### 确认数据集位置（查看train）
*@note*：一定要解压在非云盘位置，避免后期训练集因为图片读取速率而导致的速度问题

In [None]:
image_path = "../content/kgs_data/images"
mask_path = "../content/kgs_data/masks"

#### 样例输出

In [None]:
# 从训练集中随意挑选几张图片和蒙版，输出来看看
names = ['000e218f21','41cfd4b320','3c2f5ba174']
images = [Image.open(os.path.join(image_path, name+'.png')) for name in names]
masks = [Image.open(os.path.join(mask_path, name+'.png')) for name in names]

'''Transform 用法
transform = transforms.Compose([
    transforms.Grayscale(),  # 将图像转化为灰度图
    transforms.RandomCrop(32, padding=4),  #先四周填充0，在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转，一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),   # 图像转tensor
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])
'''
transforms1 = T.Resize((56,56),interpolation=Image.NEAREST)
transforms2 = T.Compose([T.Grayscale(), T.ToTensor()]) # 转换模板
x = torch.stack([transforms2(transforms1(image)) for image in images])
y = torch.stack([transforms2(transforms1(mask)) for mask in masks])
print(x.size())
fig = plt.figure( figsize=(9, 9))

ax = fig.add_subplot(331)
plt.imshow(images[0])
ax = fig.add_subplot(332)
plt.imshow(masks[0])
ax = fig.add_subplot(333)
ax.imshow(x[0].squeeze(), cmap="Greys")
ax.imshow(y[0].squeeze(), alpha=0.5, cmap="Greens_r")

ax = fig.add_subplot(334)
plt.imshow(images[1])
ax = fig.add_subplot(335)
plt.imshow(masks[1])
ax = fig.add_subplot(336)
ax.imshow(x[1].squeeze(), cmap="Greys")
ax.imshow(y[1].squeeze(), alpha=0.5, cmap="Greens_r")

ax = fig.add_subplot(337)
plt.imshow(images[2])
ax = fig.add_subplot(338)
plt.imshow(masks[2])
ax = fig.add_subplot(339)
ax.imshow(x[2].squeeze(), cmap="Greys")
ax.imshow(y[2].squeeze(), alpha=0.5, cmap="Greens_r")

plt.show()


#### 训练集获取

In [None]:
class segmentDataset(Dataset):
    def __init__(self, image_path, mask_path):
        self.image_path = image_path
        self.mask_path = mask_path
        
        # 根据所规定的pattern，返回图片目录组成的list
        image_list= glob.glob(image_path +'/*.png')
        # 利用for循环的语句获取图像文件的文件名
        sample_names = []
        for file in image_list:
            sample_names.append(file.split('/')[-1].split('.')[0])
        self.sample_names = sample_names
        # 图像缩放，获得(224, 224)
        # 获取灰度图 + 将其转化为tensor
        self.transforms = T.Compose([T.Grayscale(), T.ToTensor(), T.Resize((112,112),interpolation=Image.NEAREST)])
            
    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.image_path, self.sample_names[idx]+'.png') )
        mask = Image.open(os.path.join(self.mask_path, self.sample_names[idx]+'.png') )
        return self.transforms(image), self.transforms(mask)

    def __len__(self):
        return len(self.sample_names)


In [None]:
train_dataset = segmentDataset(image_path = image_path, mask_path = mask_path)

## FCN(fully Convolutional Networks)全卷积神经网络
提出论文：[Fully Convolutional Networks
for Semantic Segmentation](https://arxiv.org/pdf/1605.06211.pdf)

### 与VGG-Net的不同之处

对于一般的分类CNN网络，如VGG和Resnet，都会在网络的最后加入一些全连接层，经过softmax后就可以获得类别概率信息。
```python
# 三层线性全连接层
self.classifier = nn.Sequential(
    # first layer
    nn.Linear(512 * 7 * 7, 4096, bias = True),
    nn.ReLU(True),
    nn.Dropout(),
    # second layer
    nn.Linear(4096, 4096, bias = True),
    nn.ReLU(True),
    nn.Dropout(),
    # third layer
    nn.Linear(4096, num_classes, bias = True),
)
```
但是这个概率信息是1维的，即只能标识整个图片的类别，不能标识每个像素点的类别，所以这种全连接方法不适用于图像分割。

强行训练会在评估时与评估函数发生冲突，如下：
```bash
ValueError: Target size (torch.Size([64, 1, 224, 224])) must be the same as input size (torch.Size([64, 1000]))
```

为此，FCN提出可以把后面几个全连接都换成卷积，这样就可以获得一张2维的feature map，后接softmax获得每个像素点的分类信息，从而解决了分割问题，如下图所示：

![iamge](https://pic1.zhimg.com/80/v2-7bfe6e1792c2fb8bcfab6eea632d5e2c_720w.jpg)

![image](https://pic1.zhimg.com/80/v2-721ef7417b32a5aa4973f1e8dd16d90c_720w.jpg)

1. 对于FCN-32s，直接对pool5 feature进行32倍上采样获得32x upsampled feature，再对32x upsampled feature每个点做softmax prediction获得32x upsampled feature prediction（即分割图）。
2. 对于FCN-16s，首先对pool5 feature进行2倍上采样获得2x upsampled feature，再把pool4 feature和2x upsampled feature逐点相加，然后对相加的feature进行16倍上采样，并softmax prediction，获得16x upsampled feature prediction。
3. 对于FCN-8s，首先进行pool4+2x upsampled feature逐点相加，然后又进行pool3+2x upsampled逐点相加，即进行更多次特征融合。具体过程与16s类似，不再赘述。

### 总结
FCN网络是以VGG为主干，将分类器用反卷积进行替换，有点类似于这Encode-Decode过程。

### 参考
1. https://zhuanlan.zhihu.com/p/31428783
2. https://zhuanlan.zhihu.com/p/32506912


In [None]:
import torchvision.models as models
       

 
class FCNx8_ResNet(nn.Module):
    debug_info = True
    def __init__(self,num_classes = 1):
        super(FCNx8_ResNet, self).__init__()
        pretrained_net = models.resnet50(pretrained=True)
        pretrained_net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        conv_sequential= list(pretrained_net.children())[:-1]
        summary(pretrained_net.to(device), (1, 56, 56))
        modules_list = []
        for i in range(4):
            modules_list.append(conv_sequential[i])
        self.head = nn.Sequential(*modules_list)
 
        modules_list = []
        for i in range(4,6):
            temp = list(conv_sequential[i])
            for j in range(len(temp)):
                modules_list.append(temp[j])
        self.stage1 = nn.Sequential(*modules_list)
 
        modules_list = []
        temp = list(conv_sequential[6])
        for j in range(len(temp)):
            modules_list.append(temp[j])
        self.stage2 = nn.Sequential(*modules_list)
 
        modules_list = []
        temp = list(conv_sequential[7])
        for j in range(len(temp)):
            modules_list.append(temp[j])
        modules_list.append(conv_sequential[8])
        modules_list.append(nn.Conv2d(in_channels=2048,out_channels=1024,kernel_size=1,stride=1,padding=0))
        modules_list.append(nn.Conv2d(in_channels=1024,out_channels=512,kernel_size=1,stride=1,padding=0))
        self.stage3 = nn.Sequential(*modules_list)
 
        self.scores3 = nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1)
        self.scores2 = nn.Conv2d(in_channels=1024,out_channels=num_classes,kernel_size=1)
        self.scores1 = nn.Conv2d(in_channels=512,out_channels=num_classes,kernel_size=1)
        #
        # # N=(w-1)xs+k-2p
        self.upsamplex8 = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=16,stride=8,padding=4,bias= False)
        self.upsamplex8.weight.data = self.bilinear_kernel(num_classes,num_classes,16)
        self.upsamplex16 = nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=4,stride=2,padding=1,bias=False)
        self.upsamplex16.weight.data = self.bilinear_kernel(num_classes,num_classes,4)
        self.upsamplex32= nn.ConvTranspose2d(in_channels=num_classes,out_channels=num_classes,kernel_size=5,stride=3,padding=0,bias=False)
        self.upsamplex32.weight.data = self.bilinear_kernel(num_classes,num_classes,7)
 
    def forward(self, x):
        x = self.head(x)
        x = self.stage1(x)
        s1 = x
 
        x = self.stage2(x)
        s2 = x
 
        x = self.stage3(x)
        s3 = x
        if self.debug_info is True:
            print(s1.size(),s2.size(),s3.size())
 
        s3 = self.scores3(s3)
        s3 = self.upsamplex32(s3)
 
        s2 = self.scores2(s2)
        if self.debug_info is True:
            print(s1.size(),s2.size(),s3.size())
        s2 = s2 + s3
        s2 = self.upsamplex16(s2)
 
        s1 = self.scores1(s1)
        if self.debug_info is True:
            print(s1.size(),s2.size(),s3.size())
        s = s1 + s2
        s = self.upsamplex8(s)
         
        self.debug_info = False
        return s
 
    def bilinear_kernel(self,in_channels, out_channels, kernel_size):
        '''
        return a bilinear filter tensor
        '''
        factor = (kernel_size + 1) // 2
        if kernel_size % 2 == 1:
            center = factor - 1
        else:
            center = factor - 0.5
        og = np.ogrid[:kernel_size, :kernel_size]
        filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
        weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype='float32')
        weight[range(in_channels), range(out_channels), :, :] = filt
        return torch.from_numpy(weight)
# Example
fcn_model = FCNx8_ResNet(num_classes = 1).to(device)
# print(fcn_model)
summary(fcn_model, (1, 112, 112))

## 模型训练

In [None]:
def get_iou_score(outputs, labels):
    A = labels.squeeze().bool()
    pred = torch.where(outputs<0., torch.zeros_like(outputs), torch.ones_like(outputs))
    B = pred.squeeze().bool()
    intersection = (A & B).float().sum((1,2))
    union = (A | B).float().sum((1, 2)) 
    iou = (intersection + 1e-6) / (union + 1e-6)  
    return iou
  
def train_one_batch(model, x, y):
    # print("input x: ", x.size(),", input y = ", y.size())
    x, y = x.to(device), y.to(device)
    outputs = model(x)
    # print("outputs:", outputs.size())
    loss = loss_fn(outputs, y)
    iou = get_iou_score(outputs, y).mean()
    
    optimizer.zero_grad() # 将模型中的梯度设置为0
    loss.backward()
    optimizer.step()
    return loss.item(), iou.item()

### 保存模型和参数

In [None]:
def save_model_args(epoch):
    # 模型地址
    model_path = 'fcn{}_resnet{}_model_{}_batch_{}.pth'.format(8,50,epoch,BATCH_SIZE)  
    # 三个参数：网络参数；优化器参数；epoch
    state = {'net':fcn_model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
    torch.save(state, model_path)

    # 保存训练损失数据和IoU得分数据
    train_losses_save = np.array(train_losses)
    train_ious_save = np.array(train_ious)
    plt.plot(train_losses_save, label = 'loss')
    plt.plot(train_ious_save, label = 'IoU')
    plt.xlabel('Epoch')
    plt.ylabel('Metric')
    plt.legend()
    # 保存曲线
    plt.savefig('fcn{}_resnet{}_loss_{}_batch_{}.png'.format(8,50,200,64), bbox_inches='tight')
    plt.show()
    np.save('fcn{}_resnet{}_loss_{}_batch_{}'.format(8,50,epoch,BATCH_SIZE),train_losses_save)
    np.save('fcn{}_resnet{}_iou_{}_batch_{}'.format(8,50,epoch,BATCH_SIZE),train_ious_save)



In [None]:
!mkdir -p '../content/drive/MyDrive/实验数据&模型/fcn_resnet'

In [None]:
NUM_EPOCHS = 200
BATCH_SIZE = 64

fcn_model.train() # 一定要表明是训练模式!!!
optimizer = torch.optim.Adam(fcn_model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max')

loss_fn = nn.BCEWithLogitsLoss()

train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)
steps  = train_dataset.__len__() // BATCH_SIZE
print(steps,"steps per epoch")

start = time.time()
train_losses = []
train_ious = []
for epoch in range(1, NUM_EPOCHS + 1):
    print('-' * 10)
    print('Epoch {}/{}'.format(epoch, NUM_EPOCHS))
    running_iou = []
    running_loss = []
    for step, (x, y) in enumerate(train_dataloader):
        loss, iou = train_one_batch(fcn_model, x, y)
        running_iou.append(iou)
        running_loss.append(loss)
        print('\r{:6.1f} %\tloss {:8.4f}\tIoU {:8.4f}'.format(100*(step+1)/steps, loss,iou), end = "") 
        
    print('\r{:6.1f} %\tloss {:8.4f}\tIoU {:8.4f}\t{}'.format(100*(step+1)/steps,np.mean(running_loss),np.mean(running_iou), timeSince(start)))
    scheduler.step(np.mean(running_iou))
    
    train_losses.append(loss)
    train_ious.append(iou)
    if epoch % 50 is 0:
        save_model_args(epoch)




epoch 200/200 253m46s

### 恢复模型

In [None]:
'''
import torchvision.models as models
# 加载训练参数（损失+iou）
train_losses_save = np.load("../content/drive/MyDrive/实验数据&模型/fcn_resnet/fcn{}_resnet{}_loss_{}_batch_{}.npy".format(8,50,200,64))
train_ious_save = np.load("../content/drive/MyDrive/实验数据&模型/fcn_resnet/fcn{}_resnet{}_iou_{}_batch_{}.npy".format(8,50,200,64))


# 加载模型参数
params = torch.load('../content/drive/MyDrive/实验数据&模型/fcn_resnet/fcn{}_resnet{}_model_{}_batch_{}.pth'.format(8,50,50,64))
print(params)

# 加载模型
net = fcn_model()
pthfile = r'../content/drive/MyDrive/实验数据&模型/fcn_resnet/fcn{}_resnet{}_model_{}_batch_{}.pth'.format(8,50,50,64)
net.load_state_dict(torch.load(pthfile)['net'])
print(net)
'''