In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import torch.utils.data as data
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torch.utils.data import ConcatDataset
import os
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import warnings
import gc
import cv2
# 忽略所有警告
warnings.filterwarnings('ignore')
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
seed_value = 42   # 设定随机数种子

np.random.seed(seed_value)
random.seed(seed_value)
os.environ['PYTHONHASHSEED'] = str(seed_value)  # 为了禁止hash随机化，使得实验可复现。

torch.manual_seed(seed_value)     # 为CPU设置随机种子
torch.cuda.manual_seed(seed_value)      # 为当前GPU设置随机种子（只用一块GPU）
torch.cuda.manual_seed_all(seed_value)   # 为所有GPU设置随机种子（多块GPU）

torch.backends.cudnn.deterministic = True

class CFG:
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    checpoint = ''
    # ============== comp exp name =============
    comp_name = 'vesuvius'

    # # comp_dir_path = './'
    # comp_dir_path = '/kaggle/input/'
    # comp_folder_name = 'vesuvius-challenge-ink-detection'
    # # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    # comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
        # comp_dir_path = './'
    comp_dir_path = ''
    comp_folder_name = 'data'
    # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
    
    exp_name = 'ConvNeXt_UPerHead'

    # ============== pred target =============
    target_size = 1

    # ============== model cfg =============
    model_name = 'ConvNeXt_UPerHead'

    in_chans = 16 # 65
    # ============== training cfg =============
    size = 224
    tile_size = 224
    stride = tile_size // 2

    mean_num = 4

    train_batch_size = 16 # 32
    valid_batch_size = 16
    use_amp = True

    epochs = 50 # 30

    # lr = 1e-4 / warmup_factor
    lr = 1e-2

    # ============== fixed =============
    pretrained = False

    min_lr = 1e-6
    weight_decay = 1e-6
    max_grad_norm = 1000

    num_workers = 4

    seed = 42

    threshhold = 0.4

    shape_list = []
    test_shape_list = []

    # ============== set dataset path =============

    # outputs_path = f'/kaggle/working/outputs/{comp_name}/{exp_name}/'
    outputs_path = 'result/'

    submission_dir = outputs_path + 'submissions/'
    submission_path = submission_dir + f'submission_{exp_name}.csv'

    model_dir = outputs_path
    log_dir = outputs_path + 'logs/'

    # ============== augmentation =============
    train_aug_list = [
        # A.RandomResizedCrop(
        #     size, size, scale=(0.85, 1.0)),
        A.Resize(size, size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(size * 0.3), max_height=int(size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * (in_chans // mean_num),
            std= [1] * (in_chans // mean_num)
        ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        A.Resize(size, size),
        A.Normalize(
            mean= [0] * (in_chans // mean_num),
            std= [1] * (in_chans // mean_num)
        ),
        ToTensorV2(transpose_mask=True),
    ]


In [2]:
def read_image_mask(fragment_id):

    images = []

    # idxs = range(65)
    mid = 65 // 2
    start = mid - CFG.in_chans // 2
    end = mid + CFG.in_chans // 2
    idxs = range(start, end, CFG.mean_num)

    for i in tqdm(idxs):
        flag = False
        for j in range(i, i + CFG.mean_num):
            image = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/surface_volume/{j:02}.tif", 0).astype(np.float32)

            pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
            pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)

            image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

            if not flag:
                _image = image.copy()
                flag = True
            else:
                _image += image
        images.append(_image / CFG.mean_num)
    images = np.stack(images, axis=2)

    label = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/inklabels.png", 0)
    label = np.pad(label, [(0, pad0), (0, pad1)], constant_values=0)

    label = label.astype('float32')
    label /= 255.0

    mask = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/mask.png", 0)
    mask = np.pad(mask, [(0, pad0), (0, pad1)], constant_values=0)

    mask = mask.astype('float32')
    mask /= 255.0

    CFG.shape_list.append(mask.shape)
    
    return images, label, mask

def get_train_valid_dataset(val_persent=0.05):
    images = []
    labels = []
    positions = []
    indexs = []
    train_images = []
    train_labels = []

    valid_images = []
    valid_labels = []
    valid_positons = []

    index = 0

    for fragment_id in range(1, 4):

        image, label, mask = read_image_mask(fragment_id)

        x1_list = list(range(0, image.shape[1]-CFG.tile_size+1, CFG.stride))
        y1_list = list(range(0, image.shape[0]-CFG.tile_size+1, CFG.stride))

        for y1 in y1_list:
            for x1 in x1_list:
                y2 = y1 + CFG.tile_size
                x2 = x1 + CFG.tile_size
                if sum(list(mask[y1:y2, x1:x2].flatten())) != 0:
                    images.append(image[y1:y2, x1:x2])
                    labels.append(label[y1:y2, x1:x2, None])
                    positions.append([x1, y1, x2, y2, fragment_id - 1])
                    indexs.append(index)
                    index += 1
    random.shuffle(indexs)
    n = int(val_persent * len(indexs))
    train_indexs = indexs[:-n]
    for i, (image, label, position) in enumerate(zip(images, labels, positions)):
        if i in train_indexs:
            train_images.append(image)
            train_labels.append(label)
        else:
            valid_images.append(image)
            valid_labels.append(label)
            valid_positons.append(position)
    return train_images, train_labels, valid_images, valid_labels, valid_positons

In [3]:
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)
    return aug

class SubvolumeDataset(data.Dataset):
    def __init__(self, images, labels, positions, transform, is_train):
        self.transform = transform
        self.images = images
        self.labels = labels
        self.is_train = is_train
        self.positions = positions
    def __len__(self):
        return len(self.images)
    def __getitem__(self, index):
        if self.is_train:
            image = self.images[index]
            label = self.labels[index]
            if self.positions:
                position = np.array(self.positions[index])
            else:
                position = np.zeros(1)
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image']
                label = data['mask']
            return image, label, position
        else:
            image = self.images[index]
            position = np.array(self.positions[index])
            if self.transform:
                data = self.transform(image=image, mask=label)
                image = data['image']
            return image, position
        
# IOU and Dice Score
def dice_coef(y_true, y_pred, thr=0.5, dim=(0, 1), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2 * inter + epsilon) / (den + epsilon)).mean()
    return dice


def iou_coef(y_true, y_pred, thr=0.5, dim=(0, 1), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true * y_pred).sum(dim=dim)
    iou = ((inter + epsilon) / (union + epsilon)).mean()
    return iou

In [4]:
train_images, train_labels, valid_images, valid_labels, valid_positons = get_train_valid_dataset()
print()

100%|██████████| 4/4 [00:22<00:00,  5.53s/it]
100%|██████████| 4/4 [00:30<00:00,  7.73s/it]
100%|██████████| 4/4 [00:05<00:00,  1.42s/it]





In [5]:
train_dataset = SubvolumeDataset(train_images, train_labels, None,get_transforms(data='train', cfg=CFG), True)
valid_dataset = SubvolumeDataset(valid_images, valid_labels, valid_positons, get_transforms(data='valid', cfg=CFG), True)
train_loader = data.DataLoader(train_dataset,
                          batch_size=CFG.train_batch_size,
                          shuffle=True,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
                          )
valid_loader = data.DataLoader(valid_dataset,
                          batch_size=CFG.valid_batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

In [6]:
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from mmcv.utils import get_logger
import logging
from mmcv.cnn import ConvModule

def resize(input,
           size=None,
           scale_factor=None,
           mode='nearest',
           align_corners=None,
           warning=True):
    if warning:
        if size is not None and align_corners:
            input_h, input_w = tuple(int(x) for x in input.shape[2:])
            output_h, output_w = tuple(int(x) for x in size)
            if output_h > input_h or output_w > output_h:
                if ((output_h > 1 and output_w > 1 and input_h > 1
                     and input_w > 1) and (output_h - 1) % (input_h - 1)
                        and (output_w - 1) % (input_w - 1)):
                    warnings.warn(
                        f'When align_corners={align_corners}, '
                        'the output would more aligned if '
                        f'input size {(input_h, input_w)} is `x+1` and '
                        f'out size {(output_h, output_w)} is `nx+1`')
    return F.interpolate(input, size, scale_factor, mode, align_corners)

def get_root_logger(log_file=None, log_level=logging.INFO):
    """Get the root logger.

    The logger will be initialized if it has not been initialized. By default a
    StreamHandler will be added. If `log_file` is specified, a FileHandler will
    also be added. The name of the root logger is the top-level package name,
    e.g., "mmseg".

    Args:
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the root logger.
        log_level (int): The root logger level. Note that only the process of
            rank 0 is affected, while other processes will set the level to
            "Error" and be silent most of the time.

    Returns:
        logging.Logger: The root logger.
    """

    logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)

    return logger

class BayarConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding=2):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.minus1 = (torch.ones(self.in_channels, self.out_channels, 1) * -1.000)

        super(BayarConv2d, self).__init__()
        # only (kernel_size ** 2 - 1) trainable params as the center element is always -1
        self.kernel = nn.Parameter(torch.rand(self.in_channels, self.out_channels, kernel_size ** 2 - 1),
                                   requires_grad=True)


    def bayarConstraint(self):
        self.kernel.data = self.kernel.permute(2, 0, 1)
        self.kernel.data = torch.div(self.kernel.data, self.kernel.data.sum(0))
        self.kernel.data = self.kernel.permute(1, 2, 0)
        ctr = self.kernel_size ** 2 // 2
        real_kernel = torch.cat((self.kernel[:, :, :ctr], self.minus1.to(self.kernel.device), self.kernel[:, :, ctr:]), dim=2)
        real_kernel = real_kernel.reshape((self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        return real_kernel

    def forward(self, x):
        x = F.conv2d(x, self.bayarConstraint(), stride=self.stride, padding=self.padding)
        return x

class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class ConvNeXt(nn.Module):
    r""" ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    """
    def __init__(self, in_chans=3, bayar=False, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 
                 drop_path_rate=0., layer_scale_init_value=1e-6, out_indices=[0, 1, 2, 3], pretrained=None
                 ):
        super().__init__()
        self.pretrained = pretrained
        self.in_chans = in_chans
        self.bayar = bayar
        if self.bayar:
            self.bayar_conv = BayarConv2d(in_chans, in_chans)
            self.in_chans = in_chans * 2

        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(self.in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(len(dims) - 1):
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(len(dims)):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.out_indices = out_indices

        norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
        for i_layer in range(len(dims)):
            layer = norm_layer(dims[i_layer])
            layer_name = f'norm{i_layer}'
            self.add_module(layer_name, layer)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if pretrained is None:
            pretrained = self.pretrained

        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        if isinstance(pretrained, str):
            self.apply(_init_weights)
            logger = get_root_logger()
        elif pretrained is None:
            self.apply(_init_weights)
        else:
            raise TypeError('pretrained must be a str or None')

    def forward_features(self, x):
        if self.bayar:
            x_bayar = self.bayar_conv(x)
            x = torch.cat([x, x_bayar], dim=1)
        outs = []
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
            if i in self.out_indices:
                norm_layer = getattr(self, f'norm{i}')
                x_out = norm_layer(x)
                outs.append(x_out)

        return tuple(outs)

    def forward(self, x):
        x = self.forward_features(x)
        return x
    
class PPM(nn.ModuleList):
    """Pooling Pyramid Module used in PSPNet.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module.
        in_channels (int): Input channels.
        channels (int): Channels after modules, before conv_seg.
        conv_cfg (dict|None): Config of conv layers.
        norm_cfg (dict|None): Config of norm layers.
        act_cfg (dict): Config of activation layers.
        align_corners (bool): align_corners argument of F.interpolate.
    """

    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
                 act_cfg, align_corners, **kwargs):
        super(PPM, self).__init__()
        self.pool_scales = pool_scales
        self.align_corners = align_corners
        self.in_channels = in_channels
        self.channels = channels
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        for pool_scale in pool_scales:
            self.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(pool_scale),
                    ConvModule(
                        self.in_channels,
                        self.channels,
                        1,
                        conv_cfg=self.conv_cfg,
                        norm_cfg=self.norm_cfg,
                        act_cfg=self.act_cfg,
                        **kwargs)))

    def forward(self, x):
        """Forward function."""
        ppm_outs = []
        for ppm in self:
            ppm_out = ppm(x)
            upsampled_ppm_out = resize(
                ppm_out,
                size=x.size()[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            ppm_outs.append(upsampled_ppm_out)
        return ppm_outs

class UPerHead(nn.Module):
    """Unified Perceptual Parsing for Scene Understanding.

    This head is the implementation of `UPerNet
    <https://arxiv.org/abs/1807.10221>`_.

    Args:
        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
            Module applied on the last feature. Default: (1, 2, 3, 6).
    """

    def __init__(self, 
                in_channels=[128, 256, 512, 1024],
                in_index=[0, 1, 2, 3],
                pool_scales=(1, 2, 3, 6),
                channels=512,
                dropout_ratio=0.1,
                num_classes=0,
                norm_cfg=dict(type='BN', requires_grad=True),
                align_corners=False):
        super(UPerHead, self).__init__()
        # PSP Module
        self.input_transform = 'multiple_select'
        self.align_corners = align_corners
        self.in_index = in_index
        self.psp_modules = PPM(
            pool_scales,
            in_channels[-1],
            channels,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'),
            align_corners=align_corners)
        self.bottleneck = ConvModule(
            in_channels[-1] + len(pool_scales) * channels,
            channels,
            3,
            padding=1,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))
        # FPN Module
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        for in_channel in in_channels[:-1]:  # skip the top layer
            l_conv = ConvModule(
                in_channel,
                channels,
                1,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                inplace=False)
            fpn_conv = ConvModule(
                channels,
                channels,
                3,
                padding=1,
                conv_cfg=None,
                norm_cfg=norm_cfg,
                act_cfg=dict(type='ReLU'),
                inplace=False)
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)

        self.fpn_bottleneck = ConvModule(
            len(in_channels) * channels,
            channels,
            3,
            padding=1,
            conv_cfg=None,
            norm_cfg=norm_cfg,
            act_cfg=dict(type='ReLU'))
        if dropout_ratio > 0:
            self.dropout = nn.Dropout2d(dropout_ratio)
        else:
            self.dropout = None
        self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)

    def psp_forward(self, inputs):
        """Forward function of PSP module."""
        x = inputs[-1]
        psp_outs = [x]
        psp_outs.extend(self.psp_modules(x))
        psp_outs = torch.cat(psp_outs, dim=1)
        output = self.bottleneck(psp_outs)

        return output

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor]): List of multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs
    
    def cls_seg(self, feat):
        """Classify each pixel."""
        if self.dropout is not None:
            feat = self.dropout(feat)
        output = self.conv_seg(feat)
        return output

    def forward(self, inputs):
        """Forward function."""
        inputs = self._transform_inputs(inputs)

        # build laterals
        laterals = [
            lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)
        ]

        laterals.append(self.psp_forward(inputs))

        # build top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + resize(
                laterals[i],
                size=prev_shape,
                mode='bilinear',
                align_corners=self.align_corners)

        # build outputs
        fpn_outs = [
            self.fpn_convs[i](laterals[i])
            for i in range(used_backbone_levels - 1)
        ]
        # append psp feature
        fpn_outs.append(laterals[-1])

        for i in range(used_backbone_levels - 1, 0, -1):
            fpn_outs[i] = resize(
                fpn_outs[i],
                size=fpn_outs[0].shape[2:],
                mode='bilinear',
                align_corners=self.align_corners)
        fpn_outs = torch.cat(fpn_outs, dim=1)
        output = self.fpn_bottleneck(fpn_outs)
        output = self.cls_seg(output)
        return output

