In [1]:
!pip install /kaggle/input/einops/einops-0.6.1-py3-none-any.whl

Processing /kaggle/input/einops/einops-0.6.1-py3-none-any.whl
Installing collected packages: einops
Successfully installed einops-0.6.1
[0m

In [2]:
import sys
sys.path.append("/kaggle/input/efficientnet-pytorch")
sys.path.append("/kaggle/input/pretrained-models-pytorch")
sys.path.append("/kaggle/input/smp-github/segmentation_models.pytorch-master")

import albumentations as A
import cv2
import gc
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import timm
import timm.models.layers as layers
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import yaml

from albumentations.pytorch import ToTensorV2
from einops import rearrange
from functools import partial
from segmentation_models_pytorch.base import (
    ClassificationHead,
    SegmentationHead, 
    SegmentationModel
)
from segmentation_models_pytorch.base.initialization import initialize_decoder, initialize_head

from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from timm import create_model
from timm.models.layers import DropPath, trunc_normal_
from torchmetrics import AveragePrecision, Dice, FBetaScore
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']
  if block_type is 'proj':
  elif block_type is 'down':
  assert block_type is 'normal'


## Modules

In [3]:
def get_activation(activation):
    if activation is None:
        return nn.Identity()
    elif activation == "relu":
        return nn.ReLU(inplace=True)
    elif activation == "silu":
        return nn.SiLU(inplace=True)
    else:
        raise ValueError(f"Activation {activation} is not supported.")


class SeparableConv2d(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        bias=True,
    ):
        dephtwise_conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            bias=False,
        )
        pointwise_conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            bias=bias,
        )
        super().__init__(dephtwise_conv, pointwise_conv)


class SeparableConvBnAct(nn.Sequential):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        kernel_size=3,
        stride=1,
        padding=1,
        dilation=1,
        use_batchnorm=True, 
        activation="silu"
    ):
        conv = SeparableConv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=not (use_batchnorm),
        )
        bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        act = get_activation(activation)
        super(SeparableConvBnAct, self).__init__(conv, bn, act)


class ConvBnAct(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
        activation=None
    ):
        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        bn = nn.BatchNorm2d(out_channels) if use_batchnorm else nn.Identity()
        act = get_activation(activation)
        super(ConvBnAct, self).__init__(conv, bn, act)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels, activation="silu"):
        super().__init__(
            nn.AdaptiveAvgPool2d(1),
            ConvBnAct(in_channels, out_channels, kernel_size=1, activation=activation)
        )

    def forward(self, x):
        size = x.shape[-2:]
        for module in self:
            x = module(x)
        return F.interpolate(x, size=size, mode="bilinear", align_corners=False)


