In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import shutil
import time
import zarr
import glob
import PIL.Image as Image
import random
import torch.utils.data as data
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
from ipywidgets import interact, fixed
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

# INPUT_FOLDER = "/kaggle/input/vesuvius-challenge-ink-detection"
# WORKING_FOLDER = "/kaggle/working/"
# TEMP_FOLDER = "kaggle/temp/"
INPUT_FOLDER = "data/"
WORKING_FOLDER = "working/"
TEMP_FOLDER = "temp/"
TEST_PREFIX = ['data/test/a/', 'data/test/b/']
BUFFER = 32  # Buffer size in x and y direction
Z_START = 29 # First slice in the z direction to use
Z_DIM = 6  # Number of slices in the z direction
TRAINING_STEPS = 30000
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IS_TRAIN = True
IF_ZARR = True
CHEPOINT = 'result/dataset-1-ResNet3D-DIM-16-[train_loss]-0.0540-[dice_score]-0.94-[iou_score]-0.89-5-epoch.pkl'
FT = True # 是否加载预训练权重
THRESHOLD = 0.55 # mask阈值

In [2]:
class TimerError(Exception):
    pass

class Timer():
    def __init__(self, text=None):
        if text is not None:
            self.text = text + ": {:0.4f} seconds"
        else:
            self.text = "Elapsed time: {:0.4f} seconds"
        def logfunc(x):
            print(x)
        self.logger = logfunc
        self._start_time = None

    def start(self):
        if self._start_time is not None:
            raise TimerError("Timer is already running.  Use .stop() to stop it.")
        self._start_time = time.time()

    def stop(self):
        if self._start_time is None:
            raise TimerError("Timer is not running.  Use .start() to start it.")
        elapsed_time = time.time() - self._start_time
        self._start_time = None

        if self.logger is not None:
            self.logger(self.text.format(elapsed_time))

        return elapsed_time

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.stop()

The FragmentImageData class is used to store the compressed data on disk. It takes three parameters: the sample type ("test" or "train"), the sample index in the folder, and a boolean that determines whether the data is stored in the (persistent) working directory or the temporary directory. If stored persistently, then once generated, the zarr data format can be quickly loaded in other notebooks.

If the zarr file does not already exist, it will be generated by parsing the individual image files in the corresponding input directory. Otherwise, it will quickly load from the zarr file.

The image data can be accessed as attributes of the object, largely (with the exception of fancy indexing) treated as numpy arrays:

surface_volume: the 3D X-ray tomography data
mask: the 2D boolean mask describing where data exists
truth: (Training data only) the 2D boolean mask with the ink truth set
infrared: (Training data only) the 2D infrared image of the parchment

In [3]:
class FragmentImageException(Exception):
    pass

