## GAMMA Challenge Subtask 3 - Official Baseline

Link for GAMMA:

	MICCAI2021 Contest - GAMMA: https://aistudio.baidu.com/aistudio/competition/detail/90

Challenge Description:

	The GAMMA Challenge is an international ophthalmology competition held by Baidu at the MICCAI2021 seminar OMIA8. MICCAI is a comprehensive academic conference in the fields of medical image computing and computer assisted intervention, and is the top conference in these fields. OMIA is an Ophthalmic Medical Image Analysis seminar organized by Baidu at the MICCAI conference, which has been held for eight sessions so far.

    The GAMMA Challenge focused on glaucoma analysis in multimodal images and consisted of three sub-tasks:  
1) glaucoma grading, 2) macular fovea localization, 3) optic disc and cup segmentation.  
    
Task Description of this baseline

	This baseline corresponds to Task 2 of the GAMMA Challenge, which is to segment the optic disc and optic cup in 2D color fundus images.
    
Dataset Description

    The dataset used for this baseline is 2D colour fundus images released in GAMMA. Users can obtain the corresponding datasets by signing up for the GAMMA challenge.

In [17]:
### remove the extraneous files in the data folder

!rm */.DS_Store
!rm */*/.DS_Store

In [2]:
### import the necessary packages

import sys 
sys.path.append('/home/aistudio/external-libraries')
import os
import cv2
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import euclidean_distances 
import matplotlib.pylab as plt

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.io import Dataset


### Config

In [3]:
### set the parameters in your framework

images_file = ''  # the path to the training data
gt_file = 'Disc_Cup_Mask/'
test_file = ''  # the path to the testing data
image_size = 256 # the image size to the network (image_size, image_size, 3)
val_ratio = 0.2  # the ratio of train/validation splitition
BATCH_SIZE = 8 # batch size
iters = 3000 # training iteration
optimizer_type = 'adam' # the optimizer, can be set as SGD, RMSprop,...
num_workers = 4 # Number of workers used to load data
init_lr = 1e-3 # the initial learning rate


### Train / Val splitition


In [20]:
### divide the training image and the verification image from the training set 

filelists = os.listdir(images_file)
train_filelists, val_filelists = train_test_split(filelists, test_size = val_ratio,random_state = 42)
print("Total Nums: {}, train: {}, val: {}".format(len(filelists), len(train_filelists), len(val_filelists)))

### DataLoader

In [4]:
### load the fundus images from the data folder, 
### and extract the corresponding ground truth to generate training samples

class FundusDataset(Dataset):
    def __init__(self, image_file, gt_path=None, filelists=None,  mode='train'):
        super(FundusDataset, self).__init__()
        self.mode = mode
        self.image_path = image_file
        image_idxs = os.listdir(self.image_path) # 0001, fundus_img in the folder 0001
        self.gt_path = gt_path

        self.file_list = [image_idxs[i] for i in range(len(image_idxs))]        
        
        if filelists is not None:
            self.file_list = [item for item in self.file_list if item in filelists] 
   
    def __getitem__(self, idx):
        real_index = self.file_list[idx]
        fundus_img_path = os.path.join(self.image_path, real_index, real_index + '.jpg')
        fundus_img = cv2.imread(fundus_img_path)[:, :, ::-1] # BGR -> RGB        
        h,w,c = fundus_img.shape

        if self.mode == 'train':
            gt_tmp_path = os.path.join(self.gt_path, real_index + '.png')
            gt_img = cv2.imread(gt_tmp_path)

            ### In the ground truth, a pixel value of 0 is the optic cup (class 0), 
            ### a pixel value of 128 is the optic disc (class 1), 
            ### and a pixel value of 255 is the background (class 2).
            
            gt_img[gt_img == 128] = 1
            gt_img[gt_img == 255] = 2
            gt_img = cv2.resize(gt_img,(image_size, image_size))
            gt_img = gt_img[:,:,1]
            # print('gt shape', gt_img.shape)           

        fundus_re = cv2.resize(fundus_img,(image_size, image_size))
        img = fundus_re.transpose(2, 0, 1) # H, W, C -> C, H, W
        # print(img.shape)
        # img = fundus_re.astype(np.float32)
        
        if self.mode == 'test':
            ### During the testing process, 
            ### the sample returns fundus image, sample name, 
            ### height and width of the original image

            return img, real_index, h, w
        if self.mode == 'train':
            ### During the training process,
            ### the sample returns fundus image and its corresponding ground truth
            
            return img, gt_img

    def __len__(self):
        return len(self.file_list)


In [22]:
### generate a _train and a _val Dataset for presenting images in the training dataset

_train = FundusDataset(image_file = images_file, 
                        gt_path = gt_file)

_val = FundusDataset(image_file = images_file, 
                        gt_path = gt_file)

In [23]:
### present five fundus images and corresponding ground truths in the _train Dataset
### there are three classes in the ground truth: 0-optic cup, 1-optic disc, 2-background

plt.figure(figsize=(15,5))

for i in range(5):
    fundus_img, label = _train.__getitem__(i)
    plt.subplot(2,5,i+1)
    plt.imshow(fundus_img.transpose(1,2,0))
    plt.axis("off")

    plt.subplot(2,5,i+6)
    plt.imshow(label)
    plt.axis("off")

    

In [24]:
### present five fundus images and corresponding ground truths in the _val Dataset

plt.figure(figsize=(15,5))

for i in range(5):
    fundus_img, label = _val.__getitem__(i)
    plt.subplot(2,5,i+1)
    plt.imshow(fundus_img.transpose(1,2,0))
    plt.axis("off")

    plt.subplot(2,5,i+6)
    plt.imshow(label)
    plt.axis("off")

   

### Network

This network is UNet.
U-NET is a U-shaped network structure, which can be seen as two large stages. The image is first sampled by the Encoder to obtain the high-level semantic feature map, and then sampled by the Decoder to restore the feature map to the resolution of the original image.
The details of the codes can be seen at https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_segmentation/image_segmentation.html

In [5]:
class SeparableConv2D(nn.Layer):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 kernel_size, 
                 stride=1, 
                 padding=0, 
                 dilation=1, 
                 groups=None, 
                 weight_attr=None, 
                 bias_attr=None, 
                 data_format="NCHW"):
        super(SeparableConv2D, self).__init__()

        self._padding = padding
        self._stride = stride
        self._dilation = dilation
        self._in_channels = in_channels
        self._data_format = data_format

        # 第一次卷积参数，没有偏置参数
        filter_shape = [in_channels, 1] + self.convert_to_list(kernel_size, 2, 'kernel_size')
        self.weight_conv = self.create_parameter(shape=filter_shape, attr=weight_attr)

        # 第二次卷积参数
        filter_shape = [out_channels, in_channels] + self.convert_to_list(1, 2, 'kernel_size')
        self.weight_pointwise = self.create_parameter(shape=filter_shape, attr=weight_attr)
        self.bias_pointwise = self.create_parameter(shape=[out_channels], 
                                                    attr=bias_attr, 
                                                    is_bias=True)
    
    def convert_to_list(self, value, n, name, dtype=np.int):
        if isinstance(value, dtype):
            return [value, ] * n
        else:
            try:
                value_list = list(value)
            except TypeError:
                raise ValueError("The " + name +
                                "'s type must be list or tuple. Received: " + str(
                                    value))
            if len(value_list) != n:
                raise ValueError("The " + name + "'s length must be " + str(n) +
                                ". Received: " + str(value))
            for single_value in value_list:
                try:
                    dtype(single_value)
                except (ValueError, TypeError):
                    raise ValueError(
                        "The " + name + "'s type must be a list or tuple of " + str(
                            n) + " " + str(dtype) + " . Received: " + str(
                                value) + " "
                        "including element " + str(single_value) + " of type" + " "
                        + str(type(single_value)))
            return value_list
    
    def forward(self, inputs):
        conv_out = F.conv2d(inputs, 
                            self.weight_conv, 
                            padding=self._padding,
                            stride=self._stride,
                            dilation=self._dilation,
                            groups=self._in_channels,
                            data_format=self._data_format)
        
        out = F.conv2d(conv_out,
                       self.weight_pointwise,
                       bias=self.bias_pointwise,
                       padding=0,
                       stride=1,
                       dilation=1,
                       groups=1,
                       data_format=self._data_format)

        return out


In [6]:
class Encoder(nn.Layer):
    def __init__(self, in_channels, out_channels):
        super(Encoder, self).__init__()
        
        self.relus = nn.LayerList(
            [nn.ReLU() for i in range(2)])
        self.separable_conv_01 = SeparableConv2D(in_channels, 
                                                 out_channels, 
                                                 kernel_size=3, 
                                                 padding='same')
        self.bns = nn.LayerList(
            [nn.BatchNorm2D(out_channels) for i in range(2)])
        
        self.separable_conv_02 = SeparableConv2D(out_channels, 
                                                 out_channels, 
                                                 kernel_size=3, 
                                                 padding='same')
        self.pool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
        self.residual_conv = nn.Conv2D(in_channels, 
                                        out_channels, 
                                        kernel_size=1, 
                                        stride=2, 
                                        padding='same')

    def forward(self, inputs):
        previous_block_activation = inputs
        
        y = self.relus[0](inputs)
        y = self.separable_conv_01(y)
        y = self.bns[0](y)
        y = self.relus[1](y)
        y = self.separable_conv_02(y)
        y = self.bns[1](y)
        y = self.pool(y)
        
        residual = self.residual_conv(previous_block_activation)
        y = paddle.add(y, residual)

        return y


In [7]:
class Decoder(nn.Layer):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()

        self.relus = nn.LayerList(
            [nn.ReLU() for i in range(2)])
        self.conv_transpose_01 = nn.Conv2DTranspose(in_channels, 
                                                           out_channels, 
                                                           kernel_size=3, 
                                                           padding=1)
        self.conv_transpose_02 = nn.Conv2DTranspose(out_channels, 
                                                           out_channels, 
                                                           kernel_size=3, 
                                                           padding=1)
        self.bns = nn.LayerList(
            [nn.BatchNorm2D(out_channels) for i in range(2)]
        )
        self.upsamples = nn.LayerList(
            [nn.Upsample(scale_factor=2.0) for i in range(2)]
        )
        self.residual_conv = nn.Conv2D(in_channels, 
                                        out_channels, 
                                        kernel_size=1, 
                                        padding='same')

    def forward(self, inputs):
        previous_block_activation = inputs

        y = self.relus[0](inputs)
        y = self.conv_transpose_01(y)
        y = self.bns[0](y)
        y = self.relus[1](y)
        y = self.conv_transpose_02(y)
        y = self.bns[1](y)
        y = self.upsamples[0](y)
        
        residual = self.upsamples[1](previous_block_activation)
        residual = self.residual_conv(residual)
        
        y = paddle.add(y, residual)
        
        return y


In [8]:
class cup_disc_UNet(nn.Layer):
    def __init__(self, num_classes):
        super(cup_disc_UNet, self).__init__()

        self.conv_1 = nn.Conv2D(3, 32, 
                                kernel_size=3,
                                stride=2,
                                padding='same')
        self.bn = nn.BatchNorm2D(32)
        self.relu = nn.ReLU()

        in_channels = 32
        self.encoders = []
        self.encoder_list = [64, 128, 256]
        self.decoder_list = [256, 128, 64, 32]

        # 根据下采样个数和配置循环定义子Layer，避免重复写一样的程序
        for out_channels in self.encoder_list:
            block = self.add_sublayer('encoder_{}'.format(out_channels),
                                      Encoder(in_channels, out_channels))
            self.encoders.append(block)
            in_channels = out_channels

        self.decoders = []

        # 根据上采样个数和配置循环定义子Layer，避免重复写一样的程序
        for out_channels in self.decoder_list:
            block = self.add_sublayer('decoder_{}'.format(out_channels), 
                                      Decoder(in_channels, out_channels))
            self.decoders.append(block)
            in_channels = out_channels

        self.output_conv = nn.Conv2D(in_channels, 
                                            num_classes, 
                                            kernel_size=3, 
                                            padding='same')
    
    def forward(self, inputs):
        y = self.conv_1(inputs)
        y = self.bn(y)
        y = self.relu(y)
        
        for encoder in self.encoders:
            y = encoder(y)

        for decoder in self.decoders:
            y = decoder(y)
        
        y = self.output_conv(y)
        return y


### Utils

In [9]:
### we use DICE metric to validate the predicted results 
### The detailed introduction of DICE coefficient 
### can be found at https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient

class DiceLoss(nn.Layer):
    """
    Implements the dice loss function.
    Args:
        ignore_index (int64): Specifies a target value that is ignored
            and does not contribute to the input gradient. Default ``255``.
    """

    def __init__(self, ignore_index=2):
        super(DiceLoss, self).__init__()
        self.ignore_index = ignore_index
        self.eps = 1e-5

    def forward(self, logits, labels):
        if len(labels.shape) != len(logits.shape):
            labels = paddle.unsqueeze(labels, 1)
        num_classes = logits.shape[1]
        mask = (labels != self.ignore_index)
        logits = logits * mask
        labels = paddle.cast(labels, dtype='int32')
        single_label_lists = []
        for c in range(num_classes):
            single_label = paddle.cast((labels == c), dtype='int32')
            single_label = paddle.squeeze(single_label, axis=1)
            single_label_lists.append(single_label)
        labels_one_hot = paddle.stack(tuple(single_label_lists), axis=1)
        logits = F.softmax(logits, axis=1)
        labels_one_hot = paddle.cast(labels_one_hot, dtype='float32')
        dims = (0,) + tuple(range(2, labels.ndimension()))
        intersection = paddle.sum(logits * labels_one_hot, dims)
        cardinality = paddle.sum(logits + labels_one_hot, dims)
        dice_loss = (2. * intersection / (cardinality + self.eps)).mean()
        return dice_loss

In [30]:
### Training function

def train(model, iters, train_dataloader, val_dataloader, optimizer, criterion, metric, log_interval, evl_interval):
    iter = 0
    model.train()
    avg_loss_list = []
    avg_dice_list = []
    best_dice = 0.
    while iter < iters:
        for data in train_dataloader:
            iter += 1
            if iter > iters:
                break
            fundus_img = (data[0]/255.).astype("float32")
            gt_label = (data[1]).astype("int64")
            # print('label shape: ', gt_label.shape)
            logits = model(fundus_img)
            # print('logits shape: ', logits.shape)
            loss = criterion(logits, gt_label)
            # print('loss: ',loss)
            dice = metric(logits, gt_label) 
            # print('dice: ', dice)

            loss.backward()
            optimizer.step()

            model.clear_gradients()
            avg_loss_list.append(loss.numpy()[0])
            avg_dice_list.append(dice.numpy()[0]) 

            if iter % log_interval == 0:
                avg_loss = np.array(avg_loss_list).mean()
                avg_dice = np.array(avg_dice_list).mean()
                avg_loss_list = []
                avg_dice_list = []
                print("[TRAIN] iter={}/{} avg_loss={:.4f} avg_dice={:.4f}".format(iter, iters, avg_loss, avg_dice))

            if iter % evl_interval == 0:
                avg_loss, avg_dice = val(model, val_dataloader)
                print("[EVAL] iter={}/{} avg_loss={:.4f} dice={:.4f}".format(iter, iters, avg_loss, avg_dice))
                if avg_dice >= best_dice:
                    best_dice = avg_dice
                    paddle.save(model.state_dict(),
                                os.path.join("best_model_{:.4f}".format(best_dice), 'model.pdparams'))
                model.train()

### validation function

def val(model, val_dataloader):
    model.eval()
    avg_loss_list = []
    avg_dice_list = []
    with paddle.no_grad():
        for data in val_dataloader:
            fundus_img = (data[0] / 255.).astype("float32")
            gt_label = (data[1]).astype("int64")

            pred = model(fundus_img)
            loss = criterion(pred, gt_label)
            dice = metric (pred, gt_label)  

            avg_loss_list.append(loss.numpy()[0])
            avg_dice_list.append(dice.numpy()[0])

    avg_loss = np.array(avg_loss_list).mean()
    avg_dice = np.array(avg_dice_list).mean()

    return avg_loss, avg_dice

### Training

In [31]:
### generate training Dataset and validation Dataset 

train_dataset = FundusDataset(image_file = images_file, 
                        gt_path = gt_file,
                        filelists=train_filelists)

val_dataset = FundusDataset(image_file = images_file, 
                        gt_path = gt_file,
                        filelists=val_filelists)

In [32]:
### Load the samples

train_loader = paddle.io.DataLoader(
    train_dataset,
    batch_sampler=paddle.io.DistributedBatchSampler(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False),
    num_workers=num_workers,
    return_list=True,
    use_shared_memory=False
)

val_loader = paddle.io.DataLoader(
    val_dataset,
    batch_sampler=paddle.io.DistributedBatchSampler(val_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False),
    num_workers=num_workers,
    return_list=True,
    use_shared_memory=False
)

In [34]:
### Model code was used to generate Model instance, and Optimizer, loss function, 
### evaluation index and other information were defined for subsequent training.

model = cup_disc_UNet(num_classes=3)

### The SUMMARY interface provided by the paddlepaddle is called to visualize the constructed model, 
### which is convenient to view and confirm the model structure and parameter information.
# paddle.Model(model).summary((-1,3,256,256)) 

if optimizer_type == "adam":
    optimizer = paddle.optimizer.Adam(init_lr, parameters=model.parameters())

criterion = nn.CrossEntropyLoss(axis=1)
metric = DiceLoss()

In [None]:
### training process

train(model, iters, train_loader, val_loader, optimizer, criterion, metric, log_interval=50, evl_interval=100)

### Inference

In [10]:
### inference(testing) process, load the model parameters

best_model_path = "./best_model_0.1794/model.pdparams"
model = cup_disc_UNet(num_classes = 3)
para_state_dict = paddle.load(best_model_path)
model.set_state_dict(para_state_dict)
model.eval()

In [11]:
### gerenate the test Dataset

test_dataset = FundusDataset(image_file = test_file, 
                            mode='test')

In [15]:
### The fundus images in the test dataset are segmented one by one
### The segmentation results are resized and stored as BMP images

for fundus_img, idx, h, w in test_dataset:
    # print(idx)
    fundus_img = fundus_img[np.newaxis, ...]
    fundus_img = paddle.to_tensor((fundus_img / 255.).astype("float32"))
    logits = model(fundus_img)
    pred_img = logits.numpy().argmax(1)
    pred_gray = np.squeeze(pred_img, axis=0)
    pred_gray = pred_gray.astype('float32')
    # print(pred_gray.shape)
    pred_gray[pred_gray == 1] = 128
    pred_gray[pred_gray == 2] = 255
    # print(pred_gray)
    pred_ = cv2.resize(pred_gray, (w, h))
    # print(pred_.shape)
    cv2.imwrite('Disc_Cup_Segmentations/'+idx+'.bmp', pred_)

### Summary

    This baseline realized the segmentation of optic cup and optic disc in 2D color fundus photography, and the baseline model is U-Net.
    Users can try other tricks on the basis of baseline, such as joint training with macular segmentation or lacalization tasks, and realizing the segmentation from coarse to fine.