class ASPP(nn.Module):
    def __init__(
        self, 
        in_channels,
        out_channels, 
        atrous_rates, 
        reduction=1,
        dropout=0.2, 
        activation="silu"
    ):
        super(ASPP, self).__init__()
        modules = []
        modules.append(
            ConvBnAct(
                in_channels, 
                out_channels // reduction, 
                kernel_size=1, 
                padding=0,
                stride=1,
                use_batchnorm=True,
                activation=activation
            )
        )
        for r in atrous_rates:
            modules.append(
                SeparableConvBnAct(
                in_channels, 
                out_channels // reduction, 
                kernel_size=3,
                stride=1,
                padding=r,
                dilation=r,
                use_batchnorm=True,
                activation=activation
            ))
        modules.append(ASPPPooling(in_channels, out_channels // reduction, activation=activation))
        self.body = nn.ModuleList(modules)
        self.project = nn.Sequential(
            ConvBnAct(
                (len(atrous_rates) + 2) * out_channels // reduction, 
                out_channels, 
                kernel_size=1, 
                padding=0, 
                stride=1,
                use_batchnorm=True,
                activation=activation
            ),
            nn.Dropout(dropout)
        )

    def forward(self, x, scale_factor=1):
        if scale_factor != 1:
            x = F.interpolate(x, scale_factor=scale_factor, mode="bilinear")
        results = []
        for module in self.body:
            results.append(module(x))
        results = torch.cat(results, dim=1)
        return self.project(results)


class SegmentationHead(nn.Sequential):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        kernel_size=3, 
        padding=1, 
        upsampling=1
    ):
        blocks = [
            ConvBnAct(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding,
                stride=1,
                use_batchnorm=False,
                activation=None
            ),
            nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        ]
        super(SegmentationHead, self).__init__(*blocks)


class SCSEModule(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return (x * self.cSE(x) + x * self.sSE(x)) / 2.
    

class Attention(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule(**params)
        else:
            raise ValueError("Attention type {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)
    

class SegformerMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x
    

class Stem3d(nn.Module):
    def __init__(self, in_channels, temporal_dim):
        super().__init__()
        self.conv1 = nn.Conv3d(
            in_channels, 
            out_channels=256, 
            kernel_size=(temporal_dim, 1, 1), 
            padding=(temporal_dim // 2, 0, 0),
            stride=(1, 1, 1)
        )
        self.bn1 = nn.BatchNorm3d(256)
        self.act1 = nn.SiLU(inplace=True)
        self.conv2 = nn.Conv3d(
            in_channels=256, 
            out_channels=256,
            kernel_size=(1, 3, 3), 
            padding=(0, 1, 1),
            stride=(1, 1, 1)
        )
        self.bn2 = nn.BatchNorm3d(256)
        self.act2 = nn.SiLU(inplace=True)
        self.conv3 = nn.Conv3d(
            in_channels=256, 
            out_channels=in_channels,
            kernel_size=(1, 3, 3), 
            padding=(0, 1, 1),
            stride=(1, 1, 1)
        )
        self.bn3 = nn.BatchNorm3d(in_channels)
        self.act3 = nn.SiLU(inplace=True)

    def forward(self, x):
        assert len(x.size()) == 5
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        return x.max(dim=2)[0]

## NextViT

In [4]:
NORM_EPS = 1e-5


def merge_pre_bn(module, pre_bn_1, pre_bn_2=None):
    """ Merge pre BN to reduce inference runtime.
    """
    weight = module.weight.data
    if module.bias is None:
        zeros = torch.zeros(module.out_channels, device=weight.device).type(weight.type())
        module.bias = nn.Parameter(zeros)
    bias = module.bias.data
    if pre_bn_2 is None:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        extra_weight = scale_invstd * pre_bn_1.weight
        extra_bias = pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd
    else:
        assert pre_bn_1.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_1.affine is True, "Unsupport bn_module.affine is False"

        assert pre_bn_2.track_running_stats is True, "Unsupport bn_module.track_running_stats is False"
        assert pre_bn_2.affine is True, "Unsupport bn_module.affine is False"

        scale_invstd_1 = pre_bn_1.running_var.add(pre_bn_1.eps).pow(-0.5)
        scale_invstd_2 = pre_bn_2.running_var.add(pre_bn_2.eps).pow(-0.5)

        extra_weight = scale_invstd_1 * pre_bn_1.weight * scale_invstd_2 * pre_bn_2.weight
        extra_bias = scale_invstd_2 * pre_bn_2.weight *(pre_bn_1.bias - pre_bn_1.weight * pre_bn_1.running_mean * scale_invstd_1 - pre_bn_2.running_mean) + pre_bn_2.bias


class ConvBNReLU(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            groups=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                              padding=1, groups=groups, bias=False)
        self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PatchEmbed(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 stride=1):
        super(PatchEmbed, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        if stride == 2:
            self.avgpool = nn.AvgPool2d((2, 2), stride=2, ceil_mode=True, count_include_pad=False)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        elif in_channels != out_channels:
            self.avgpool = nn.Identity()
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.norm = norm_layer(out_channels)
        else:
            self.avgpool = nn.Identity()
            self.conv = nn.Identity()
            self.norm = nn.Identity()

    def forward(self, x):
        return self.norm(self.conv(self.avgpool(x)))


class MHCA(nn.Module):
    """
    Multi-Head Convolutional Attention
    """
    def __init__(self, out_channels, head_dim):
        super(MHCA, self).__init__()
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        self.group_conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                                       padding=1, groups=out_channels // head_dim, bias=False)
        self.norm = norm_layer(out_channels)
        self.act = nn.ReLU(inplace=True)
        self.projection = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.group_conv3x3(x)
        out = self.norm(out)
        out = self.act(out)
        out = self.projection(out)
        return out


class Mlp(nn.Module):
    def __init__(self, in_features, out_features=None, mlp_ratio=None, drop=0., bias=True):
        super().__init__()
        out_features = out_features or in_features
        hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
        self.conv1 = nn.Conv2d(in_features, hidden_dim, kernel_size=1, bias=bias)
        self.act = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(hidden_dim, out_features, kernel_size=1, bias=bias)
        self.drop = nn.Dropout(drop)

    def merge_bn(self, pre_norm):
        merge_pre_bn(self.conv1, pre_norm)

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.conv2(x)
        x = self.drop(x)
        return x


class NCB(nn.Module):
    """
    Next Convolution Block
    """
    def __init__(self, in_channels, out_channels, stride=1, path_dropout=0,
                 drop=0, head_dim=32, mlp_ratio=3):
        super(NCB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
        assert out_channels % head_dim == 0

        self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
        self.mhca = MHCA(out_channels, head_dim)
        self.attention_path_dropout = DropPath(path_dropout)

        self.norm = norm_layer(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
        self.mlp_path_dropout = DropPath(path_dropout)
        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.mlp.merge_bn(self.norm)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.attention_path_dropout(self.mhca(x))
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class E_MHSA(nn.Module):
    """
    Efficient Multi-Head Self Attention
    """
    def __init__(self, dim, out_dim=None, head_dim=32, qkv_bias=True, qk_scale=None,
                 attn_drop=0, proj_drop=0., sr_ratio=1):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim if out_dim is not None else dim
        self.num_heads = self.dim // head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
        self.proj = nn.Linear(self.dim, self.out_dim)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        self.N_ratio = sr_ratio ** 2
        if sr_ratio > 1:
            self.sr = nn.AvgPool1d(kernel_size=self.N_ratio, stride=self.N_ratio)
            self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
        self.is_bn_merged = False

    def merge_bn(self, pre_bn):
        merge_pre_bn(self.q, pre_bn)
        if self.sr_ratio > 1:
            merge_pre_bn(self.k, pre_bn, self.norm)
            merge_pre_bn(self.v, pre_bn, self.norm)
        else:
            merge_pre_bn(self.k, pre_bn)
            merge_pre_bn(self.v, pre_bn)
        self.is_bn_merged = True

    def forward(self, x):
        B, N, C = x.shape
        q = self.q(x)
        q = q.reshape(B, N, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.transpose(1, 2)
            x_ = self.sr(x_)
            if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
                x_ = self.norm(x_)
            x_ = x_.transpose(1, 2)
            k = self.k(x_)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x_)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        else:
            k = self.k(x)
            k = k.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 3, 1)
            v = self.v(x)
            v = v.reshape(B, -1, self.num_heads, int(C // self.num_heads)).permute(0, 2, 1, 3)
        attn = (q @ k) * self.scale

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class NTB(nn.Module):
    """
    Next Transformer Block
    """
    def __init__(
            self, in_channels, out_channels, path_dropout, stride=1, sr_ratio=1,
            mlp_ratio=2, head_dim=32, mix_block_ratio=0.75, attn_drop=0, drop=0,
    ):
        super(NTB, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mix_block_ratio = mix_block_ratio
        norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)

        self.mhsa_out_channels = _make_divisible(int(out_channels * mix_block_ratio), 32)
        self.mhca_out_channels = out_channels - self.mhsa_out_channels

        self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, stride)
        self.norm1 = norm_func(self.mhsa_out_channels)
        self.e_mhsa = E_MHSA(self.mhsa_out_channels, head_dim=head_dim, sr_ratio=sr_ratio,
                             attn_drop=attn_drop, proj_drop=drop)
        self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)

        self.projection = PatchEmbed(self.mhsa_out_channels, self.mhca_out_channels, stride=1)
        self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
        self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))

        self.norm2 = norm_func(out_channels)
        self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
        self.mlp_path_dropout = DropPath(path_dropout)

        self.is_bn_merged = False

    def merge_bn(self):
        if not self.is_bn_merged:
            self.e_mhsa.merge_bn(self.norm1)
            self.mlp.merge_bn(self.norm2)
            self.is_bn_merged = True

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm1(x)
        else:
            out = x
        out = rearrange(out, "b c h w -> b (h w) c")  # b n c
        out = self.mhsa_path_dropout(self.e_mhsa(out))
        x = x + rearrange(out, "b (h w) c -> b c h w", h=H)

        out = self.projection(x)
        out = out + self.mhca_path_dropout(self.mhca(out))
        x = torch.cat([x, out], dim=1)

        if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
            out = self.norm2(x)
        else:
            out = x
        x = x + self.mlp_path_dropout(self.mlp(out))
        return x


class NextViT(nn.Module):
    def __init__(self, stem_chs, depths, path_dropout, attn_drop=0, drop=0, num_classes=1000,
                 strides=[1, 2, 2, 2], sr_ratios=[8, 4, 2, 1], head_dim=32, mix_block_ratio=0.75,
                 use_checkpoint=False):
        super(NextViT, self).__init__()
        self.use_checkpoint = use_checkpoint

        self.stage_out_channels = [[96] * (depths[0]),
                                   [192] * (depths[1] - 1) + [256],
                                   [384, 384, 384, 384, 512] * (depths[2] // 5),
                                   [768] * (depths[3] - 1) + [1024]]

        # Next Hybrid Strategy
        self.stage_block_types = [[NCB] * depths[0],
                                  [NCB] * (depths[1] - 1) + [NTB],
                                  [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5),
                                  [NCB] * (depths[3] - 1) + [NTB]]

        self.stem = nn.Sequential(
            ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
            ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
            ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
        )
        input_channel = stem_chs[-1]
        features = []
        idx = 0
        dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))]  # stochastic depth decay rule
        for stage_id in range(len(depths)):
            numrepeat = depths[stage_id]
            output_channels = self.stage_out_channels[stage_id]
            block_types = self.stage_block_types[stage_id]
            for block_id in range(numrepeat):
                if strides[stage_id] == 2 and block_id == 0:
                    stride = 2
                else:
                    stride = 1
                output_channel = output_channels[block_id]
                block_type = block_types[block_id]
                if block_type is NCB:
                    layer = NCB(input_channel, output_channel, stride=stride, path_dropout=dpr[idx + block_id],
                                drop=drop, head_dim=head_dim)
                    features.append(layer)
                elif block_type is NTB:
                    layer = NTB(input_channel, output_channel, path_dropout=dpr[idx + block_id], stride=stride,
                                sr_ratio=sr_ratios[stage_id], head_dim=head_dim, mix_block_ratio=mix_block_ratio,
                                attn_drop=attn_drop, drop=drop)
                    features.append(layer)
                input_channel = output_channel
            idx += numrepeat
        self.features = nn.Sequential(*features)

        self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj_head = nn.Sequential(
            nn.Linear(output_channel, num_classes),
        )

        self.stage_out_idx = [sum(depths[:idx + 1]) - 1 for idx in range(len(depths))]
        # print('initialize_weights...')
        self._initialize_weights()

    def merge_bn(self):
        self.eval()
        for idx, module in self.named_modules():
            if isinstance(module, NCB) or isinstance(module, NTB):
                module.merge_bn()

    def _initialize_weights(self):
        for n, m in self.named_modules():
            if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                trunc_normal_(m.weight, std=.02)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.stem(x)
        for idx, layer in enumerate(self.features):
            if self.use_checkpoint:
                x = checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.proj_head(x)
        return x

    def forward_features(self, x):
        x = self.stem(x)
        for idx, layer in enumerate(self.features):
            if self.use_checkpoint:
                x = checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        x = self.norm(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x


def nextvit_small(**kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 10, 3], path_dropout=0.1, **kwargs)
    return model


def nextvit_base(**kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 20, 3], path_dropout=0.2, **kwargs)
    return model


def nextvit_large(**kwargs):
    model = NextViT(stem_chs=[64, 32, 64], depths=[3, 4, 30, 3], path_dropout=0.2, **kwargs)
    return model

## Inflated Networks

In [5]:
def inflate_conv2d(conv2d, temporal_dim=1):
    kh, kw = conv2d.kernel_size
    sh, sw = conv2d.stride
    ph, pw = conv2d.padding if not isinstance(conv2d, layers.Conv2dSame) else [kh // 2, kw // 2]
    dh, dw = conv2d.dilation
    groups = conv2d.groups
    conv3d = torch.nn.Conv3d(
        conv2d.in_channels,
        conv2d.out_channels,
        kernel_size=(temporal_dim, kh, kw),
        stride=(1, sh, sw),
        padding=(temporal_dim // 2, ph, pw),
        dilation=(1, dh, dw),
        groups=groups
    )
    weight_2d = conv2d.weight.data
    weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, temporal_dim, 1, 1) / temporal_dim
    conv3d.weight = nn.Parameter(weight_3d)
    conv3d.bias = conv2d.bias
    return conv3d


def inflate_bn2d(bn2d):
    bn3d = nn.BatchNorm3d(
        bn2d.num_features, 
        eps=bn2d.eps, 
        momentum=bn2d.momentum, 
        affine=bn2d.affine, 
        track_running_stats=bn2d.track_running_stats
    )
    bn3d.weight = bn2d.weight
    bn3d.bias = bn2d.bias
    return bn3d


def inflate_ln2d(ln2d):
    ln3d = nn.LayerNorm(
        ln2d.normalized_shape, 
        eps=ln2d.eps, 
        elementwise_affine=ln2d.elementwise_affine
    )
    ln3d.weight = ln2d.weight
    ln3d.bias = ln2d.bias
    return ln3d


def inflate_pool2d(pool2d, temporal_dim=1):
    if isinstance(pool2d, nn.MaxPool2d):
        pool3d = nn.MaxPool3d(
            (temporal_dim, pool2d.kernel_size, pool2d.kernel_size),
            stride=(1, pool2d.stride, pool2d.stride),
            padding=(temporal_dim // 2, pool2d.padding, pool2d.padding),
            dilation=(1, pool2d.dilation, pool2d.dilation),
            ceil_mode=pool2d.ceil_mode
        )
    elif isinstance(pool2d, nn.AvgPool2d):
        pool3d = nn.AvgPool3d(
            (temporal_dim, pool2d.kernel_size, pool2d.kernel_size),
            stride=(1, pool2d.stride, pool2d.stride),
            padding=(0, pool2d.padding, pool2d.padding)
        )
    else:
        raise ValueError(f"Layer {pool2d} is not supported.")
    return pool3d


class InflatedBatchNormAct2d(nn.Module):
    def __init__(self, block, **kwargs):
        super().__init__(**kwargs)
        self.bn1 = inflate_bn2d(block)
        self.drop = block.drop
        self.act = block.act

    def forward(self, x):
        x = self.bn1(x)
        x = self.drop(x)
        x = self.act(x)
        return x


class InflatedLayerNorm(nn.Module):
    def __init__(self, block, **kwargs):
        super().__init__(**kwargs)
        self.ln3d = inflate_ln2d(block)
        
    def forward(self, x):
        x = x.permute(0, 2, 3, 4, 1)
        weight, bias = self.ln3d.weight, self.ln3d.bias
        if torch.is_autocast_enabled():
            dt = torch.get_autocast_gpu_dtype()
            x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
        with torch.cuda.amp.autocast(enabled=False):
            x = F.layer_norm(x, self.ln3d.normalized_shape, weight, bias, self.ln3d.eps)
        x = x.permute(0, 4, 1, 2, 3)
        return x
        

class InflatedConvBnAct(nn.Module):
    def __init__(self, block, temporal_dim=1, **kwargs):
        super().__init__(**kwargs)
        self.conv = inflate_conv2d(block.conv, temporal_dim)
        self.bn1 = InflatedBatchNormAct2d(block.bn1)
        self.drop_path = block.drop_path
        self.has_skip = block.has_skip
    
    def forward(self, x):
        shortcut = x
        x = self.conv(x)
        x = self.bn1(x)
        if self.has_skip:
            x = self.drop_path(x) + shortcut
        return x
    

class InflatedSqueezeExcite(nn.Module):
    def __init__(self, block, temporal_dim=1, **kwargs):
        super().__init__(**kwargs)
        self.conv_reduce = inflate_conv2d(block.conv_reduce, temporal_dim)
        self.act1 = block.act1
        self.conv_expand = inflate_conv2d(block.conv_expand, temporal_dim)
        self.gate = block.gate

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        return x * self.gate(x_se)

    
class InflatedEdgeResidual(nn.Module):
    def __init__(self, block, temporal_dim=1, **kwargs):
        super().__init__(**kwargs)
        self.conv_exp = inflate_conv2d(block.conv_exp, temporal_dim)
        self.bn1 = InflatedBatchNormAct2d(block.bn1)
        self.se = block.se
        self.conv_pwl = inflate_conv2d(block.conv_pwl, temporal_dim)
        self.bn2 = InflatedBatchNormAct2d(block.bn2)
        self.drop_path = block.drop_path
        self.has_skip = block.has_skip

    def forward(self, x):
        shortcut = x
        x = self.conv_exp(x)
        x = self.bn1(x)
        x = self.se(x)
        x = self.conv_pwl(x)
        x = self.bn2(x)
        if self.has_skip:
            x = self.drop_path(x) + shortcut
        return x

        
class InflatedInvertedResidual(nn.Module):
    def __init__(self, block, temporal_dim=1, **kwargs):
        super().__init__(**kwargs)
        self.conv_pw = inflate_conv2d(block.conv_pw, temporal_dim) # Only inflate the first point-wise convolution
        self.bn1 = InflatedBatchNormAct2d(block.bn1)
        self.conv_dw = inflate_conv2d(block.conv_dw, 1)
        self.bn2 = InflatedBatchNormAct2d(block.bn2)
        self.se = InflatedSqueezeExcite(block.se, 1)
        self.conv_pwl = inflate_conv2d(block.conv_pwl, 1)
        self.bn3 = InflatedBatchNormAct2d(block.bn3)
        self.drop_path = block.drop_path
        self.has_skip = block.has_skip

    def forward(self, x):
        shortcut = x
        x = self.conv_pw(x)
        x = self.bn1(x)
        x = self.conv_dw(x)
        x = self.bn2(x)
        x = self.se(x)
        x = self.conv_pwl(x)
        x = self.bn3(x)
        if self.has_skip:
            x = self.drop_path(x) + shortcut
        return x


class InflatedEfficientNet(nn.Module):
    def __init__(
        self, 
        name, 
        backbone_params, 
        temporal_dim=1,
        block_idx=[4, 5],
        **kwargs
    ):
        super().__init__(**kwargs)
        encoder = timm.create_model(name, **backbone_params)
        self.temporal_dim = temporal_dim
        self.block_idx = block_idx
        self.conv_stem = inflate_conv2d(encoder.conv_stem, 1)
        self.bn1 = InflatedBatchNormAct2d(encoder.bn1)
        blocks = []
        blocks.append(
            nn.Sequential(*[
                InflatedConvBnAct(cba, 1) for cba in encoder.blocks[0]
            ])
        )
        blocks.append(
            nn.Sequential(*[
                InflatedEdgeResidual(er, 1) for er in encoder.blocks[1]
            ])
        )
        blocks.append(
            nn.Sequential(*[
                InflatedEdgeResidual(er, 1) for er in encoder.blocks[2]
            ])
        )
        for i in range(3, len(encoder.blocks)):
            blocks.append(
                nn.Sequential(*[
                    InflatedInvertedResidual(ir, temporal_dim if i in block_idx else 1) for ir in encoder.blocks[i]
                ])
            )
        self.blocks = nn.Sequential(*blocks)
        
    def forward(self, x):
        x = self.conv_stem(x)
        x = self.bn1(x)
        x = self.blocks(x)
        return x

## Encoder

In [6]:
def create_encoder(encoder_params):
    module = getattr(sys.modules[__name__], encoder_params["class"])
    name = encoder_params["encoder_name"]
    return module(name=name, **encoder_params["params"])


class BaseEncoder(nn.Module):
    def __init__(self, out_channels, **kwargs):
        super().__init__()
        self.out_channels = out_channels

    def get_stages(self):
        return [nn.Identity()]

    def forward(self, x):
        stages = self.get_stages()
        features = []
        for stage in stages:
            x = stage(x)
            features.append(x)
        return features


class ConvNeXtEncoder2d(BaseEncoder):
    def __init__(
        self, 
        name,
        stage_idx,
        backbone_params={},
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = create_model(name, **backbone_params)
        assert len(stage_idx) <= len(self.encoder.stages)
        self.stage_idx = stage_idx
        self.depth = len(stage_idx) + 2

    def get_stages(self):
        return [nn.Identity(), self.encoder.stem] + \
            [self.encoder.stages[i : j] for i, j in zip([0] + self.stage_idx, self.stage_idx + [len(self.encoder.stages)])]


class EfficientNetEncoder2d(BaseEncoder):
    def __init__(
        self, 
        name,
        stage_idx,
        backbone_params={},
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = create_model(name, **backbone_params)
        assert len(stage_idx) <= len(self.encoder.blocks)
        self.stage_idx = stage_idx
        self.depth = len(stage_idx) + 2

    def get_stages(self):
        return [nn.Identity(), nn.Sequential(self.encoder.conv_stem, self.encoder.bn1)] + \
            [self.encoder.blocks[i : j] for i, j in zip([0] + self.stage_idx, self.stage_idx + [len(self.encoder.blocks)])]
    
    
class EfficientNetEncoder3d(BaseEncoder):
    def __init__(
        self, 
        name,
        stage_idx,
        backbone_params={},
        temporal_dim=1,
        block_idx=[5, 6],
        slice_idx=0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = InflatedEfficientNet(name, backbone_params, temporal_dim, block_idx)
        assert slice_idx < temporal_dim
        assert len(stage_idx) <= len(self.encoder.blocks)
        self.stage_idx = stage_idx
        self.depth = len(stage_idx) + 2
        self.slice_idx = slice_idx

    def get_stages(self):
        return [nn.Identity(), nn.Sequential(self.encoder.conv_stem, self.encoder.bn1)] + \
            [self.encoder.blocks[i : j] for i, j in zip([0] + self.stage_idx, self.stage_idx + [len(self.encoder.blocks)])]
    
    def forward(self, x):
        stages = self.get_stages()
        features = []
        for stage in stages:
            x = stage(x)
            features.append(x[:, :, self.slice_idx, :, :])
        return features


class MaxxViTEncoder2d(BaseEncoder):
    def __init__(
        self, 
        name,
        backbone_params={},
        **kwargs
    ):
        super().__init__(**kwargs)
        self.encoder = create_model(name, **backbone_params)
        self.depth = 5

    def get_stages(self):
        return [
            nn.Identity(),
            self.encoder.stem,
            self.encoder.stages[0],
            self.encoder.stages[1],
            self.encoder.stages[2],
            self.encoder.stages[3]
        ]


class NextViTEncoder2d(BaseEncoder):
    def __init__(
        self, 
        name, 
        stage_idx,
        create_fn,
        backbone_params={}, 
        **kwargs
    ):
        super().__init__(**kwargs)
        backbone_params = {} if backbone_params is None else backbone_params # Null params fix
        self.encoder = getattr(sys.modules[__name__], create_fn)(**backbone_params)
        self.stage_idx = stage_idx

    def forward(self, x):
        features = []
        stage_id = 0
        features.append(x)
        x = self.encoder.stem(x)
        for i, layer in enumerate(self.encoder.features):
            x = layer(x)
            if i == self.stage_idx[stage_id]:
                if i == 0:
                    features.append(nn.functional.interpolate(x, scale_factor=2., mode="bilinear"))
                else:
                    features.append(x)
                stage_id += 1
        return features

## Decoder

In [7]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        block_depth=1,
        separable=False,
        use_aspp=False,
        use_batchnorm=True,
        attention_type=None,
        activation="relu"
    ):
        super().__init__()
        self.attention = nn.ModuleList([
            Attention(attention_type, in_channels=in_channels + skip_channels),
            Attention(attention_type, in_channels=out_channels)
        ])
        self.aspp = ASPP(
            in_channels,
            in_channels,
            atrous_rates=[1, 2, 4],
            reduction=2,
            dropout=0.2,
            activation=activation
        ) if use_aspp else nn.Identity()
        module = SeparableConvBnAct if separable else ConvBnAct
        self.stem = module(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
            activation=activation
        )
        self.body = nn.Sequential(*[
            module(
                out_channels, 
                out_channels, 
                kernel_size=3, 
                padding=1, 
                use_batchnorm=use_batchnorm,
                activation=activation
            ) for _ in range(block_depth)
         ])

    def forward(self, x, skip=None, scale_factor=1):
        if scale_factor != 1:
            x = F.interpolate(x, scale_factor=scale_factor, mode="bilinear")
        x = self.aspp(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention[0](x)
        x = self.stem(x)
        x = self.body(x)
        x = self.attention[1](x)
        return x


class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        scale_factors,
        num_blocks=5,
        block_depth=1,
        separable=False,
        use_aspp=False,
        use_batchnorm=True,
        attention_type=None,
        activation="relu"
    ):
        super().__init__()
        assert num_blocks >= len(encoder_channels) - 1
        assert num_blocks == len(decoder_channels)
        assert num_blocks == len(scale_factors)
        self.scale_factors = scale_factors
        encoder_channels = encoder_channels[1:][::-1]
        in_channels = [encoder_channels[0]] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:])
        skip_channels += [0] * (len(in_channels) - len(skip_channels))
        out_channels = decoder_channels
        aspp_idx = len(in_channels) - 2
        blocks = []
        for i, (i_ch, s_ch, o_ch) in enumerate(zip(in_channels, skip_channels, out_channels)):
            blocks.append(
                DecoderBlock(
                    i_ch, 
                    s_ch, 
                    o_ch, 
                    block_depth,
                    separable=separable,
                    use_aspp=use_aspp if i == aspp_idx else False,
                    use_batchnorm=use_batchnorm, 
                    attention_type=attention_type,
                    activation=activation
                )
            )
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):
        features = features[1:][::-1]
        x = features[0]
        skips = features[1:]
        for i, (block, scale_factor) in enumerate(zip(self.blocks, self.scale_factors)):
            skip = skips[i] if i < len(skips) else None
            x = block(x, skip, scale_factor)
        return x

## Model

In [8]:
def create_segmentation_model(config):
    config_ = config.copy()
    family = config_.pop("family")
    if family == "unet":
        return Unet(**config_)
    else:
        raise ValueError(f"Model family {family} is not supported.")
        
        
def create_classification_model(config):
    return ContrailsClassifier(**config)
        

def load_model(module, config, checkpoint_path):
    model = module(config)
    model.load_state_dict(torch.load(checkpoint_path)["state_dict"])
    return model
    

class Unet(nn.Module):
    def __init__(
        self, 
        encoder_params,
        decoder_params, 
        num_classes=1
    ):
        super().__init__()
        self.encoder = create_encoder(encoder_params)
        self.decoder = UnetDecoder(self.encoder.out_channels, **decoder_params)
        self.seg_head = SegmentationHead(
            decoder_params["decoder_channels"][-1], 
            num_classes, 
            kernel_size=3,
            padding=1, 
            upsampling=1
        )
        initialize_decoder(self.decoder)
        initialize_head(self.seg_head)

    def forward(self, x):
        x = x.squeeze(dim=2)
        features = self.encoder(x)
        decoder_output = self.decoder(*features)
        logits = self.seg_head(decoder_output)
        return logits
    
    
class ContrailsClassifier(nn.Module):
    def __init__(
        self, 
        encoder_name, 
        representation_dim, 
        dropout=0.2, 
        backbone_params={}, 
        **kwargs
    ):
        super().__init__(**kwargs)
        backbone_params = {} if backbone_params is None else backbone_params
        self.model = create_model(encoder_name, num_classes=0, **backbone_params)
        self.drop = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
        self.clf_head = nn.Linear(self.model.num_features, 1)
        self.rep_head = nn.Linear(self.model.num_features, representation_dim)
    
    def forward(self, x):
        if len(x.size()) == 5:
            x = x[:, :, 0]
        x = self.model(x)
        x = self.drop(x)
        clf_logits = self.clf_head(x)[:, 0]
        rep_logits = self.rep_head(x)
        return clf_logits, rep_logits
    
    
class SegmentationModule2d(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = create_segmentation_model(config["model"])
        
        
class ClassificationModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = create_classification_model(config["model"])

## Dataset

In [9]:
DATA_DIR = "/kaggle/input/google-research-identify-contrails-reduce-global-warming"
N_TIMES_BEFORE = 4
_T11_BOUNDS = (243, 303)
_CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
_TDIFF_BOUNDS = (-4, 2)
_CLOUD_BOUNDS = (323, 203)


def normalize_range(data, bounds):
    return (data - bounds[0]) / (bounds[1] - bounds[0])


def load_metadata(split="train"):
    with open(os.path.join(DATA_DIR, f"{split}_metadata.json"), "r") as f:
        meta = json.load(f)
    for record in meta:
        record["split"] = split
        record["record_path"] = os.path.join(DATA_DIR, split, record["record_id"])
    return meta


def data_split(path):
    if os.path.exists(path):
        df = pd.read_csv(path)
    else:
        df_train = pd.DataFrame(load_metadata("train"))
        df_valid = pd.DataFrame(load_metadata("validation"))
        kf = KFold(shuffle=True, n_splits=FOLDS)
        for n, (_, valid_idx) in enumerate(kf.split(df_train)):
            df_train.loc[valid_idx, "fold"] = int(n)
        df_train["fold"] = df_train["fold"].astype(int)
        df_valid["fold"] = -1
        df = pd.concat([df_train, df_valid], axis=0).reset_index(drop=True)
        df = dedup_records(df)
        df.to_csv(path, index=None)
    return df


def dedup_records(df):
    rows = []
    for _, group in df.groupby(["row_min", "row_size", "col_min", "col_size"]):
        prev_timestamp = None
        for _, row in group.sort_values(by="timestamp", ascending=True).iterrows():
            if prev_timestamp is not None and (row.timestamp - prev_timestamp) / 3600 < 1.:
                pass
            else:
                rows.append(row)
            prev_timestamp = row.timestamp
    return pd.DataFrame(rows).reset_index(drop=True)


def ash_color(bands):
    r = normalize_range(bands[2] - bands[1], _TDIFF_BOUNDS)
    g = normalize_range(bands[2] - bands[0], _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize_range(bands[1], _T11_BOUNDS)
    false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
    return false_color


def normalize_temperature_signal(t):
    t_smooth = cv2.GaussianBlur(t, (5, 5), 0)
    t_signal = np.clip(t - t_smooth, 0, 1)
    t_std = np.sqrt(cv2.GaussianBlur(t_signal ** 2, (5, 5), 0))
    t_norm = t_signal / (t_std + 0.1)
    return t_norm


def load_band(record_path, band):
    return np.load(os.path.join(record_path, f"band_{band:02d}.npy"))


def load_mask(record_path):
    mask = np.load(os.path.join(record_path, "human_pixel_masks.npy"))[..., 0]
    return mask


def load_frames(record_path, timesteps, add_temp_diff=False):
    bands = [load_band(record_path, band) for band in [11, 14, 15]]
    frames = ash_color(bands)
    frames = np.transpose(frames[..., timesteps], (3, 0, 1, 2))
    if add_temp_diff:
        temp_diff = []
        for t_diff in frames[..., 0]:
            temp_diff.append(normalize_temperature_signal(t_diff))
        temp_diff = np.stack(temp_diff, axis=0)
        frames = np.concatenate([frames, temp_diff[..., None]], axis=-1)
    return frames


def load_record(record_path, timesteps=[N_TIMES_BEFORE], add_temp_diff=False):
    frames = load_frames(record_path, timesteps, add_temp_diff)
    mask = load_mask(record_path)
    return frames, mask


def get_transform(split="train"):
    if split == "train":
        augments = [
            A.Flip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomResizedCrop(height=256, width=256, scale=(0.8, 1.2), ratio=(0.8, 1.2), interpolation=cv2.INTER_LINEAR, p=0.8),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=1.0, always_apply=True)
        ]
    else:
        augments = [
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=1.0, always_apply=True)
        ]
    return A.ReplayCompose(augments)
    
    
def rle_encode(x, fg_val=1):
    dots = np.where(
        x.T.flatten() == fg_val)[0]  # .T sets Fortran order down-then-right
    run_lengths = []
    prev = -2
    for b in dots:
        if b > prev + 1:
            run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def list_to_string(x):
    if x: # non-empty list
        s = str(x).replace("[", "").replace("]", "").replace(",", "")
    else:
        s = '-'
    return s
    

class ContrailsInferenceDataset(Dataset):
    def __init__(
        self,
        df, 
        timesteps=[4],
        image_size=512
    ):
        self.df = df
        self.timesteps = timesteps
        self.image_size = image_size
        self.split = "validation"
        self.transform = get_transform(self.split)

    def __len__(self):
        return len(self.df)

    def resize_to_tensor(self, image):
        image = A.Compose([
            A.Resize(self.image_size, self.image_size, interpolation=cv2.INTER_LINEAR, always_apply=True),
            ToTensorV2(always_apply=True)
        ])(image=image)["image"]
        return image
    
    def apply_transform(self, frames):
        tf = self.transform(image=frames[0])
        image = tf["image"]
        processed = [self.resize_to_tensor(image)]
        if len(frames) > 1:
            for frame in frames[1:]:
                frame = A.ReplayCompose.replay(tf["replay"], image=frame)["image"]
                processed.append(self.resize_to_tensor(frame))
        frames = torch.stack(processed, dim=1)
        return frames

    def __getitem__(self, i):
        row = self.df.iloc[i]
        frames = load_frames(row["record_path"], self.timesteps, False)
        frames = self.apply_transform(frames)
        sample = {
            "frames": frames
        }
        return sample
    
    
class ContrailsDebugDataset(Dataset):
    def __init__(
        self,
        df, 
        timesteps=[4],
        image_size=512
    ):
        self.df = df
        self.timesteps = timesteps
        self.image_size = image_size
        self.split = "validation"
        self.transform = get_transform(self.split)

    def __len__(self):
        return len(self.df)
    
    def resize_to_tensor(self, image):
        image = A.Compose([
            A.Resize(self.image_size, self.image_size, interpolation=cv2.INTER_LINEAR, always_apply=True),
            ToTensorV2(always_apply=True)
        ])(image=image)["image"]
        return image
    
    def apply_transform(self, frames, mask):
        tf = self.transform(image=frames[0], mask=mask)
        image, mask = tf["image"], tf["mask"]
        processed = [self.resize_to_tensor(image)]
        if len(frames) > 1:
            for frame in frames[1:]:
                frame = A.ReplayCompose.replay(tf["replay"], image=frame)["image"]
                processed.append(self.resize_to_tensor(frame))
        frames = torch.stack(processed, dim=1)
        return frames, mask

    def __getitem__(self, i):
        row = self.df.iloc[i]
        frames, mask = load_record(row["record_path"], self.timesteps, False)
        frames, mask = self.apply_transform(frames, mask)
        sample = {
            "frames": frames,
            "mask": mask[None, :],
            "label": float(mask.sum() > 0)
        }
        return sample

## Ensemble

In [10]:
# TTA
def flip_batch(x, i):
    if i == 0:
        return x
    elif i == 1:
        return x.flip(3)
    elif i == 2:
        return x.flip(4)
    else:
        return x.flip(3).flip(4)

In [11]:
# Ensemble config
clf_models = {
    "clf_efficientnetv2_s": {
        "weight": 1.0,
        "config_path": "/kaggle/input/contrails-configs/clf_efficientnetv2_s_512.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/tf_efficientnetv2_s.in21k_ft_in1k_512_s0.ckpt",
            "/kaggle/input/contrails-models-final/tf_efficientnetv2_s.in21k_ft_in1k_512_s1.ckpt"
        ]
    }
}

seg_models = {
    "seg_convnextv2_base": {
        "weight": 16.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/seg_convnextv2_base_1024.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/convnextv2_base.fcmae_ft_in22k_in1k_384_1024_s777_fold_0.ckpt",
        ]
    },
    "seg_convnextv2_large": {
        "weight": 2.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/2d_convnextv2_large_512_s1997_small.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_0.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_1.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_2.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_3.ckpt",
#             "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_4.ckpt",
        ]
    },
    "seg_maxxvitv2_large": {
        "weight": 1.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/2d_maxxvitv2_base_384_s1997_small.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_0.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_1.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_2.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_3.ckpt",
#             "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_4.ckpt",
        ]
    },
    "seg_nextvit_large": {
        "weight": 1.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/2d_nextvit_large_512_s1997_small.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_0.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_1.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_2.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_3.ckpt",
#             "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_4.ckpt",
        ]
    },
    "seg_efficientnetv2_l": {
        "weight": 1.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/tf_efficientnetv2_l_512.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_0.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_1.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_2.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_3.ckpt",
#             "/kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_4.ckpt",
        ]
    },
    "seg_efficientnetv2_m_3d": {
        "weight": 1.0,
        "use_tta": False,
        "config_path": "/kaggle/input/contrails-configs/3d_tf_efficientnetv2_m.yaml",
        "checkpoint_paths": [
            "/kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_0.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_1.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_2.ckpt",
            "/kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_3.ckpt",
#             "/kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_4.ckpt",
        ]
    }
}

