In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import torch.utils.data as data
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torch.utils.data import ConcatDataset
from torch.cuda.amp import autocast, GradScaler
import os
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
import gc
import cv2
# 忽略所有警告
warnings.filterwarnings('ignore')
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
seed_value = 42   # 设定随机数种子

np.random.seed(seed_value)
random.seed(seed_value)
os.environ['PYTHONHASHSEED'] = str(seed_value)  # 为了禁止hash随机化，使得实验可复现。

torch.manual_seed(seed_value)     # 为CPU设置随机种子
torch.cuda.manual_seed(seed_value)      # 为当前GPU设置随机种子（只用一块GPU）
torch.cuda.manual_seed_all(seed_value)   # 为所有GPU设置随机种子（多块GPU）

torch.backends.cudnn.deterministic = True

class CFG:
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    checpoint = 'result/ConvNeXt_Uper-DIM-10-[eval_loss]-0.3738-[dice_score]-0.45-19-epoch.pkl'
    # ============== comp exp name =============
    comp_name = 'vesuvius'

    # # comp_dir_path = './'
    # comp_dir_path = '/kaggle/input/'
    # comp_folder_name = 'vesuvius-challenge-ink-detection'
    # # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    # comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
        # comp_dir_path = './'
    comp_dir_path = ''
    comp_folder_name = 'data'
    # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
    
    exp_name = 'Unet_stride'

    # ============== pred target =============
    target_size = 1

    # ============== model cfg =============
    model_name = 'Unet'

    in_chans = 24# 65
    # ============== training cfg =============
    size = 224
    tile_size = 224
    stride = tile_size // 2

    train_batch_size = 16 # 32
    valid_batch_size = 16
    use_amp = True

    epochs = 30 # 30

    # lr = 1e-4 / warmup_factor
    lr = 1e-4

    # ============== fixed =============
    pretrained = True

    backbone = 'se_resnext50_32x4d'

    min_lr = 1e-6
    weight_decay = 1e-6
    max_grad_norm = 1000

    num_workers = 4

    seed = 42

    threshhold = 0.5

    all_best_dice = 0
    all_best_loss = np.float('inf')

    shape_list = []
    test_shape_list = []

    val_mask = None
    val_label = None

    # ============== set dataset path =============

    # outputs_path = f'/kaggle/working/outputs/{comp_name}/{exp_name}/'
    outputs_path = 'result/'

    submission_dir = outputs_path + 'submissions/'
    submission_path = submission_dir + f'submission_{exp_name}.csv'

    model_dir = outputs_path
    log_dir = outputs_path + 'logs/'

    # ============== augmentation =============
    train_aug_list = [
        # A.RandomResizedCrop(
        #     size, size, scale=(0.85, 1.0)),
        A.Resize(size, size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(size * 0.3), max_height=int(size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        A.Resize(size, size),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]
seed = CFG.seed
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
def read_image_mask(fragment_id):

    images = []

    # idxs = range(65)
    mid = 65 // 2
    start = mid - CFG.in_chans // 2
    end = mid + CFG.in_chans // 2
    idxs = range(start, end)

    for i in tqdm(idxs):
        
        image = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/surface_volume/{i:02}.tif", 0)

        pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size + 1)
        pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size + 1)

        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

        images.append(image)
    images = np.stack(images, axis=2)

    label = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/inklabels.png", 0)
    label = np.pad(label, [(0, pad0), (0, pad1)], constant_values=0)

    label = label.astype('float32')
    label /= 255.0

    mask = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/mask.png", 0)
    mask = np.pad(mask, [(0, pad0), (0, pad1)], constant_values=0)

    mask = (mask / 255.).astype('float32')

    CFG.shape_list.append(mask.shape)
    if fragment_id == 1:
        CFG.val_mask = mask
        CFG.val_label = label
    
    return images, label, mask

