In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import gc
import sys
import random
from tqdm import tqdm
import os,cv2
from glob import glob
import pandas as pd

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.cuda import amp
from torch.utils.data import Dataset, DataLoader
from resnet3d import generate_model
sys.path.append("/kaggle/input/resnet3d/")

In [None]:
!pip install gdown
!gdown 1Nb4abvIkkp_ydPFA9sNPT1WakoVKA8Fa

!pip install zarr imageio-ffmpeg
!mkdir ./ckpts

In [None]:
class CFG:
    dataset_path = '/kaggle/input/vesuvius-challenge-ink-detection/'
    target_size = 1
    model_name = 'Unet'
    backbone = 'resnet3d'
    pretrained = True

    input_channels = 16
    load_channels=16
    prd=192
    stride = prd // 8
    batch_size = 24
    num_workers=2

In [None]:
def RLE(img):
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def batchnorm(x):
    dim = list(range(1, x.ndim))
    mean = x.mean(dim=dim, keepdim=True)
    std = x.std(dim=dim, keepdim=True)
    return (x - mean) / (std + 1e-9)

In [None]:
delta_dict = {
    "xx": cp.array([[1, -2, 1]], dtype=float),
    "yy": cp.array([[1], [-2], [1]], dtype=float),
    "xy": cp.array([[1, -1], [-1, 1]], dtype=float),
}

def fourier_derivative(img_shape, pair):
    delta = delta_dict[pair]
    fft = cp.fft.fftn(delta, img_shape)
    return fft * cp.conj(fft)

def soft_threshold(vec, thresh):
    return cp.sign(vec) * cp.maximum(cp.abs(vec) - thresh, 0)

def back_grad(input_image, dim):
    r, n = cp.shape(input_image)
    size = cp.array((r, n))
    pos = cp.zeros(2, dtype=int)
    tmp1 = cp.zeros((r+1, n+1), dtype=float)
    tmp2 = cp.zeros((r+1, n+1), dtype=float)
    
    tmp1[pos[0]:size[0], pos[1]:size[1]] = input_image
    tmp2[pos[0]:size[0], pos[1]:size[1]] = input_image
    
    size[dim] += 1
    pos[dim] += 1
    tmp2[pos[0]:size[0], pos[1]:size[1]] = input_image
    tmp1 -= tmp2
    size[dim] -= 1
    return tmp1[0:size[0], 0:size[1]]

def forward_grad(input_image, dim):
    r, n = cp.shape(input_image)
    size = cp.array((r, n))
    position = cp.zeros(2, dtype=int)
    tmp1 = cp.zeros((r+1, n+1), dtype=float)
    tmp2 = cp.zeros((r+1, n+1), dtype=float)
        
    size[dim] += 1
    position[dim] += 1

    tmp1[position[0]:size[0], position[1]:size[1]] = input_image
    tmp2[position[0]:size[0], position[1]:size[1]] = input_image
    
    size[dim] -= 1
    tmp2[0:size[0], 0:size[1]] = input_image
    tmp1 -= tmp2
    size[dim] += 1
    res = -tmp1[position[0]:size[0], position[1]:size[1]]
    return res

def iter_grad(input_image, b, scale, mu, dim1, dim2):
    g = back_grad(forward_grad(input_image, dim1), dim2)
    d = soft_threshold(g + b, 1 / mu)
    b = b + (g - d)
    L = scale * back_grad(forward_grad(d - b, dim2), dim1)
    return L, b

def it_xx(*args):
    return iter_grad(*args, dim1=1, dim2=1)

def it_yy(*args):
    return iter_grad(*args, dim1=0, dim2=0)

def it_xy(*args):
    return iter_grad(*args, dim1=0, dim2=1)

def it_sparse(input_image, b_sparse, scale, mu):
    d = soft_threshold(input_image + b_sparse, 1 / mu)
    b_sparse = b_sparse + (input_image - d)
    L_sparse = scale * (d - b_sparse)
    return L_sparse, b_sparse

def denoise(input_image, iter_num=100, fidelity=150, sparsity=10, continuity=0.5, mu=1):
    image_size = cp.shape(input_image)
    normed_array = (
        fourier_derivative(image_size, "xx") + 
        fourier_derivative(image_size, "yy") + 
        2 * fourier_derivative(image_size, "xy")
    )
    normed_array += (fidelity / mu) + sparsity ** 2
    b_arrays = {
        "xx": cp.zeros(image_size, dtype=float),
        "yy": cp.zeros(image_size, dtype=float),
        "xy": cp.zeros(image_size, dtype=float),
        "L1": cp.zeros(image_size, dtype=float),
    }
    grad_upd = cp.multiply(fidelity / mu, input_image)
    for i in tqdm(range(iter_num), total=iter_num):
        grad_upd = cp.fft.fftn(grad_upd)
        if i == 0:
            g = cp.fft.ifftn(grad_upd / (fidelity / mu)).real
        else:
            g = cp.fft.ifftn(cp.divide(grad_upd, normed_array)).real
        grad_upd = cp.multiply((fidelity / mu), input_image)
        
        L, b_arrays["xx"] = it_xx(g, b_arrays["xx"], continuity, mu)
        grad_upd += L
        L, b_arrays["yy"] = it_yy(g, b_arrays["yy"], continuity, mu)
        grad_upd += L
        L, b_arrays["xy"] = it_xy(g, b_arrays["xy"], 2 * continuity, mu)
        grad_upd += L
        L, b_arrays["L1"] = it_sparse(g, b_arrays["L1"], sparsity, mu)
        grad_upd += L
        
    grad_upd = cp.fft.fftn(grad_upd)
    g = cp.fft.ifftn(cp.divide(grad_upd, normed_array)).real
    
    g[g < 0] = 0
    g -= g.min()
    res = g / g.max()
    return res