class FragmentImageData:
    """A general class that uses persistent zarr objects to store the surface volume data,
    binary data mask, and for training sets, the truth data and infrared image of a papyrus
    fragment, in a compressed and efficient way.
    """
    def __init__(self, sample_type: str, sample_index: str, working: bool = True):
        if sample_type not in ("test, train"):
            raise FragmentImageException(
                f"Invalid sample type f{sample_type}, must be one of 'test' or 'train'"
            )
        zarrpath = self._zarr_path(sample_type, sample_index, working)
        if os.path.exists(zarrpath):
            self.zarr = self.load_from_zarr(zarrpath)
        else:
            dirpath = os.path.join(INPUT_FOLDER, sample_type, sample_index)
            if not os.path.exists(dirpath):
                raise FragmentImageException(
                    f"No input data found at f{zarrpath} or f{dirpath}"
                )
            self.zarr = self.load_from_directory(dirpath, zarrpath)
    
    @property
    def surface_volume(self):
        return self.zarr.surface_volume
    
    @property
    def mask(self):
        return self.zarr.mask
    
    @property
    def truth(self):
        return self.zarr.truth
    
    @property
    def infrared(self):
        return self.zarr.infrared
    
    @staticmethod
    def _zarr_path(sample_type: str, sample_index: str, working: bool = True):
        filename = f"{sample_type}-{sample_index}.zarr"
        if working:
            return os.path.join(WORKING_FOLDER, filename)
        else:
            return os.path.join(TEMP_FOLDER, filename)
    
    @staticmethod
    def clean_zarr(sample_type: str, sample_index: str, working: bool = True):
        zarrpath = FragmentImageData._zarr_path(sample_type, sample_index, working)
        if os.path.exists(zarrpath):
            shutil.rmtree(zarrpath)
    
    @staticmethod
    def load_from_zarr(filepath):
        with Timer("Loading from existing zarr"):
            return zarr.open(filepath, mode="r")
    
    @staticmethod
    def load_from_directory(dirpath, zarrpath):
        if os.path.exists(zarrpath):
            raise FragmentImageException(
                f"Trying to overwrite existing zarr at f{zarrpath}"
            )
        # Initialize the root zarr group and write the file
        root = zarr.open_group(zarrpath, mode="w")
        # Load in the surface volume tif files
        with Timer("Surface volume loading"):
            init = True
            imgfiles = sorted([
                imgfile for imgfile in
                os.listdir(os.path.join(dirpath, "surface_volume"))
            ])
            imgfiles = imgfiles[Z_START:Z_START+Z_DIM]
            for imgfile in imgfiles:
                print(f"Loading file {imgfile}", end="\r")
                img_data = np.array(
                    Image.open(os.path.join(dirpath, "surface_volume", imgfile))
                )
                if init:
                    surface_volume = root.zeros(
                        name="surface_volume",
                        shape=(img_data.shape[0], img_data.shape[1], len(imgfiles)),
                        chunks=(1000, 1000, 4),
                        dtype=img_data.dtype,
                        write_empty_chunks=False,
                    )
                    init = False
                z_index = int(imgfile.split(".")[0]) - Z_START
                surface_volume[:,:,z_index] = img_data
        # Load in the mask
        with Timer("Mask loading"):
            img_data = np.array(Image.open(os.path.join(dirpath, "mask.png")), dtype=bool)
            mask = root.array(
                name="mask",
                data=img_data,
                shape=img_data.shape,
                chunks=(1000, 1000),
                dtype=img_data.dtype,
                write_empty_chunks=False,
            )
        # Load in the truth set (if it exists)
        with Timer("Truth set loading"):
            truthfile = os.path.join(dirpath, "inklabels.png")
            if os.path.exists(truthfile):
                img_data = np.array(Image.open(truthfile), dtype=bool)
                truth = root.array(
                    name="truth",
                    data=img_data,
                    shape=img_data.shape,
                    chunks=(1000, 1000),
                    dtype=img_data.dtype,
                    write_empty_chunks=False,
                )
        # Load in the infrared image (if it exists)
        with Timer("Infrared image loading"):
            irfile = os.path.join(dirpath, "ir.png")
            if os.path.exists(irfile):
                img_data = np.array(Image.open(irfile))
                infrared = root.array(
                    name = "infrared",
                    data = img_data,
                    shape = img_data.shape,
                    chunks = (1000, 1000),
                    dtype=img_data.dtype,
                    write_empty_chunks=False,
                )
        return root 

Let's see how long it takes to generate a new zarr file from scratch on the first training set data. Note we need to clean up any pre-existing zarr data first, otherwise it will load directly from there instead of reading the input images.

In [4]:
# 第一次运行需要这段代码生成缓存
if IF_ZARR:
    FragmentImageData.clean_zarr("train", "1")
    data = FragmentImageData("train", "1")
    FragmentImageData.clean_zarr("train", "2")
    data = FragmentImageData("train", "2")
    FragmentImageData.clean_zarr("train", "3")
    data = FragmentImageData("train", "3")
    FragmentImageData.clean_zarr("test", "a")
    data = FragmentImageData("test", "a")
    FragmentImageData.clean_zarr("test", "b")
    data = FragmentImageData("test", "b")