In [12]:
# # Debugging
# device = "cuda"
# batch_size = 16
# num_workers = 2
# bins = 100
# stride = 64
# thresholds = [i / bins for i in range(1, bins)]
# total_clf_weight = sum([clf_models[k]["weight"] * len(clf_models[k]["checkpoint_paths"]) for k in clf_models])
# total_seg_weight = sum([seg_models[k]["weight"] * len(seg_models[k]["checkpoint_paths"]) for k in seg_models])

# split = "validation"
# records = os.listdir(os.path.join(DATA_DIR, split))
# df = pd.DataFrame(records, columns=["record_id"])
# df["record_path"] = df["record_id"].apply(lambda r: os.path.join(DATA_DIR, split, str(r)))
# print(f"Images: {df.shape[0]}")

# # Classification filtering
# predictions = {}
# labels = {}
# for name, model in clf_models.items():
#     weight = model["weight"]
#     config_path = model["config_path"]
#     checkpoint_paths = model["checkpoint_paths"]

#     print(f"Loading classification model `{name}` with configuration: {config_path}")
#     with open(config_path, "rb") as f:
#         config = yaml.load(f, Loader=yaml.FullLoader)
#     try:
#         config["model"]["model"]["backbone_params"]["pretrained"] = False
#     except:
#         pass