In [None]:
def read_image(fragment_id):
    images = []

    mid = 65 // 2
    start = mid - CFG.input_channels // 2
    end = mid + CFG.input_channels // 2

    for i in tqdm(range(start, end)):
        image = cv2.imread(CFG.dataset_path + f"{mode}/{fragment_id}/surface_volume/{i:02}.tif", 0)
        pad0 = (CFG.prd - image.shape[0] % CFG.prd)
        pad1 = (CFG.prd - image.shape[1] % CFG.prd)
        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)
        images.append(image)

    images = np.stack(images, axis=2)
    return images

In [None]:
class CustomDataset(Dataset):
    def __init__(self, images, cfg, xys, labels=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        self.xys = xys

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        image = torch.from_numpy(image).permute(2,0,1).to(torch.float32) / 255
        image = (image - 0.45) / 0.225
        return image, self.xys[idx]

In [None]:
def make_test_dataset(fragment_id):
    test_images = read_image(fragment_id)
    
    x1_list = list(range(0, test_images.shape[1] - CFG.prd + 1, CFG.stride))
    y1_list = list(range(0, test_images.shape[0] - CFG.prd + 1, CFG.stride))
    
    test_images_list = []
    xyxys = []
    for y1 in y1_list:
        for x1 in x1_list:
            y2 = y1 + CFG.prd
            x2 = x1 + CFG.prd
            if np.all(test_images[y1:y2, x1:x2]==0):
                continue
            test_images_list.append(test_images[y1:y2, x1:x2])
            xyxys.append((x1, y1, x2, y2))
    xyxys = np.stack(xyxys)
            
    test_dataset = CustomDataset(test_images_list, CFG,xys=xyxys)
    
    test_loader = DataLoader(test_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    return test_loader, xyxys

In [None]:
class Decoder(nn.Module):
    def __init__(self, encoder_dims, upscale):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(encoder_dims[i] + encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
                nn.BatchNorm2d(encoder_dims[i-1]),
                nn.ReLU(inplace=True)
            ) for i in range(1, len(encoder_dims))])
        self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
        self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")

    def forward(self, feature_maps):
        for i in range(len(feature_maps)-1, 0, -1):
            f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
            f = torch.cat([feature_maps[i-1], f_up], dim=1)
            f_down = self.convs[i-1](f)
            feature_maps[i-1] = f_down

        x = self.logit(feature_maps[0])
        mask = self.up(x)
        return mask
    
class SegModel(nn.Module):
    def __init__(self,model_depth=34):
        super().__init__()
        self.encoder = generate_model(model_depth=model_depth, n_input_channels=1)
        self.decoder = Decoder(encoder_dims=[64, 128, 256, 512], upscale=4)
        
    def forward(self, x):
        if x.ndim==4:
            x=x[:,None]
        feat_maps = self.encoder(x)
        feat_maps_pooled = [torch.mean(f, dim=2) for f in feat_maps]
        pred_mask = self.decoder(feat_maps_pooled)
        return pred_mask
        
class CustomModel(nn.Module):
    def __init__(self, cfg=CFG, weight=None):
        super().__init__()
        self.cfg = cfg

        if cfg.backbone=="resnet3d":
            self.encoder=SegModel()
        elif cfg.backbone[:3]!="mit":
            self.encoder = smp.Unet(
                encoder_name=cfg.backbone, 
                encoder_weights=weight,
                in_channels=cfg.input_channels,
                classes=cfg.target_size,
                activation=None,
            )
        else :
            self.encoder = smp.Unet(
                encoder_name=cfg.backbone, 
                encoder_weights=weight,
                classes=cfg.target_size,
                activation=None,
            )
            print(self.encoder.encoder.patch_embed1.proj)
            out_channels=self.encoder.encoder.patch_embed1.proj.out_channels
            self.encoder.encoder.patch_embed1.proj=nn.Conv2d(cfg.input_channels,out_channels,7,4,3)

    def forward(self, images:torch.Tensor):
        if images.ndim==4:
            images=images[:,None]
        images=batchnorm(images)
        output = self.encoder(images)
        return output

def build_model(cfg, weight="imagenet"):
    print('model_name', cfg.model_name)
    print('backbone', cfg.backbone)
    model = CustomModel(cfg, weight)
    return model