Surface volume loading: 36.1081 seconds
Mask loading: 0.0927 seconds
Truth set loading: 0.0912 seconds
Infrared image loading: 0.3774 seconds
Loading file 27.tif



Surface volume loading: 80.4000 seconds
Mask loading: 0.2657 seconds
Truth set loading: 0.2687 seconds
Infrared image loading: 1.1186 seconds
Surface volume loading: 33.4963 seconds
Mask loading: 0.9128 seconds
Truth set loading: 0.4121 seconds
Infrared image loading: 3.4408 seconds
Surface volume loading: 55.8678 seconds
Mask loading: 0.0365 seconds
Truth set loading: 0.0000 seconds
Infrared image loading: 0.0000 seconds
Surface volume loading: 7.1472 seconds
Mask loading: 0.1005 seconds
Truth set loading: 0.0000 seconds
Infrared image loading: 0.0000 seconds


In [None]:
train_data_1 = FragmentImageData("train", "1")
train_data_2 = FragmentImageData("train", "2")
train_data_3 = FragmentImageData("train", "3")
test_data_a = FragmentImageData("test", "a")
test_data_b = FragmentImageData("test", "b")
train_data_list = [train_data_1, train_data_2, train_data_3]
test_data_list = [test_data_a, test_data_b]

In [None]:
def train_val_split(mask, val_percent=0.3):
    # Convert the mask to a Numpy array
    mask_array = np.array(mask)
    # Generate random points within the mask
    num_points = TRAINING_STEPS * BATCH_SIZE  # Number of points to generate
    height, width = mask_array.shape
    points = []
    while len(points) < num_points:
        x = np.random.randint(0, width)
        y = np.random.randint(0, height)
        if mask_array[y, x] and not (x < BUFFER or x >= width-BUFFER or y < BUFFER or y >= height-BUFFER):
            points.append((y, x))
    n = int(num_points * val_percent)
    return points[:-n], points[-n:]
    

class SubvolumeDataset(data.Dataset):
    def __init__(self, image_stack, label, pixels, is_train):
        self.image_stack = image_stack
        self.label = label
        self.pixels = pixels
        self.is_train = is_train
    def __len__(self):
        return len(self.pixels)
    def __getitem__(self, index):
        if self.is_train:
            y, x = self.pixels[index]
            subvolume = self.image_stack[y-BUFFER:y+BUFFER+1, x-BUFFER:x+BUFFER+1, :]
            subvolume = subvolume.unsqueeze(0).permute(0, 3, 1, 2)
            inklabel = self.label[y, x].view(1)
            return subvolume, inklabel
        else:
            subvolume = self.image_stack[y-BUFFER:y+BUFFER+1, x-BUFFER:x+BUFFER+1, :]
            subvolume = subvolume.squeeze(0).permute(0, 3, 1, 2)
            return subvolume
# IOU and Dice Score
def dice_coef(y_true, y_pred, thr=0.5, dim=(0, 1), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2 * inter + epsilon) / (den + epsilon)).mean()
    return dice


def iou_coef(y_true, y_pred, thr=0.5, dim=(0, 1), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred > thr).to(torch.float32)
    inter = (y_true * y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true * y_pred).sum(dim=dim)
    iou = ((inter + epsilon) / (union + epsilon)).mean()
    return iou

We can see it took on the order of a minute to load all the images for this training dataset. However, now that we have the data on disk in our working directory, we can reload the data from that zarr much faster:

