In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn.functional as F
import torch.optim as optim
from functools import partial
from torchsummary import summary
from einops.layers.torch import Rearrange
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.unetplusplus.decoder import UnetPlusPlusDecoder
from segmentation_models_pytorch.base import SegmentationHead
from torch.cuda.amp import GradScaler
from torchvision.utils import make_grid
import warnings
import os
import random
import pandas as pd
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from model.ink_model import get_model
# 忽略所有警告
warnings.filterwarnings('ignore')

class CFG:
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    checpoint = 'result/resnet3d-seg-2d/resnet18/only_mask/3/Resnet3D-DIM-22-[eval_loss]-0.2345-[dice_score]-0.66-11-epoch.pkl'
    # ============== comp exp name =============
    comp_name = 'vesuvius'

    # # comp_dir_path = './'
    # comp_dir_path = '/kaggle/input/'
    # comp_folder_name = 'vesuvius-challenge-ink-detection'
    # # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    # comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
        # comp_dir_path = './'
    comp_dir_path = ''
    comp_folder_name = 'data'
    # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'

    img_path = 'working/'
    
    encoder_name = 'convnext3d' # resnet3d、r2plus1d_18、r3d_18、mc3_18、convnext3d、unet3d_down
    decoder_name = 'cnn' # cnn、uper_head、unetplusplus、unet3d_up、cnn+uperhead
    mix_up = False

    # ============== pred target =============
    target_size = 1

    # ============== model cfg =============
    model_name = encoder_name + '-' + decoder_name

    in_idx = [i for i in range(18, 38)] # 21 43

    valid_id = 1 # 1 2 3 random

    rate_valid = 0.05


    in_chans =  len(in_idx)# 22
    # ============== training cfg =============
    size = 224

    train_tile_size_1 = 224
    train_stride_1 = train_tile_size_1 // 4

    train_tile_size_2 = 224
    train_stride_2 = train_tile_size_2 // 2

    train_tile_size_3 = 224
    train_stride_3= train_tile_size_3 // 4

    valid_tile_size = 224
    valid_stride = valid_tile_size // 2

    train_batch_size = 8 # 32
    valid_batch_size = 8
    use_amp = True

    inplanes = [64, 128, 256, 512]

    epochs = 30 # 30

    # lr = 1e-4 / warmup_factor
    lr = 1e-5

    # ============== fixed =============
    pretrained = False

    min_lr = 1e-6
    weight_decay = 1e-4
    max_grad_norm = 1000

    num_workers = 4

    seed = 42

    threshhold = 0.5

    all_best_dice = 0
    all_best_loss = np.float('inf')

    shape_list = []
    test_shape_list = []

    val_mask = None
    val_label = None

    # ============== augmentation =============
    train_aug_list = [
        # A.RandomCrop(height=size, width=size, p=0.5),
        A.Resize(size, size),
        A.Rotate(limit=90,  p=0.5),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(size * 0.3), max_height=int(size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.ChannelShuffle(p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        A.Resize(size, size),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]
    test_aug_list = [
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]
seed = CFG.seed
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
class Ink3DModel(nn.Module):
    def __init__(self, encoder_name, decoder_name, mix_up=True, **kwargs):
        super().__init__()
        self.encoder_name = encoder_name
        self.decoder_name = decoder_name
        self.mix_up = mix_up
        self.encoder, self.decoder = get_model(encoder_name=encoder_name, decoder_name=decoder_name, inplanes=CFG.inplanes, **kwargs)

    def forward(self, x):
        feat_maps = self.encoder(x)
        if self.mix_up:
            pred_mask = []
            for decoder in self.decoder:
                mask = decoder(feat_maps)
                pred_mask.append(mask)
            pred_mask = torch.stack(pred_mask, dim=0)
            pred_mask = torch.mean(pred_mask, dim=0)
        else:
            pred_mask = self.decoder(feat_maps)
        return pred_mask

class Ink3DUnet(smp.UnetPlusPlus):
    def __init__(self, encoder_name, decoder_name ,**kwargs):
        super(Ink3DUnet, self).__init__(**kwargs)
        self.encoder_name = encoder_name
        self.encoder, self.decoder = get_model(encoder_name=encoder_name, decoder_name=decoder_name, **kwargs)
        self.segmentation_head = SegmentationHead(
            in_channels=kwargs['decoder_channels'][-1],
            out_channels=kwargs['classes'],
            activation=None,
            kernel_size=3,
            upsampling=2
        )

In [None]:
if CFG.decoder_name == 'unetplusplus':
    model = Ink3DUnet(encoder_name=CFG.encoder_name, 
                      decoder_name=CFG.decoder_name,
                      classes=CFG.target_size ,
                      decoder_attention_type='scse', 
                      encoder_depth=4, 
                      decoder_channels=CFG.inplanes[::-1])
else:
    model = Ink3DModel(encoder_name=CFG.encoder_name, decoder_name=CFG.decoder_name, classes=CFG.target_size, mix_up=CFG.mix_up)
print(model)
# i = torch.ones((2, 1, 22, 224, 224))
# print(model(i).shape)
model = nn.DataParallel(model, device_ids=[0])
model = model.to(CFG.device)
if CFG.pretrained:
    try:
        checkpoint = torch.load(CFG.checpoint, map_location=CFG.device)
        models_dict = model.state_dict()
        for model_part in models_dict:
            if model_part in checkpoint:
                models_dict[model_part] = checkpoint[model_part]
        model.load_state_dict(models_dict)
        print('Checkpoint loaded')
    except:
        print('Checkpoint not loaded')
        pass

In [None]:
def read_image_mask(fragment_id, tile_size):

    images = []
    idxs = CFG.in_idx

    for i in tqdm(idxs):
        
        image = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/surface_volume/{i:02}.tif", 0)

        pad0 = (tile_size - image.shape[0] % tile_size)
        pad1 = (tile_size - image.shape[1] % tile_size)

        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

        images.append(image)
    images = np.stack(images, axis=2)

    mask = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/inklabels.png", 0)
    mask = np.pad(mask, [(0, pad0), (0, pad1)], constant_values=0)

    mask = mask.astype('float32')
    mask /= 255.0

    mask_location = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/mask.png", 0)
    mask_location = np.pad(mask_location, [(0, pad0), (0, pad1)], constant_values=0)

    mask_location = mask_location / 255
    
    return images, mask, mask_location

In [None]:
def get_random_train_valid_dataset(rate_valid=0.05):
    images = []
    masks = []
    xyxys = []
    for fragment_id in range(1, 4):
        if fragment_id == 1:
            tile_size = CFG.train_tile_size_1
            stride = CFG.train_stride_1
        elif fragment_id == 2:
            tile_size = CFG.train_tile_size_2
            stride = CFG.train_stride_2
        else:
            tile_size = CFG.train_tile_size_3
            stride = CFG.train_stride_3

        image, mask, mask_location = read_image_mask(fragment_id, tile_size)
        x1_list = list(range(0, image.shape[1] - tile_size + 1, stride))
        y1_list = list(range(0, image.shape[0] - tile_size + 1, stride))

        for y1 in y1_list:
            for x1 in x1_list:
                y2 = y1 + tile_size
                x2 = x1 + tile_size
                if np.sum(mask_location[y1:y2, x1:x2]) == 0:
                    continue
                images.append(image[y1:y2, x1:x2])
                masks.append(mask[y1:y2, x1:x2, None])
                xyxys.append([x1, y1, x2, y2])

    '''random images&masks get train&valid'''
    images_masks_list = list(zip(images, masks, xyxys))
    random.shuffle(images_masks_list)
    images[:], masks[:], xyxys[:] = zip(*images_masks_list)
    valid_numbers = int(len(images)*rate_valid)

    train_images = images[:-valid_numbers]
    train_masks = masks[:-valid_numbers]
    valid_images = images[-valid_numbers:]
    valid_masks = masks[-valid_numbers:]
    valid_xyxys = xyxys[-valid_numbers:]

    return train_images, train_masks, valid_images, valid_masks, valid_xyxys

def get_train_valid_dataset():
    train_images = []
    train_masks = []

    valid_images = []
    valid_masks = []
    valid_xyxys = []

    for fragment_id in range(1, 4):
        
        if fragment_id == 1:
            tile_size = CFG.train_tile_size_1
            stride = CFG.train_stride_1
        elif fragment_id == 2:
            tile_size = CFG.train_tile_size_2
            stride = CFG.train_stride_2
        else:
            tile_size = CFG.train_tile_size_3
            stride = CFG.train_stride_3

        if fragment_id == CFG.valid_id:
            tile_size = CFG.valid_tile_size
            stride = CFG.valid_stride
            
        image, mask, mask_location = read_image_mask(fragment_id, tile_size)
        x1_list = list(range(0, image.shape[1]-tile_size+1, stride))
        y1_list = list(range(0, image.shape[0]-tile_size+1, stride))

        for y1 in y1_list:
            for x1 in x1_list:
                y2 = y1 + tile_size
                x2 = x1 + tile_size
                if np.sum(mask_location[y1:y2, x1:x2]) == 0:
                    continue
        
                if fragment_id == CFG.valid_id:
                    if CFG.valid_id  == 2:
                        if  y2 <4800 or y2 > 4800 + 4096 + 2048 or x2 > 640+ 4096 +2048 or x2 < 640:
                            continue
                    valid_images.append(image[y1:y2, x1:x2])
                    valid_masks.append(mask[y1:y2, x1:x2, None])

                    valid_xyxys.append([x1, y1, x2, y2])
                else:
                    train_images.append(image[y1:y2, x1:x2])
                    train_masks.append(mask[y1:y2, x1:x2, None])

    return train_images, train_masks, valid_images, valid_masks, valid_xyxys

In [None]:
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)

    # print(aug)
    return aug

class Ink_Detection_Dataset(data.Dataset):
    def __init__(self, images, cfg, labels=None, transform=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        self.transform = transform

    def __len__(self):
        # return len(self.df)
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            data = self.transform(image=image, mask=label)
            if CFG.encoder_name == 'swin' or CFG.encoder_name == 'convnext':
                image = data['image']
            else:
                image = data['image'].unsqueeze(0)
            label = data['mask']

        return image, label

In [None]:
if CFG.valid_id == "random":
    train_images, train_masks, valid_images, valid_masks, _ = get_random_train_valid_dataset(CFG.rate_valid)
else:
    train_images, train_masks, valid_images, valid_masks, valid_xyxys = get_train_valid_dataset()
    valid_xyxys = np.stack(valid_xyxys)

In [None]:
if CFG.valid_id != 'random':
    valid_xyxys = np.stack(valid_xyxys)
else:
    valid_xyxys = None

In [None]:
train_dataset = Ink_Detection_Dataset(
    train_images, CFG, labels=train_masks, transform=get_transforms(data='train', cfg=CFG))
valid_dataset = Ink_Detection_Dataset(
    valid_images, CFG, labels=valid_masks, transform=get_transforms(data='valid', cfg=CFG))

train_loader = data.DataLoader(train_dataset,
                          batch_size=CFG.train_batch_size,
                          shuffle=True,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=True,
                          )
valid_loader = data.DataLoader(valid_dataset,
                          batch_size=CFG.valid_batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

In [None]:
from warmup_scheduler import GradualWarmupScheduler


class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    """
    https://www.kaggle.com/code/underwearfitting/single-fold-training-of-resnet200d-lb0-965
    """
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(
            optimizer, multiplier, total_epoch, after_scheduler)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [
                        base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

def get_scheduler(cfg, optimizer):
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, cfg.epochs, eta_min=1e-7)
    scheduler = GradualWarmupSchedulerV2(
        optimizer, multiplier=10, total_epoch=1, after_scheduler=scheduler_cosine)

    return scheduler

def scheduler_step(scheduler, avg_val_loss, epoch):
    scheduler.step(epoch)

def dice_coef(targets, preds, thr=0.5, beta=0.5, smooth=1e-5):

    #comment out if your model contains a sigmoid or equivalent activation layer
    # flatten label and prediction tensors
    preds = (preds > thr).view(-1).float()
    targets = targets.view(-1).float()

    y_true_count = targets.sum()
    ctp = preds[targets==1].sum()
    cfp = preds[targets==0].sum()
    beta_squared = beta * beta

    c_precision = ctp / (ctp + cfp + smooth)
    c_recall = ctp / (y_true_count + smooth)
    dice = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall + smooth)

    return dice

In [None]:
def TTA(x:torch.Tensor, model:nn.Module):
    shape = x.shape
    x = [ x,*[torch.rot90(x,k=i,dims=(-2,-1)) for i in range(1,4)]]
    x = torch.cat(x,dim=0)
    x = model(x)
    x = x.reshape(4,shape[0],1,*shape[-2:])
    x = [torch.rot90(x[i],k=-i,dims=(-2,-1)) for i in range(4)]
    x = torch.stack(x,dim=0)
    x = torch.sigmoid(x)
    return x.mean(0)

In [None]:
def train_step(train_loader, model, criterion, optimizer, writer, device, epoch):
    model.train()
    epoch_loss = 0
    scaler = GradScaler(enabled=CFG.use_amp)
    bar = tqdm(enumerate(train_loader), total=len(train_loader)) 
    for step, (image, label) in bar:
        optimizer.zero_grad()
        outputs = model(image.to(device))
        loss = criterion(outputs, label.to(device))
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        bar.set_postfix(loss=f'{loss.item():0.4f}', epoch=epoch ,gpu_mem=f'{mem:0.2f} GB', lr=f'{optimizer.state_dict()["param_groups"][0]["lr"]:0.2e}')
        epoch_loss += loss.item()
    writer.add_scalar('Train/Loss', epoch_loss / len(train_loader), epoch)
    return epoch_loss / len(train_loader)

def valid_step(valid_loader, model, valid_xyxys, valid_mask , criterion, device, writer, epoch):
    model.eval()
    if CFG.valid_id != 'random':
        mask_pred = np.zeros(valid_mask.shape)
        mask_count = (1 - valid_mask).astype(np.float64)
        valid_mask_gt = np.zeros(valid_mask.shape)

    epoch_loss = 0
    best_th = 0
    best_dice = 0
    dice_scores = {}
    for th in np.arange(3, 7, 0.5) / 10:
        dice_scores[th] = []

    bar = tqdm(enumerate(valid_loader), total=len(valid_loader)) 
    for step, (image, label) in bar:
        image = image.to(device)
        label = label.to(device)
        with torch.no_grad():
            y_pred = model(image)
            loss = criterion(y_pred, label)
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        bar.set_postfix(loss=f'{loss.item():0.4f}', epoch=epoch ,gpu_mem=f'{mem:0.2f} GB')
        # make whole mask
        if CFG.valid_id != 'random':
            y_pred = torch.sigmoid(y_pred).to('cpu').numpy()
            # y_pred = y_pred.to('cpu').numpy()
            label = label.to('cpu').numpy()
            start_idx = step*CFG.valid_batch_size
            end_idx = start_idx + CFG.valid_batch_size
            for i, (x1, y1, x2, y2) in enumerate(valid_xyxys[start_idx:end_idx]):
                mask_pred[y1:y2, x1:x2] += y_pred[i].squeeze(0)
                valid_mask_gt[y1:y2, x1:x2] = label[i].squeeze(0)
                mask_count[y1:y2, x1:x2] += np.ones((CFG.valid_tile_size, CFG.valid_tile_size))
        else:
            y_pred = torch.sigmoid(y_pred)
            for th in np.arange(3, 7, 0.5) / 10:
                dice_score = dice_coef(label, y_pred, thr=th).item()
                dice_scores[th].append(dice_score)
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(valid_loader)
    writer.add_scalar('Valid/Loss', avg_loss, epoch)
    if CFG.valid_id != 'random':
        print(f'mask_count_min: {mask_count.min()}')
        mask_pred /= mask_count
        mask_pred *= valid_mask
        has_nan = np.isnan(mask_pred).any()
        print(has_nan)
        if CFG.valid_id == 2:
            # 防止内存溢出if  y2 <4800 or y2 > 4800 + 4096 + 2048 or x2 > 640+ 4096 +2048 or x2 < 640:
            valid_mask_gt = valid_mask_gt[4800:4800+4096+2048, 640:640+4096+2048]
            mask_pred = mask_pred[4800:4800+4096+2048, 640:640+4096+2048]
            valid_mask = valid_mask[4800:4800+4096+2048, 640:640+4096+2048]
        for th in np.arange(3, 7, 0.5) / 10:
            dice_score = dice_coef(torch.from_numpy(valid_mask_gt).to(CFG.device), torch.from_numpy(mask_pred).to(CFG.device), thr=th).item()
            dice_scores[th].append(dice_score)
        for th in np.arange(3, 7, 0.5) / 10:
            dice_score = sum(dice_scores[th]) / len(dice_scores[th])
            if dice_score > best_dice:
                best_dice = dice_score
                best_th = th
        mask_pred = (mask_pred >= best_th).astype(int)
        cv2.imwrite(f'result/logs/{epoch}.png', mask_pred * 255)
        cv2.imwrite(f'result/logs/gt.png', valid_mask_gt * 255)
    else:
        print(dice_scores.keys())
        print('slice dice:')
        for th in np.arange(3, 7, 0.5) / 10:
            dice_score = sum(dice_scores[th]) / len(dice_scores[th])
            if dice_score > best_dice:
                best_dice = dice_score
                best_th = th
    print(best_dice, best_th)
    if CFG.all_best_dice < best_dice:
        print('best_th={:2f}' .format(best_th),"score up: {:2f}->{:2f}".format(CFG.all_best_dice, best_dice))       
        CFG.all_best_dice = best_dice
    torch.save(model.state_dict(), 'result/' +  '{}-DIM-{}-[eval_loss]-{:.4f}-[dice_score]-{:.2f}-'.format(CFG.model_name, CFG.in_chans , avg_loss, best_dice) + str(epoch) + '-epoch.pkl')  
    writer.add_scalar('Valid/Dice', best_dice, epoch)
    
    return avg_loss
    

In [None]:
if CFG.valid_id != 'random':
    fragment_id = CFG.valid_id
    valid_mask = cv2.imread(CFG.comp_dataset_path + f"train/{fragment_id}/mask.png", 0)
    valid_mask = valid_mask.astype('float32') / 255.
    pad0 = (CFG.valid_tile_size - valid_mask.shape[0] % CFG.valid_tile_size)
    pad1 = (CFG.valid_tile_size - valid_mask.shape[1] % CFG.valid_tile_size)
    valid_mask = np.pad(valid_mask, [(0, pad0), (0, pad1)], constant_values=0)
else:
    valid_mask = None

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(),
                        lr=CFG.lr,
                        betas=(0.9, 0.999),
                        weight_decay=CFG.weight_decay
                        )
scheduler = get_scheduler(CFG, optimizer)
writer = SummaryWriter('result/logs')

for i in range(CFG.epochs):
    print('train:')
    train_step(train_loader, model, criterion, optimizer, writer, CFG.device, i + 1)
    print('val:')
    val_loss = valid_step(valid_loader, model, valid_xyxys, valid_mask, criterion, CFG.device, writer,  i + 1)
    scheduler_step(scheduler, val_loss, i + 1)
    import gc
    gc.collect()