#     for checkpoint_path in checkpoint_paths:
#         print(f"Checkpoint: {checkpoint_path}")
#         model = load_model(ClassificationModule, config["model"], checkpoint_path)
#         model.to(device)
#         model.eval()

#         dataset = ContrailsDebugDataset(
#             df, 
#             timesteps=[i for i in range(8)], 
#             image_size=config["model"]["data"]["image_size"]
#         )
#         dataloader = DataLoader(
#             dataset, 
#             shuffle=False, 
#             batch_size=batch_size // 8,
#             num_workers=num_workers
#         )

#         fbeta = [FBetaScore(task="binary", beta=1., threshold=t).to(device) for t in thresholds]

#         with torch.no_grad():
#             for i, batch in tqdm(enumerate(dataloader)):
#                 x = batch["frames"].to(device)
#                 y = batch["label"].to(device)
#                 n, c, t, h, w = x.shape
#                 x = torch.permute(x, (0, 2, 1, 3, 4)).contiguous()
#                 x = x.view(-1, c, h, w)
#                 logits, _ = model.model(x)
#                 probs = torch.sigmoid(logits).view(n, t)
#                 for j, (pp, yy) in enumerate(zip(probs, y)):
#                     record_id = df.iloc[i * batch_size // 8 + j]["record_id"]
#                     predictions.setdefault(record_id, torch.zeros(pp.shape, dtype=torch.float32).to(device))
#                     predictions[record_id] += pp * weight / total_clf_weight
#                     labels[record_id] = yy
#                 for fb in fbeta:
#                     fb.update(probs[:, 4], y)