In [7]:
model = nn.Sequential(ConvNeXt(in_chans=CFG.in_chans // CFG.mean_num,
                                     depths=[3, 3, 27, 3],
                                     dims=[128, 256, 512, 1024],
                                     drop_path_rate=0.4,
                                     layer_scale_init_value=1.0,
                                     out_indices=[0, 1, 2, 3]),
                            UPerHead(in_channels=[128, 256, 512, 1024],
                                     in_index=[0, 1, 2, 3],
                                     pool_scales=(1, 2, 3, 6),
                                     channels=512,
                                     dropout_ratio=0.1,
                                     num_classes=CFG.target_size,
                                     norm_cfg=dict(type='BN', requires_grad=True),
                                     align_corners=False)
                            ).to(CFG.device)

In [8]:
model_name = CFG.model_name
if CFG.pretrained:
    try:
        checkpoint = torch.load(CFG.checpoint, map_location=CFG.device)
        models_dict = model.state_dict()
        for model_part in models_dict:
            if model_part in checkpoint:
                models_dict[model_part] = checkpoint[model_part]
        model.load_state_dict(models_dict)
        print('Checkpoint loaded')
    except:
        print('Checkpoint not loaded')
        pass

In [9]:
def train_step(train_loader, model, criterion, optimizer, writer, device, epoch):
    model.train()
    epoch_loss = 0
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader) * 10, eta_min=CFG.min_lr)
    bar = tqdm(enumerate(train_loader), total=len(train_loader)) 
    for step, (image, label, _) in bar:
        optimizer.zero_grad()
        outputs = model(image.to(device))
        outputs = resize(input=outputs,
                            size=label.shape[2:],
                            mode='bilinear',
                            align_corners=False)
        loss = criterion(outputs, label.to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        optimizer.step()
        scheduler.step()
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        bar.set_postfix(loss=f'{loss.item():0.4f}', epoch=epoch ,gpu_mem=f'{mem:0.2f} GB', lr=f'{optimizer.state_dict()["param_groups"][0]["lr"]:0.2e}')
        epoch_loss += loss.item()
    writer.add_scalar('Train/Loss', epoch_loss / len(train_loader), epoch)

def valid_step(valid_loader, model, criterion, device, writer, epoch):
    pred_label_list = []
    label_list = []
    for i in range(3):
        pred_label_list.append(torch.zeros(CFG.shape_list[i]).to(device))
        label_list.append(torch.zeros(CFG.shape_list[i]).to(device))
    model.eval()
    epoch_loss = 0
    best_core = 0
    dice_scores = []
    iou_scores = []

    for step, (images, labels, positions) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
        images = images.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            y_preds = model(images)
            y_preds = resize(input=y_preds,
                            size=labels.shape[2:],
                            mode='bilinear',
                            align_corners=False)
            loss = criterion(y_preds, labels)
        # make whole mask
        y_preds = torch.sigmoid(y_preds)
        pred_img = (y_preds > CFG.threshhold).float()
        positions = positions.squeeze()
        for i in range(len(positions)):
            x1, y1, x2, y2, fragment_id = positions[i].numpy().tolist()
            pred_label_list[fragment_id][y1:y2, x1:x2] = pred_img[i].squeeze(0)
            label_list[fragment_id][y1:y2, x1:x2] = labels[i].squeeze(0)
        epoch_loss += loss.item()
        # 计算准确率
        dice_score = dice_coef(labels.to(device), y_preds.to(device), thr=CFG.threshhold).item()
        iou_socre = iou_coef(labels.to(device), y_preds.to(device), thr=CFG.threshhold).item()

        dice_scores.append(dice_score)
        iou_scores.append(iou_socre)
                
    # 使用make_grid将图片转换成网格形式
    for i in range(len(pred_label_list)):

        pred_mask = make_grid(pred_label_list[i].to(device), normalize=True)
        true_mask = make_grid(label_list[i].to(device), normalize=True)
        # 使用add_image方法将图片添加到TensorBoard中
        writer.add_image('Valid/' + str(i) + '_True_mask', true_mask, global_step=epoch, dataformats="CHW")
        writer.add_image('Valid/' + str(i) + '_Pred_mask', pred_mask, global_step=epoch, dataformats="CHW")

    writer.add_scalar('Val/IOU', sum(iou_scores) / len(iou_scores), epoch)
    writer.add_scalar('Val/Dice', sum(dice_scores) / len(dice_scores), epoch)
    writer.add_scalar('Valid/Loss', epoch_loss / len(valid_loader), epoch)
    if epoch % 5 == 0 or sum(iou_scores) / len(iou_scores) > best_core:
        best_core = sum(iou_scores) / len(iou_scores)
        torch.save(model.state_dict(), 'result/' +  '{}-DIM-{}-[eval_loss]-{:.4f}-[dice_score]-{:.2f}-[iou_score]-{:.2f}-'.format(model_name, CFG.in_chans , epoch_loss / len(valid_loader), sum(dice_scores) / len(dice_scores), sum(iou_scores) / len(iou_scores)) + str(epoch) + '-epoch.pkl')

In [10]:
criterion = smp.losses.SoftBCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(),
                        lr=CFG.lr,
                        betas=(0.9, 0.999),
                        weight_decay=CFG.weight_decay
                        )
writer = SummaryWriter('result/logs')
for i in range(CFG.epochs):
    print('train:')
    train_step(train_loader, model, criterion, optimizer, writer, CFG.device, i + 1)
    print('val:')
    valid_step(valid_loader, model, criterion, CFG.device, writer, i + 1)

train:


 11%|█         | 86/782 [00:55<07:25,  1.56it/s, epoch=1, gpu_mem=8.29 GB, loss=0.5798, lr=1.00e-05]


KeyboardInterrupt: 