def get_train_valid_dataset(val_persent=0.05):
    train_id = [2, 3]
    train_images = []
    train_labels = []

    valid_images = []
    valid_labels = []
    valid_positons = []

    for fragment_id in range(1, 4):

        image, label, mask = read_image_mask(fragment_id)

        x1_list = list(range(0, image.shape[1]-CFG.tile_size+1, CFG.stride))
        y1_list = list(range(0, image.shape[0]-CFG.tile_size+1, CFG.stride))

        for y1 in y1_list:
            for x1 in x1_list:
                y2 = y1 + CFG.tile_size
                x2 = x1 + CFG.tile_size
                if sum(list(mask[y1:y2, x1:x2].flatten())) != 0:
                    if fragment_id in train_id:
                        train_images.append(image[y1:y2, x1:x2])
                        train_labels.append(label[y1:y2, x1:x2, None])
                    else:
                        valid_images.append(image[y1:y2, x1:x2])
                        valid_labels.append(label[y1:y2, x1:x2, None])
                        valid_positons.append([x1, y1, x2, y2, fragment_id - 1])
    return train_images, train_labels, valid_images, valid_labels, valid_positons

def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)
    return aug

class SubvolumeDataset(data.Dataset):
    def __init__(self, images, labels, positions, transform, is_train):
        self.transform = transform
        self.images = images
        self.labels = labels
        self.is_train = is_train
        self.positions = positions
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        if self.is_train:
            image = self.images[index]
            label = self.labels[index]
            if self.positions:
                position = np.array(self.positions[index])
            else:
                position = np.zeros(1)
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image']
                label = data['mask']
            return image, label, position
        else:
            image = self.images[index]
            position = np.array(self.positions[index])
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image']
            return image, position
        
# IOU and Dice Score
def dice_coef(targets, preds, thr=0.5, beta=0.5, smooth=1e-5):

    #comment out if your model contains a sigmoid or equivalent activation layer
    # flatten label and prediction tensors
    preds = (preds > thr).view(-1).float()
    targets = targets.view(-1).float()

    y_true_count = targets.sum()
    ctp = preds[targets==1].sum()
    cfp = preds[targets==0].sum()
    beta_squared = beta * beta

    c_precision = ctp / (ctp + cfp + smooth)
    c_recall = ctp / (y_true_count + smooth)
    dice = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall + smooth)

    return dice

In [3]:
train_images, train_labels, valid_images, valid_labels, valid_positons = get_train_valid_dataset()

100%|██████████| 24/24 [00:21<00:00,  1.14it/s]
100%|██████████| 24/24 [01:12<00:00,  3.00s/it]
100%|██████████| 24/24 [00:06<00:00,  3.54it/s]


In [4]:
train_dataset = SubvolumeDataset(train_images, train_labels, None,get_transforms(data='train', cfg=CFG), True)
valid_dataset = SubvolumeDataset(valid_images, valid_labels, valid_positons, get_transforms(data='valid', cfg=CFG), True)
train_loader = data.DataLoader(train_dataset,
                          batch_size=CFG.train_batch_size,
                          shuffle=True,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
                          )