In [None]:
with Timer():
    for i in range(len(train_data_list)):
        fig, ax = plt.subplots(1, 1)
        ax.set_title(str(i) + "_ir.png")
        ax.imshow(train_data_list[i].infrared)
    plt.show()

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.stride = stride

        if in_channels != out_channels or stride != 1:
            self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            self.bn_shortcut = nn.BatchNorm3d(out_channels)
        else:
            self.shortcut = nn.Identity()
            self.bn_shortcut = None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)

        identity = self.shortcut(identity)
        if self.bn_shortcut is not None:
            identity = self.bn_shortcut(identity)

        out += identity
        out = self.relu(out)

        return out
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet3D, self).__init__()
        self.in_channels = 8

        self.conv1 = nn.Conv3d(1, 8, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(8)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self.make_layer(block, 8, layers[0])
        self.layer2 = self.make_layer(block, 16, layers[1], stride=2)
        self.layer3 = self.make_layer(block, 32, layers[2], stride=2)
        self.layer4 = self.make_layer(block, 64, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = nn.Flatten(start_dim=1)
        self.linear1 = nn.LazyLinear(256)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.LazyLinear(128)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.LazyLinear(num_classes)
        self.sigmoid = nn.Sigmoid()

    def make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # x = x.view(x.size(0), -1)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.linear3(x)
        x = self.sigmoid(x)

        return x
class ResNet3DLess(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet3DLess, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self.make_layer(block, 64, layers[0])
        self.layer2 = self.make_layer(block, 128, layers[1], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((2, 2, 2))
        self.flatten = nn.Flatten(start_dim=1)
        # self.linear1 = nn.LazyLinear(1024)
        # self.drop1 = nn.Dropout(0.1)
        # self.linear2 = nn.LazyLinear(512)
        # self.drop2 = nn.Dropout(0.1)
        self.linear3 = nn.LazyLinear(128)
        self.relu = nn.ReLU()
        self.linear4 = nn.LazyLinear(num_classes)
        self.sigmoid = nn.Sigmoid()

    def make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels

        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)

        x = self.avgpool(x)
        x = self.flatten(x)
        # x = self.linear1(x)
        # x = self.drop1(x)
        # x = self.linear2(x)
        # x = self.drop2(x)
        x = self.linear3(x)
        x = self.relu(x)
        x = self.linear4(x)
        x = self.sigmoid(x)

        return x

Now we'll train the model. Conceptually it looks like this:

<a href="https://user-images.githubusercontent.com/22727759/224853655-3fad9edb-c798-452e-94d0-f74efe71c08e.mp4"><img src="https://user-images.githubusercontent.com/22727759/224853385-ed190d89-f466-469c-82a9-499881759d57.gif"/></a>

This typically takes about 10 minutes.

In [None]:
model = ResNet3D(block=ResNetBlock, layers=[1, 1, 1, 1], num_classes=1).to(DEVICE)
model_name = 'ResNet3D'
if FT:
    try:
        checkpoint = torch.load(CHEPOINT, map_location=DEVICE)
        models_dict = model.state_dict()
        for model_part in models_dict:
            if model_part in checkpoint:
                models_dict[model_part] = checkpoint[model_part]
        model.load_state_dict(models_dict)
        print('Checkpoint loaded')
    except:
        print('Checkpoint not loaded')
        pass

In [None]:
if IS_TRAIN:
    # 实例化SummaryWriter对象
    torch.cuda.empty_cache()
    writer = SummaryWriter('result/logs')
    EPOCH = 5
    T_max = int(30000 / BATCH_SIZE * EPOCH) + 50
    min_lr = 0.000001
    print('''
    Starting training:
        Model: {}
        Epochs: {}
        Batch size: {}
        Learning rate: {}
        CUDA: {}
    '''.format(model_name,
               EPOCH,
               BATCH_SIZE,
               LEARNING_RATE,
               torch.cuda.is_available()))
    criterion = nn.BCELoss()
    optimizer = optim.AdamW(model.parameters(),
                            lr=LEARNING_RATE,
                            betas=(0.9, 0.999),
                            weight_decay=0.01
                            )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=min_lr)
    max_memory = torch.cuda.max_memory_allocated(device=DEVICE) / 1E9 if torch.cuda.is_available() else 0
    # 循环训练 1~3中的数据，每轮数据只抽取了TRAINING_STEPS的长度，也可以全部加入
    iter = 6
    for index in range(1, len(train_data_list)):
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        # 加载数据
        pixels_train_rect , pixels_val_rect= train_val_split(train_data_list[index].mask)
        image_stack = torch.from_numpy(np.array(train_data_list[index].surface_volume[:, :, Z_START:Z_START+Z_DIM], dtype=np.float32) / 65535.0)
        label = torch.from_numpy(np.array(train_data_list[index].truth)).float()
        train_dataset = SubvolumeDataset(image_stack, label, pixels_train_rect, IS_TRAIN)
        eval_dataset = SubvolumeDataset(image_stack, label, pixels_val_rect, IS_TRAIN)

        for epoch in range(1, EPOCH + 1):
            train_loader = data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
            eval_loader = data.DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False)
            epoch_loss = 0
            model.train()
            # TRAINING_STEPS = len(train_loader)
            bar = tqdm(enumerate(train_loader), total=len(train_loader)) 
            for i, (subvolumes, inklabels) in bar:
                optimizer.zero_grad()
                outputs = model(subvolumes.to(DEVICE))
                loss = criterion(outputs, inklabels.to(DEVICE))
                loss.backward()
                optimizer.step()
                scheduler.step()
                mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
                bar.set_postfix(loss=f'{loss.item():0.4f}', epoch=iter, dataset=str(index + 1) ,gpu_mem=f'{mem:0.2f} GB')
                epoch_loss += loss.item()
            #     for j, pred_value in enumerate(outputs):
            #         output[pixels_outside_rect[i*BATCH_SIZE+j]] = pred_value
            # # 使用make_grid将图片转换成网格形式，这里是每训练SUMMERY_SIZE步就会把结果打印在tensorboard中
            # pred_mask = make_grid(output.to(DEVICE), normalize=True)
            # true_mask = make_grid(lable_list[index].to(DEVICE), normalize=True)
            # 使用add_image方法将图片添加到TensorBoard中
            # writer.add_image('Train/True_mask', true_mask, global_step=epoch, dataformats="CHW")
            # writer.add_image('Train/Pred_mask', pred_mask, global_step=epoch, dataformats="CHW")
            writer.add_scalar('Train/Loss', epoch_loss / len(train_loader), iter)
            output = torch.zeros(train_data_list[index].truth.shape).float()
            true = torch.zeros(train_data_list[index].truth.shape).float()
            model.eval()
            with torch.no_grad():
                for i, (subvolumes, inklabels) in enumerate(tqdm(eval_loader)):
                    outputs = model(subvolumes.to(DEVICE))
                    for j, (value, true_value) in enumerate(zip(outputs, inklabels)):
                        output[pixels_val_rect[i*BATCH_SIZE+j]] = value
                        true[pixels_val_rect[i*BATCH_SIZE+j]] = true_value

                # 计算准确率
                dice_score = dice_coef(true.to(DEVICE), output.to(DEVICE), thr=THRESHOLD).item()
                iou_socre = iou_coef(true.to(DEVICE), output.to(DEVICE), thr=THRESHOLD).item()
                        
                # 使用make_grid将图片转换成网格形式
                pred_mask = make_grid(output.to(DEVICE), normalize=True)
                true_mask = make_grid(true.to(DEVICE), normalize=True)
                # 使用add_image方法将图片添加到TensorBoard中
                writer.add_image('Valid/True_mask', true_mask, global_step=iter, dataformats="CHW")
                writer.add_image('Valid/Pred_mask', pred_mask, global_step=iter, dataformats="CHW")
                iter += 1

                # fig, (ax1, ax2) = plt.subplots(1, 2)
                # ax1.imshow(output.cpu(), cmap='gray')
                # ax2.imshow(label.cpu(), cmap='gray')
                # plt.show()
                writer.add_scalar('Val/IOU', iou_socre, epoch)
                writer.add_scalar('Val/Dice', dice_score, epoch)
            torch.save(model.state_dict(), 'result/dataset-' + str(index + 1) +  '-{}-DIM-{}-[train_loss]-{:.4f}-[dice_score]-{:.2f}-[iou_score]-{:.2f}-'.format(model_name, Z_DIM ,epoch_loss / TRAINING_STEPS, dice_score, iou_socre) + str(epoch) + '-epoch.pkl')
        del image_stack
        del label
    writer.close()

Finally, we'll generate a prediction image. We'll use the model to predict the presence of ink for each pixel in our rectangle (the val set). Conceptually it looks like this:

<a href="https://user-images.githubusercontent.com/22727759/224853653-7cffd0a4-c6fa-49a2-93c1-e3c820863a51.mp4"><img src="https://user-images.githubusercontent.com/22727759/224853379-09ae991e-02be-4ecc-a652-313165b3005c.gif"/></a>


This should take about a minute.

Remember that the model has never seen the label data within the rectangle before!

We'll plot it side-by-side with the label image. Are you able to recognize the letter "P" in it?

In [None]:
if not IS_TRAIN:
    output_list = []
    for index in range(len(test_image_stack_list)):
        test_dataset = SubvolumeDataset(test_image_stack_list[index], None, pixels_test_rect_list[index], IS_TRAIN)
        test_eval_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        output = torch.zeros_like(shape_list[index]).float()
        model.eval()
        with torch.no_grad():
            for i, (subvolumes) in enumerate(tqdm(test_eval_loader)):
                for j, value in enumerate(model(subvolumes.to(DEVICE))):
                    output[pixels_test_rect_list[index][i*BATCH_SIZE+j]] = value
            output_list.append(output)
            out = output_list[index].gt(THRESHOLD).cpu().float().numpy()
            import cv2
            cv2.imwrite(str(index + 1) + '.png', out * 255)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(output_list[0].gt(THRESHOLD).cpu(), cmap='gray')
    ax1.imshow(output_list[1].gt(THRESHOLD).cpu(), cmap='gray')
    plt.show()
    

Since our output has to be binary, we have to choose a threshold, say 40% confidence.

In [None]:
if IS_TRAIN:
    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.imshow(output.gt(THRESHOLD).cpu(), cmap='gray')
    ax2.imshow(label.cpu(), cmap='gray')
    plt.show()

Finally, Kaggle expects a runlength-encoded submission.csv file, so let's output that.

In [None]:
def rle(output):
    flat_img = np.where(output.flatten().cpu() > THRESHOLD, 1, 0).astype(np.uint8)
    starts = np.array((flat_img[:-1] == 0) & (flat_img[1:] == 1))
    ends = np.array((flat_img[:-1] == 1) & (flat_img[1:] == 0))
    starts_ix = np.where(starts)[0] + 2
    ends_ix = np.where(ends)[0] + 2
    lengths = ends_ix - starts_ix
    return " ".join(map(str, sum(zip(starts_ix, lengths), ())))
# rle_output = rle(output)
# This doesn't make too much sense, but let's just output in the required format
# so notebook works as a submission. :-)
# print("Id,Predicted\na," + rle_output + "\nb," + rle_output, file=open('submission.csv', 'w'))

Hurray! We've detected ink! Now, can you do better? :-) For example, you could start with this [example submission](https://www.kaggle.com/code/danielhavir/vesuvius-challenge-example-submission).

In [None]:
rle_list = []
for output in outputs:
    rle_sample = rle(output)
    rle_list.append(rle_sample)
print("Id,Predicted\na," + rle_list[0] + "\nb," + rle_list[1], file=open('submission.csv', 'w'))