#         scores = [fb.compute().cpu().numpy() for fb in fbeta]
#         idx = np.argmax(scores)
#         print(f"F1 coefficient (t = {thresholds[idx]}): {scores[idx]:.04f}")

#         del model, dataset, dataloader, fbeta
#         torch.cuda.empty_cache()
#         gc.collect()

# ensemble_fbeta = [FBetaScore(task="binary", beta=1., threshold=t).to(device) for t in thresholds]
# for record_id in predictions:
#     probs, y = predictions[record_id], labels[record_id]
#     for ef in ensemble_fbeta:
#         ef.update(probs[None, 4], y[None])    

# scores = [ef.compute().cpu().numpy() for ef in ensemble_fbeta]
# idx = np.argmax(scores)
# threshold, score = thresholds[idx], scores[idx]
# negative_records = set([record_id for record_id in predictions if (predictions[record_id] > threshold).sum() < 2])
# print(f"Ensemble F1 coefficient (t = {threshold}): {score:.04f}")
# print(f"Negative records: {len(negative_records)}")

# # Segmentation predictions
# predictions = {}
# labels = {}
# for name, model in seg_models.items():
#     weight = model["weight"]
#     use_tta = model["use_tta"]
#     config_path = model["config_path"]
#     checkpoint_paths = model["checkpoint_paths"]