In [None]:
def TTA(x:torch.Tensor,model:nn.Module):
    shape=x.shape
    x=[x,*[torch.rot90(x,k=i,dims=(-2,-1)) for i in range(1,4)]]
    x=torch.cat(x,dim=0)
    x=model(x)
    x=torch.sigmoid(x)
    x=x.reshape(4,shape[0],*shape[2:])
    x=[torch.rot90(x[i],k=-i,dims=(-2,-1)) for i in range(4)]
    x=torch.stack(x,dim=0)
    return x.mean(0)

In [None]:
in_submission = True
train_mode = False
mode = 'train' if train_mode else 'test'
threshold = 0.55
if mode == 'test':
    fragment_ids = sorted(os.listdir(CFG.dataset_path + mode))
else:
    fragment_ids = [3]

model = build_model(CFG)
model.load_state_dict(torch.load("/kaggle/input/3d-resnet-baseline-inference-model-data/resnet3d-34_3d_seg_epoch_14.pth"))
model = nn.DataParallel(model)
model = model.cuda()#.eval()
model.training

In [None]:
def fbeta_numpy(targets, preds, beta=0.5, smooth=1e-5):
    y_true_count = targets.sum()
    ctp = preds[targets==1].sum()
    cfp = preds[targets==0].sum()
    beta_squared = beta * beta

    c_precision = ctp / (ctp + cfp + smooth)
    c_recall = ctp / (y_true_count + smooth)
    fbeta = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall + smooth)
    return fbeta

In [None]:
results = []
for fragment_id in fragment_ids:
    if not in_submission:
        break
    test_loader, xyxys = make_test_dataset(fragment_id)
    
    binary_mask = cv2.imread(CFG.dataset_path + f"{mode}/{fragment_id}/mask.png", 0)
    binary_mask = (binary_mask / 255).astype(int)
    
    hei = binary_mask.shape[0]
    wid = binary_mask.shape[1]

    pad0 = (CFG.prd - binary_mask.shape[0] % CFG.prd)
    pad1 = (CFG.prd - binary_mask.shape[1] % CFG.prd)

    binary_mask = np.pad(binary_mask, [(0, pad0), (0, pad1)], constant_values=0)

    mask_pred = np.zeros(binary_mask.shape)
    mask_count = np.zeros(binary_mask.shape)
    for step, (images,xys) in tqdm(enumerate(test_loader), total=len(test_loader)):
        images = images.cuda()
        batch_size = images.size(0)
        with torch.no_grad():
            y_preds=TTA(images,model)
            
        for k, (x1, y1, x2, y2) in enumerate(xys):
            mask_pred[y1:y2, x1:x2] += y_preds[k].squeeze(0).cpu().numpy()
            mask_count[y1:y2, x1:x2] += 1
        
    print(f'mask_count_min: {mask_count.min()}')
    mask_pred /= (mask_count + 1e-7)

    
    fig, axes = plt.subplots(1, 4, figsize=(15, 8))
    axes[0].imshow(mask_count)
    axes[1].imshow(mask_pred.copy())
    
    mask_pred = cp.array(mask_pred)
    mask_pred = denoise(mask_pred, iter_num=10)
    mask_pred = mask_pred.get()
    axes[2].imshow(mask_pred)
    mask_pred = mask_pred[:hei, :wid]
    binary_mask = binary_mask[:hei, :wid]
    mask_pred = (mask_pred >= threshold).astype(np.uint8)
    mask_pred=mask_pred.astype(int)
    dice = fbeta_numpy(binary_mask, mask_pred, beta=0.5)
    print(dice, 'dice')
    mask_pred *= binary_mask
    axes[3].imshow(mask_pred)
    plt.show()
    
    inklabels_rle = RLE(mask_pred)
    
    results.append((fragment_id, inklabels_rle))
    

    del mask_pred, mask_count
    del test_loader
    
    gc.collect()
    torch.cuda.empty_cache()
    plt.clf()
    fig.clear()
    plt.close(fig)

In [None]:
! cp /kaggle/input/vesuvius-challenge-ink-detection/sample_submission.csv submission.csv
if in_submission:
    sub = pd.DataFrame(results, columns=['Id', 'Predicted'])
    #sub
    sample_sub = pd.read_csv(CFG.dataset_path + 'sample_submission.csv')
    sample_sub = pd.merge(sample_sub[['Id']], sub, on='Id', how='left')
    #sample_sub
    sample_sub.to_csv("submission.csv", index=False)
    print("ok")

In [None]:
def f1(y_pred, y_true, beta=0.5):
    y_pred_f = torch.flatten(y_pred)
    y_true_f = torch.flatten(y_true)
    tp = ((y_pred_f == y_true_f) * (y_true_f == 1)).sum()
    fp = ((y_pred_f != y_true_f) * (y_true_f == 0)).sum()
    fn = ((y_pred_f != y_true_f) * (y_true_f == 1)).sum()
    p = tp / (tp + fp)
    r = tp / (tp + fn)
    res = (1 + beta * beta) * p * r / (beta * beta * p + r)
    return res