valid_loader = data.DataLoader(valid_dataset,
                          batch_size=CFG.valid_batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

In [5]:
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from mmcv.utils import get_logger
import logging
from mmcv.cnn import ConvModule

def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None,
           warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > output_h:
                if ((output_h > 1 and output_w > 1 and input_h > 1
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
                        and (output_w - 1) % (input_w - 1)):
                    warnings.warn(
                        f'When align_corners={align_corners}, '
                        'the output would more aligned if '
                        f'input size {(input_h, input_w)} is `x+1` and '
                        f'out size {(output_h, output_w)} is `nx+1`')
    return F.interpolate(input, size, scale_factor, mode, align_corners)

def get_root_logger(log_file=None, log_level=logging.INFO):
    """Get the root logger.

    The logger will be initialized if it has not been initialized. By default a
    StreamHandler will be added. If `log_file` is specified, a FileHandler will
    also be added. The name of the root logger is the top-level package name,
    e.g., "mmseg".

    Args:
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the root logger.
        log_level (int): The root logger level. Note that only the process of
            rank 0 is affected, while other processes will set the level to
            "Error" and be silent most of the time.

    Returns:
        logging.Logger: The root logger.
    """

    logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)

    return logger

class BayarConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding=2):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.minus1 = (torch.ones(self.in_channels, self.out_channels, 1) * -1.000)

        super(BayarConv2d, self).__init__()
        # only (kernel_size ** 2 - 1) trainable params as the center element is always -1
        self.kernel = nn.Parameter(torch.rand(self.in_channels, self.out_channels, kernel_size ** 2 - 1),
                                   requires_grad=True)


    def bayarConstraint(self):
        self.kernel.data = self.kernel.permute(2, 0, 1)
        self.kernel.data = torch.div(self.kernel.data, self.kernel.data.sum(0))
        self.kernel.data = self.kernel.permute(1, 2, 0)
        ctr = self.kernel_size ** 2 // 2
        real_kernel = torch.cat((self.kernel[:, :, :ctr], self.minus1.to(self.kernel.device), self.kernel[:, :, ctr:]), dim=2)
        real_kernel = real_kernel.reshape((self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        return real_kernel

    def forward(self, x):
        x = F.conv2d(x, self.bayarConstraint(), stride=self.stride, padding=self.padding)
        return x

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(self, in_chans=3, bayar=False, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 
                 drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3], pretrained=None
                 ):
        super().__init__()
        self.pretrained = pretrained
        self.in_chans = in_chans
        self.bayar = bayar
        if self.bayar:
            self.bayar_conv = BayarConv2d(in_chans, in_chans)
            self.in_chans = in_chans * 2

        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(self.in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(len(dims) - 1):
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(len(dims)):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.out_indices = out_indices

        norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
        for i_layer in range(len(dims)):
            layer = norm_layer(dims[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if pretrained is None:
            pretrained = self.pretrained

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        if isinstance(pretrained, str):
            self.apply(_init_weights)
            logger = get_root_logger()
        elif pretrained is None:
            self.apply(_init_weights)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward_features(self, x):
        if self.bayar:
            x_bayar = self.bayar_conv(x)
            x = torch.cat([x, x_bayar], dim=1)
        outs = []
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x)
                outs.append(x_out)

        return tuple(outs)

    def forward(self, x):
        x = self.forward_features(x)
        return x
    
class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
        align_corners (bool): align_corners argument of F.interpolate.
    """

    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
                 act_cfg, align_corners, **kwargs):
        super(PPM, self).__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(
                        self.in_channels,
                        self.channels,
                        1,
                        conv_cfg=self.conv_cfg,
                        norm_cfg=self.norm_cfg,
                        act_cfg=self.act_cfg,
                        **kwargs)))

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = resize(
                ppm_out,
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs

class UPerHead(nn.Module):
    """Unified Perceptual Parsing for Scene Understanding.

    This head is the implementation of `UPerNet
    <https://arxiv.org/abs/1807.10221>`_.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
    """

    def __init__(self, 
                in_channels=[128, 256, 512, 1024],
                in_index=[0, 1, 2, 3],
                pool_scales=(1, 2, 3, 6),
                channels=512,
                dropout_ratio=0.1,
                num_classes=0,
                norm_cfg=dict(type='BN', requires_grad=True),
                align_corners=False):
        super(UPerHead, self).__init__()
        # PSP Module
        self.input_transform = 'multiple_select'
        self.align_corners = align_corners
        self.in_index = in_index
        self.psp_modules = PPM(
            pool_scales,
            in_channels[-1],
            channels,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'),
            align_corners=align_corners)
        self.bottleneck = ConvModule(
            in_channels[-1] + len(pool_scales) * channels,
            channels,
            3,
            padding=1,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for in_channel in in_channels[:-1]:  # skip the top layer
            l_conv = ConvModule(
                in_channel,
                channels,
                1,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                inplace=False)
            fpn_conv = ConvModule(
                channels,
                channels,
                3,
                padding=1,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                inplace=False)
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(
            len(in_channels) * channels,
            channels,
            3,
            padding=1,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))
        if dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None
        self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs
    
    def cls_seg(self, feat):
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def forward(self, inputs):
        """Forward function."""
        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [
            lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + resize(
                laterals[i],
                size=prev_shape,
                mode='bilinear',
                align_corners=self.align_corners)

        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i],
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        output = self.fpn_bottleneck(fpn_outs)
        output = self.cls_seg(output)
        return output
    
class FCNHead(nn.Module):
    """Fully Convolution Networks for Semantic Segmentation.

    This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.

    Args:
        num_convs (int): Number of convs in the head. Default: 2.
        kernel_size (int): The kernel size for convs in the head. Default: 3.
        concat_input (bool): Whether concat the input and output of convs
            before classification layer.
        dilation (int): The dilation rate for convs in the head. Default: 1.
    """

    def __init__(self,
                kernel_size=3,
                dilation=1,
                in_channels=512,
                in_index=2,
                channels=256,
                num_convs=1,
                concat_input=False,
                dropout_ratio=0.1,
                num_classes=2,
                norm_cfg=dict(type='BN', requires_grad=True),
                align_corners=False):
        assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
        self.align_corners = align_corners
        self.in_index = in_index
        self.input_transform = None
        self.num_convs = num_convs
        self.concat_input = concat_input
        self.kernel_size = kernel_size
        super(FCNHead, self).__init__()
        if num_convs == 0:
            assert in_channels == channels

        conv_padding = (kernel_size // 2) * dilation
        convs = []
        convs.append(
            ConvModule(
                in_channels,
                channels,
                kernel_size=kernel_size,
                padding=conv_padding,
                dilation=dilation,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU')))
        for i in range(num_convs - 1):
            convs.append(
                ConvModule(
                    channels,
                    channels,
                    kernel_size=kernel_size,
                    padding=conv_padding,
                    dilation=dilation,
                    conv_cfg=None,
                    norm_cfg=norm_cfg,
                    act_cfg=dict(type='ReLU')))
        if num_convs == 0:
            self.convs = nn.Identity()
        else:
            self.convs = nn.Sequential(*convs)
        if self.concat_input:
            self.conv_cat = ConvModule(
                in_channels + channels,
                channels,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'))
        if dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None
        self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs
    
    def cls_seg(self, feat):
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def forward(self, inputs):
        """Forward function."""
        x = self._transform_inputs(inputs)
        output = self.convs(x)
        if self.concat_input:
            output = self.conv_cat(torch.cat([x, output], dim=1))
        output = self.cls_seg(output)
        return output

In [6]:
class Ink_detection(nn.Module):
    def __init__(self) -> None:
        super(Ink_detection, self).__init__()
        self.encoder = ConvNeXt(in_chans=CFG.in_chans,
                                depths=[3, 3, 16, 3],
                                dims=[128, 256, 512, 1024],
                                drop_path_rate=0.1,
                                layer_scale_init_value=0.1,
                                out_indices=[0, 1, 2, 3])
        self.decoder_uper = UPerHead(in_channels=[128, 256, 512, 1024],
                                     in_index=[0, 1, 2, 3],
                                     pool_scales=(1, 2, 3, 6),
                                     channels=512,
                                     dropout_ratio=0.1,
                                     num_classes=CFG.target_size,
                                     norm_cfg=dict(type='BN', requires_grad=True),
                                     align_corners=False)
        self.decoder_fcn = FCNHead(kernel_size=3,
                                   dilation=1,
                                   in_channels=512,
                                   in_index=2,
                                   channels=256,
                                   num_convs=1,
                                   concat_input=False,
                                   dropout_ratio=0.1,
                                   num_classes=CFG.target_size,
                                   norm_cfg=dict(type='BN', requires_grad=True),
                                   align_corners=False)
        param = torch.ones(2, requires_grad=True)
        self.param = nn.Parameter(param)
    def forward(self, x):
        feature_list = self.encoder(x)
        res_uper = self.decoder_uper(feature_list)
        res_fcn = self.decoder_fcn(feature_list)
        res_uper = resize(input=res_uper,
                          size=(CFG.size, CFG.size),
                          mode='bilinear',
                          align_corners=False)
        res_fcn = resize(input=res_fcn,
                          size=(CFG.size, CFG.size),
                          mode='bilinear',
                          align_corners=False)
        res = self.param[0] * res_uper +  self.param[1] * res_fcn
        return res       

In [7]:
model = Ink_detection().to(CFG.device)
from torchsummary import summary
# summary(model, input_size=(CFG.in_chans, 224, 224))
# x = torch.ones((16, CFG.in_chans, 224, 224)).float().cuda()
# o = model(x)
model_name = 'ConvNeXt_Uper'
if CFG.pretrained:
    try:
        checkpoint = torch.load(CFG.checpoint, map_location=CFG.device)
        models_dict = model.state_dict()
        for model_part in models_dict:
            if model_part in checkpoint:
                models_dict[model_part] = checkpoint[model_part]
        model.load_state_dict(models_dict)
        print('Checkpoint loaded')
    except:
        print('Checkpoint not loaded')
        pass

Checkpoint not loaded


In [8]:
from warmup_scheduler import GradualWarmupScheduler


class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    """
    https://www.kaggle.com/code/underwearfitting/single-fold-training-of-resnet200d-lb0-965
    """
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(
            optimizer, multiplier, total_epoch, after_scheduler)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [
                        base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

def get_scheduler(cfg, optimizer):
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, cfg.epochs, eta_min=1e-7)
    scheduler = GradualWarmupSchedulerV2(
        optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)

    return scheduler

def scheduler_step(scheduler, avg_val_loss, epoch):
    scheduler.step(epoch)

In [9]:
def train_step(train_loader, model, criterion, optimizer, writer, device, epoch):
    model.train()
    epoch_loss = 0
    scaler = GradScaler(enabled=CFG.use_amp)
    bar = tqdm(enumerate(train_loader), total=len(train_loader)) 
    for step, (image, label, _) in bar:
        optimizer.zero_grad()
        outputs = model(image.to(device))
        # outputs = resize(input=outputs,
        #                 size=(CFG.size, CFG.size),
        #                 mode='bilinear',
        #                 align_corners=False)
        loss = criterion(outputs, label.to(device))
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        bar.set_postfix(loss=f'{loss.item():0.4f}', epoch=epoch ,gpu_mem=f'{mem:0.2f} GB', lr=f'{optimizer.state_dict()["param_groups"][0]["lr"]:0.2e}')
        epoch_loss += loss.item()
    writer.add_scalar('Train/Loss', epoch_loss / len(train_loader), epoch)
    return epoch_loss / len(train_loader)

def valid_step(valid_loader, model, criterion, device, writer, epoch):
    pred_label = np.zeros(CFG.shape_list[0])
    true_label = CFG.val_label
    mask_count = np.zeros(CFG.shape_list[0])
    model.eval()
    epoch_loss = 0
    for step, (images, labels, positions) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            y_preds = model(images)
            # y_preds = resize(input=y_preds,
            #             size=(CFG.size, CFG.size),
            #             mode='bilinear',
            #             align_corners=False)
            loss = criterion(y_preds, labels)
        # make whole mask
        y_preds = torch.sigmoid(y_preds)
        pred_img =y_preds.squeeze().cpu().numpy()
        positions = positions.squeeze()
        for i in range(len(positions)):
            x1, y1, x2, y2, _ = positions[i].numpy().tolist()
            pred_label[y1:y2, x1:x2] += pred_img[i]
            mask_count[y1:y2, x1:x2] += np.ones((CFG.tile_size, CFG.tile_size))
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(valid_loader)
    print(f'mask_count_min: {mask_count.min()}')
    print(f'mask_count_max: {mask_count.max()}')
    # 计算准确率
    pred_label /= mask_count
    pred_label *= CFG.val_mask
    best_th = 0
    best_dice = 0
    for th in np.arange(3, 6, 0.5) / 10:
        dice_score = dice_coef(torch.from_numpy(true_label).to(CFG.device), torch.from_numpy(pred_label).to(CFG.device), thr=th).item()
        # dice_scores.append(dice_score)
        if dice_score > best_dice:
            best_dice = dice_score
            best_th = th
    if CFG.all_best_dice < best_dice:
        print('best_th={:2f}' .format(best_th),"score up: {:2f}->{:2f}".format(CFG.all_best_dice, best_dice))       
        cv2.imwrite('result/logs/img/'+str(epoch) + '_res.png', ((pred_label > best_th).astype('int')* 255) )
        CFG.all_best_dice = best_dice
        torch.save(model.state_dict(), 'result/' +  '{}-DIM-{}-[eval_loss]-{:.4f}-[dice_score]-{:.2f}-'.format(model_name, CFG.in_chans , avg_loss, best_dice) + str(epoch) + '-epoch.pkl')
        # 使用make_grid将图片转换成网格形式
        # pred_mask = make_grid(torch.from_numpy(pred_label ).to(CFG.device),normalize=True)
        # true_mask = make_grid(torch.from_numpy(true_label).to(CFG.device), normalize=True)
        # 使用add_image方法将图片添加到TensorBoard中
        # writer.add_image('Valid/True_mask', true_mask, global_step=epoch, dataformats="CHW")
        # writer.add_image('Valid/Pred_mask', pred_mask, global_step=epoch, dataformats="CHW")
    if CFG.all_best_loss > avg_loss:
        print('best_loss={:2f}'.format(avg_loss), "loss down: {:2f}->{:2f}".format(CFG.all_best_loss, avg_loss))
        cv2.imwrite('result/logs/img/' + str(epoch) + '_res.png', ((pred_label > best_th).astype('int')* 255) )
        CFG.all_best_loss = avg_loss
        torch.save(model.state_dict(), 'result/' +  '{}-DIM-{}-[eval_loss]-{:.4f}-[dice_score]-{:.2f}-'.format(model_name, CFG.in_chans ,avg_loss, best_dice) + str(epoch) + '-epoch.pkl')
    writer.add_scalar('Val/Dice', best_dice, epoch)
    writer.add_scalar('Valid/Loss', avg_loss , epoch)
    return avg_loss

In [10]:
criterion = smp.losses.SoftBCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(),
                        lr=CFG.lr,
                        betas=(0.9, 0.999),
                        weight_decay=CFG.weight_decay
                        )
scheduler = get_scheduler(CFG, optimizer)
writer = SummaryWriter('result/logs')
for i in range(CFG.epochs):
    print('train:')
    train_step(train_loader, model, criterion, optimizer, writer, CFG.device, i + 1)
    print('val:')
    val_loss = valid_step(valid_loader, model, criterion, CFG.device, writer, i + 1)
    scheduler_step(scheduler, val_loss, i + 1)

train:


100%|██████████| 660/660 [06:32<00:00,  1.68it/s, epoch=1, gpu_mem=7.03 GB, loss=0.3897, lr=1.00e-04]

val:



100%|██████████| 164/164 [00:33<00:00,  4.93it/s]


mask_count_min: 0.0
mask_count_max: 4.0
best_th=0.300000 score up: 0.000000->0.319557
best_loss=0.409941 loss down: inf->0.409941
train:


100%|██████████| 660/660 [06:39<00:00,  1.65it/s, epoch=2, gpu_mem=7.47 GB, loss=0.4697, lr=1.00e-03]

val:



100%|██████████| 164/164 [00:33<00:00,  4.94it/s]


mask_count_min: 0.0
mask_count_max: 4.0
train:


100%|██████████| 660/660 [06:42<00:00,  1.64it/s, epoch=3, gpu_mem=7.47 GB, loss=0.4186, lr=1.00e-03]

val:



100%|██████████| 164/164 [00:33<00:00,  4.94it/s]


mask_count_min: 0.0
mask_count_max: 4.0
train:


100%|██████████| 660/660 [06:42<00:00,  1.64it/s, epoch=4, gpu_mem=7.47 GB, loss=0.2907, lr=9.89e-04]

val:



100%|██████████| 164/164 [00:33<00:00,  4.92it/s]


mask_count_min: 0.0
mask_count_max: 4.0
best_loss=0.407708 loss down: 0.409941->0.407708
train:


 80%|███████▉  | 525/660 [05:15<01:18,  1.72it/s, epoch=5, gpu_mem=7.47 GB, loss=0.2150, lr=9.76e-04]