#     print(f"Loading segmentation model `{name}` with configuration: {config_path}")
#     with open(config_path, "rb") as f:
#         config = yaml.load(f, Loader=yaml.FullLoader)
#     try:
#         config["model"]["model"]["encoder_params"]["params"]["backbone_params"]["pretrained"] = False
#     except:
#         pass

#     for checkpoint_path in checkpoint_paths:
#         print(f"Checkpoint: {checkpoint_path}")
#         model = load_model(SegmentationModule2d, config["model"], checkpoint_path)
#         model.to(device)
#         model.eval()

#         image_size = config["model"]["data"]["image_size"]
#         timesteps = config["model"]["data"]["timesteps"]
#         resize_i = 3 * image_size // 4
#         resize_m = 192
#         stride_i = image_size - resize_i
#         stride_m = 256 - resize_m
#         strided_steps = (image_size - resize_i) // stride_i + 1
        
#         dataset = ContrailsDebugDataset(
#             df, 
#             timesteps=timesteps, 
#             image_size=image_size
#         )
#         dataloader = DataLoader(
#             dataset, 
#             shuffle=False, 
#             batch_size=batch_size,
#             num_workers=num_workers
#         )

#         global_dice = [Dice(threshold=t).to(device) for t in thresholds]
#         auprc = AveragePrecision(task="binary").to(device)

#         with torch.no_grad():
#             for i, batch in tqdm(enumerate(dataloader)):
#                 x = batch["frames"].to(device)
#                 y = batch["mask"].to(device)
                
#                 # Full resolution prediction
#                 logits = model.model(x)
#                 if logits.shape[-1] != 256:
#                     logits = F.interpolate(logits, size=256, mode="bilinear")
#                 probs = torch.sigmoid(logits)
#                 probs = probs.view(probs.size(0), 1, 256, 256)
                
# #                 # Strided crops
# #                 probs *= float(strided_steps)
# #                 count = torch.ones(probs.shape, dtype=torch.int16).to(device) * strided_steps
# #                 for j in range(strided_steps):
# #                     for k in range(strided_steps):
# #                         x_ = x[..., j * stride_i : j * stride_i + resize_i, k * stride_i : k * stride_i + resize_i]
# #                         x_ = torch.stack([
# #                             F.interpolate(x_[:, :, i], size=image_size, mode="bilinear") for i in range(len(timesteps))
# #                         ], dim=2)
# #                         logits_ = model.model(x_)
# #                         logits_ = F.interpolate(logits_, size=resize_m, mode="bilinear")
# #                         probs_ = torch.sigmoid(logits_)
# #                         probs_ = probs_.view(probs_.size(0), 1, resize_m, resize_m)
# #                         probs[..., j * stride_m : j * stride_m + resize_m, k * stride_m : k * stride_m + resize_m] += probs_
# #                         count[..., j * stride_m : j * stride_m + resize_m, k * stride_m : k * stride_m + resize_m] += 1
# #                 probs /= count

#                 # TTA
#                 if use_tta:
#                     probs *= 3.
#                     for j in range(3):
#                         x_ = flip_batch(x, j + 1)
#                         logits_ = model.model(x_)
#                         if logits_.shape[-1] != 256:
#                             logits_ = F.interpolate(logits_, size=256, mode="bilinear")
#                         probs_ = torch.sigmoid(logits_)
#                         probs_ = probs_.view(probs_.size(0), 1, 1, 256, 256)
#                         probs_ = flip_batch(probs_, j + 1)[:, 0]
#                         probs += probs_
#                     probs /= 6.
                    
#                 for j, (pp, yy) in enumerate(zip(probs, y)):
#                     record_id = df.iloc[i * batch_size + j]["record_id"]
#                     predictions.setdefault(record_id, torch.zeros(pp.shape, dtype=torch.float32).to(device))
#                     if record_id in negative_records:
#                         pp = torch.zeros(pp.shape, dtype=torch.float32).to(device)
#                     predictions[record_id] += pp * weight / total_seg_weight
#                     labels[record_id] = yy
#                     for gd in global_dice:
#                         gd.update(pp, yy)
#                     auprc.update(pp, yy)

#         scores = [gd.compute().cpu().numpy() for gd in global_dice]
#         idx = np.argmax(scores)
#         print(f"DICE coefficient (t = {thresholds[idx]}): {scores[idx]:.04f}")
#         print(f"AUPRC: {auprc.compute():.04f}")

#         del model, dataset, dataloader, global_dice, auprc
#         torch.cuda.empty_cache()
#         gc.collect()

# ensemble_dice = [Dice(threshold=t).to(device) for t in thresholds]
# for record_id in predictions:
#     probs, y = predictions[record_id], labels[record_id]
#     for ed in ensemble_dice:
#         ed.update(probs, y)    

# scores = [ed.compute().cpu().numpy() for ed in ensemble_dice]
# idx = np.argmax(scores)
# threshold, score = thresholds[idx], scores[idx]
# print(f"Ensemble DICE coefficient (t = {threshold}): {score:.04f}")

In [13]:
# Inference
device = "cuda"
batch_size = 4
num_workers = 2
stride = 64
total_clf_weight = sum([clf_models[k]["weight"] * len(clf_models[k]["checkpoint_paths"]) for k in clf_models])
total_seg_weight = sum([seg_models[k]["weight"] * len(seg_models[k]["checkpoint_paths"]) for k in seg_models])
clf_threshold = 0.39
seg_threshold = 0.55

split = "test"
records = os.listdir(os.path.join(DATA_DIR, split))
df = pd.DataFrame(records, columns=["record_id"])
df["record_path"] = df["record_id"].apply(lambda r: os.path.join(DATA_DIR, split, str(r)))
print(f"Images: {df.shape[0]}")

# Classification filtering
predictions = {}
for name, model in clf_models.items():
    weight = model["weight"]
    config_path = model["config_path"]
    checkpoint_paths = model["checkpoint_paths"]

    print(f"Loading classification model `{name}` with configuration: {config_path}")
    with open(config_path, "rb") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    try:
        config["model"]["model"]["backbone_params"]["pretrained"] = False
    except:
        pass

    for checkpoint_path in checkpoint_paths:
        print(f"Checkpoint: {checkpoint_path}")
        model = load_model(ClassificationModule, config["model"], checkpoint_path)
        model.to(device)
        model.eval()

        dataset = ContrailsInferenceDataset(
            df, 
            timesteps=[i for i in range(8)], 
            image_size=config["model"]["data"]["image_size"]
        )
        dataloader = DataLoader(
            dataset, 
            shuffle=False, 
            batch_size=batch_size // 4,
            num_workers=num_workers
        )

        with torch.no_grad():
            for i, batch in tqdm(enumerate(dataloader)):
                x = batch["frames"].to(device)
                n, c, t, h, w = x.shape
                x = torch.permute(x, (0, 2, 1, 3, 4)).contiguous()
                x = x.view(-1, c, h, w)
                logits, _ = model.model(x)
                probs = torch.sigmoid(logits).view(n, t)
                for j, pp in enumerate(probs):
                    record_id = df.iloc[i * batch_size // 4 + j]["record_id"]
                    predictions.setdefault(record_id, torch.zeros(pp.shape, dtype=torch.float32).to(device))
                    predictions[record_id] += pp * weight / total_clf_weight

        del model, dataset, dataloader
        torch.cuda.empty_cache()
        gc.collect()

negative_records = set([record_id for record_id in predictions if (predictions[record_id] > clf_threshold).sum() < 2])
print(f"Negative records: {len(negative_records)}")

# Segmentation predictions
predictions = {}
for name, model in seg_models.items():
    weight = model["weight"]
    use_tta = model["use_tta"]
    config_path = model["config_path"]
    checkpoint_paths = model["checkpoint_paths"]

    print(f"Loading segmentation model `{name}` with configuration: {config_path}")
    with open(config_path, "rb") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    try:
        config["model"]["model"]["encoder_params"]["params"]["backbone_params"]["pretrained"] = False
    except:
        pass

    for checkpoint_path in checkpoint_paths:
        print(f"Checkpoint: {checkpoint_path}")
        model = load_model(SegmentationModule2d, config["model"], checkpoint_path)
        model.to(device)
        model.eval()

        image_size = config["model"]["data"]["image_size"]
        timesteps = config["model"]["data"]["timesteps"]
        resize_i = 3 * image_size // 4
        resize_m = 192
        stride_i = image_size - resize_i
        stride_m = 256 - resize_m
        strided_steps = (image_size - resize_i) // stride_i + 1
        
        dataset = ContrailsInferenceDataset(
            df, 
            timesteps=timesteps, 
            image_size=image_size
        )
        dataloader = DataLoader(
            dataset, 
            shuffle=False, 
            batch_size=batch_size,
            num_workers=num_workers
        )

        with torch.no_grad():
            for i, batch in tqdm(enumerate(dataloader)):
                x = batch["frames"].to(device)
                
                # Full resolution prediction
                logits = model.model(x)
                if logits.shape[-1] != 256:
                    logits = F.interpolate(logits, size=256, mode="bilinear")
                probs = torch.sigmoid(logits)
                probs = probs.view(probs.size(0), 1, 256, 256)
                
#                 # Strided crops
#                 probs *= float(strided_steps)
#                 count = torch.ones(probs.shape, dtype=torch.int16).to(device) * strided_steps
#                 for j in range(strided_steps):
#                     for k in range(strided_steps):
#                         x_ = x[..., j * stride_i : j * stride_i + resize_i, k * stride_i : k * stride_i + resize_i]
#                         x_ = torch.stack([
#                             F.interpolate(x_[:, :, i], size=image_size, mode="bilinear") for i in range(len(timesteps))
#                         ], dim=2)
#                         logits_ = model.model(x_)
#                         logits_ = F.interpolate(logits_, size=resize_m, mode="bilinear")
#                         probs_ = torch.sigmoid(logits_)
#                         probs_ = probs_.view(probs_.size(0), 1, resize_m, resize_m)
#                         probs[..., j * stride_m : j * stride_m + resize_m, k * stride_m : k * stride_m + resize_m] += probs_
#                         count[..., j * stride_m : j * stride_m + resize_m, k * stride_m : k * stride_m + resize_m] += 1
#                 probs /= count

                # TTA
                if use_tta:
                    probs *= 3.
                    for j in range(3):
                        x_ = flip_batch(x, j + 1)
                        logits_ = model.model(x_)
                        if logits_.shape[-1] != 256:
                            logits_ = F.interpolate(logits_, size=256, mode="bilinear")
                        probs_ = torch.sigmoid(logits_)
                        probs_ = probs_.view(probs_.size(0), 1, 1, 256, 256)
                        probs_ = flip_batch(probs_, j + 1)[:, 0]
                        probs += probs_
                    probs /= 6.
                    
                for j, pp in enumerate(probs):
                    record_id = df.iloc[i * batch_size + j]["record_id"]
                    predictions.setdefault(record_id, np.zeros(pp.shape, dtype=np.float32))
                    if record_id in negative_records:
                        pp = torch.zeros(pp.shape, dtype=torch.float32).to(device)
                    predictions[record_id] += (pp * weight / total_seg_weight).cpu().numpy()

        del model, dataset, dataloader
        torch.cuda.empty_cache()
        gc.collect()
        
for record_id in predictions:
    predictions[record_id] = (predictions[record_id] > seg_threshold).astype(np.uint8)

Images: 2
Loading classification model `clf_efficientnetv2_s` with configuration: /kaggle/input/contrails-configs/clf_efficientnetv2_s_512.yaml
Checkpoint: /kaggle/input/contrails-models-final/tf_efficientnetv2_s.in21k_ft_in1k_512_s0.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/tf_efficientnetv2_s.in21k_ft_in1k_512_s1.ckpt


0it [00:00, ?it/s]

Negative records: 2
Loading segmentation model `seg_convnextv2_base` with configuration: /kaggle/input/contrails-configs/seg_convnextv2_base_1024.yaml
Checkpoint: /kaggle/input/contrails-models-final/convnextv2_base.fcmae_ft_in22k_in1k_384_1024_s777_fold_0.ckpt


0it [00:00, ?it/s]

Loading segmentation model `seg_convnextv2_large` with configuration: /kaggle/input/contrails-configs/2d_convnextv2_large_512_s1997_small.yaml
Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_0.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_1.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_2.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_convnextv2_large.fcmae_ft_in22k_in1k_384__fold_3.ckpt


0it [00:00, ?it/s]

Loading segmentation model `seg_maxxvitv2_large` with configuration: /kaggle/input/contrails-configs/2d_maxxvitv2_base_384_s1997_small.yaml
Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_0.ckpt


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_1.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_2.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k__fold_3.ckpt


0it [00:00, ?it/s]

Loading segmentation model `seg_nextvit_large` with configuration: /kaggle/input/contrails-configs/2d_nextvit_large_512_s1997_small.yaml
Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_0.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_1.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_2.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_upernet_160k_nextvit_large_1n1k6m_pretrained__fold_3.ckpt


0it [00:00, ?it/s]

Loading segmentation model `seg_efficientnetv2_l` with configuration: /kaggle/input/contrails-configs/tf_efficientnetv2_l_512.yaml
Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_0.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_1.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_2.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_2d__backbone_tf_efficientnetv2_l.in21k_ft_in1k__fold_3.ckpt


0it [00:00, ?it/s]

Loading segmentation model `seg_efficientnetv2_m_3d` with configuration: /kaggle/input/contrails-configs/3d_tf_efficientnetv2_m.yaml
Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_0.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_1.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_2.ckpt


0it [00:00, ?it/s]

Checkpoint: /kaggle/input/contrails-models-final/finetuning__family_3d__backbone_tf_efficientnetv2_m.in21k_ft_in1k__fold_3.ckpt


0it [00:00, ?it/s]

In [14]:
submission = pd.read_csv(os.path.join(DATA_DIR, "sample_submission.csv"), index_col="record_id")

for record_id in submission.index.tolist():
    submission.loc[record_id, "encoded_pixels"] = list_to_string(rle_encode(predictions[str(record_id)]))
    
submission.head()

Unnamed: 0_level_0,encoded_pixels
record_id,Unnamed: 1_level_1
1000834164244036115,-
1002653297254493116,-


In [15]:
submission.to_csv("submission.csv")