In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# !pip install --upgrade segmentation_models_pytorch

In [13]:
import os
import shutil
import pathlib
import gc
import pandas as pd
import ttach as tta
from pathlib import Path
from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets
import cv2
from torch.cuda import amp

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid
import optuna

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import timm

In [14]:
def seed_everything(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [15]:
seed_everything(42)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [17]:
class Unet(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.model = smp.Unet(**kwargs)

    def forward(self, x):
        x = self.model(x)
        return x

    
class ConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels),
            nn.ReLU6()
        )


class ConvBN(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False):
        super(ConvBN, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
            norm_layer(out_channels)
        )


class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
        super(Conv, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
                      dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
        )


class SeparableConvBNReLU(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
                 norm_layer=nn.BatchNorm2d):
        super(SeparableConvBNReLU, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            norm_layer(out_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.ReLU6()
        )


class SeparableConvBN(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1,
                 norm_layer=nn.BatchNorm2d):
        super(SeparableConvBN, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            norm_layer(out_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )


class SeparableConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1):
        super(SeparableConv, self).__init__(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation,
                      padding=((stride - 1) + dilation * (kernel_size - 1)) // 2,
                      groups=in_channels, bias=False),
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        )


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True)
        self.drop = nn.Dropout(drop, inplace=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class GlobalLocalAttention(nn.Module):
    def __init__(self,
                 dim=256,
                 num_heads=16,
                 qkv_bias=False,
                 window_size=8,
                 relative_pos_embedding=True
                 ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // self.num_heads
        self.scale = head_dim ** -0.5
        self.ws = window_size

        self.qkv = Conv(dim, 3*dim, kernel_size=1, bias=qkv_bias)
        self.local1 = ConvBN(dim, dim, kernel_size=3)
        self.local2 = ConvBN(dim, dim, kernel_size=1)
        self.proj = SeparableConvBN(dim, dim, kernel_size=window_size)

        self.attn_x = nn.AvgPool2d(kernel_size=(window_size, 1), stride=1,  padding=(window_size//2 - 1, 0))
        self.attn_y = nn.AvgPool2d(kernel_size=(1, window_size), stride=1, padding=(0, window_size//2 - 1))

        self.relative_pos_embedding = relative_pos_embedding

        if self.relative_pos_embedding:
            # define a parameter table of relative position bias
            self.relative_position_bias_table = nn.Parameter(
                torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_h = torch.arange(self.ws)
            coords_w = torch.arange(self.ws)
            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
            relative_coords[:, :, 0] += self.ws - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.ws - 1
            relative_coords[:, :, 0] *= 2 * self.ws - 1
            relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
            self.register_buffer("relative_position_index", relative_position_index)

            trunc_normal_(self.relative_position_bias_table, std=.02)

    def pad(self, x, ps):
        _, _, H, W = x.size()
        if W % ps != 0:
            x = F.pad(x, (0, ps - W % ps), mode='reflect')
        if H % ps != 0:
            x = F.pad(x, (0, 0, 0, ps - H % ps), mode='reflect')
        return x

    def pad_out(self, x):
        x = F.pad(x, pad=(0, 1, 0, 1), mode='reflect')
        return x

    def forward(self, x):
        B, C, H, W = x.shape

        local = self.local2(x) + self.local1(x)

        x = self.pad(x, self.ws)
        B, C, Hp, Wp = x.shape
        qkv = self.qkv(x)

        q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads,
                            d=C//self.num_heads, hh=Hp//self.ws, ww=Wp//self.ws, qkv=3, ws1=self.ws, ws2=self.ws)

        dots = (q @ k.transpose(-2, -1)) * self.scale

        if self.relative_pos_embedding:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.ws * self.ws, self.ws * self.ws, -1)  # Wh*Ww,Wh*Ww,nH
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
            dots += relative_position_bias.unsqueeze(0)

        attn = dots.softmax(dim=-1)
        attn = attn @ v

        attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads,
                         d=C//self.num_heads, hh=Hp//self.ws, ww=Wp//self.ws, ws1=self.ws, ws2=self.ws)

        attn = attn[:, :, :H, :W]

        out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \
              self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect'))

        out = out + local
        out = self.pad_out(out)
        out = self.proj(out)
        # print(out.size())
        out = out[:, :, :H, :W]

        return out


class Block(nn.Module):
    def __init__(self, dim=256, num_heads=16,  mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, window_size=8):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = GlobalLocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, window_size=window_size)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop)
        self.norm2 = norm_layer(dim)

    def forward(self, x):

        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class WF(nn.Module):
    def __init__(self, in_channels=128, decode_channels=128, eps=1e-8):
        super(WF, self).__init__()
        self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = eps
        self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

    def forward(self, x, res):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
        x = self.post_conv(x)
        return x


class FeatureRefinementHead(nn.Module):
    def __init__(self, in_channels=64, decode_channels=64):
        super().__init__()
        self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1)

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = 1e-8
        self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3)

        self.pa = nn.Sequential(nn.Conv2d(decode_channels, decode_channels, kernel_size=3, padding=1, groups=decode_channels),
                                nn.Sigmoid())
        self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                Conv(decode_channels, decode_channels//16, kernel_size=1),
                                nn.ReLU6(),
                                Conv(decode_channels//16, decode_channels, kernel_size=1),
                                nn.Sigmoid())

        self.shortcut = ConvBN(decode_channels, decode_channels, kernel_size=1)
        self.proj = SeparableConvBN(decode_channels, decode_channels, kernel_size=3)
        self.act = nn.ReLU6()

    def forward(self, x, res):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x
        x = self.post_conv(x)
        shortcut = self.shortcut(x)
        pa = self.pa(x) * x
        ca = self.ca(x) * x
        x = pa + ca
        x = self.proj(x) + shortcut
        x = self.act(x)

        return x


class AuxHead(nn.Module):

    def __init__(self, in_channels=64, num_classes=8):
        super().__init__()
        self.conv = ConvBNReLU(in_channels, in_channels)
        self.drop = nn.Dropout(0.1)
        self.conv_out = Conv(in_channels, num_classes, kernel_size=1)

    def forward(self, x, h, w):
        feat = self.conv(x)
        feat = self.drop(feat)
        feat = self.conv_out(feat)
        feat = F.interpolate(feat, size=(h, w), mode='bilinear', align_corners=False)
        return feat


class Decoder(nn.Module):
    def __init__(self,
                 encoder_channels=(64, 128, 256, 512),
                 decode_channels=64,
                 dropout=0.1,
                 window_size=8,
                 num_classes=6):
        super(Decoder, self).__init__()

        self.pre_conv = ConvBN(encoder_channels[-1], decode_channels, kernel_size=1)
        self.b4 = Block(dim=decode_channels, num_heads=8, window_size=window_size)

        self.b3 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
        self.p3 = WF(encoder_channels[-2], decode_channels)

        self.b2 = Block(dim=decode_channels, num_heads=8, window_size=window_size)
        self.p2 = WF(encoder_channels[-3], decode_channels)

        if self.training:
            self.up4 = nn.UpsamplingBilinear2d(scale_factor=4)
            self.up3 = nn.UpsamplingBilinear2d(scale_factor=2)
            self.aux_head = AuxHead(decode_channels, num_classes)

        self.p1 = FeatureRefinementHead(encoder_channels[-4], decode_channels)

        self.segmentation_head = nn.Sequential(ConvBNReLU(decode_channels, decode_channels),
                                               nn.Dropout2d(p=dropout, inplace=True),
                                               Conv(decode_channels, num_classes, kernel_size=1))
        self.init_weight()

    def forward(self, res1, res2, res3, res4, h, w):
        if self.training:
            x = self.b4(self.pre_conv(res4))
            h4 = self.up4(x)

            x = self.p3(x, res3)
            x = self.b3(x)
            h3 = self.up3(x)

            x = self.p2(x, res2)
            x = self.b2(x)
            h2 = x
            x = self.p1(x, res1)
            x = self.segmentation_head(x)
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)

            ah = h4 + h3 + h2
            ah = self.aux_head(ah, h, w)

            return x, ah

        else:
            x = self.b4(self.pre_conv(res4))
            x = self.p3(x, res3)
            x = self.b3(x)

            x = self.p2(x, res2)
            x = self.b2(x)

            x = self.p1(x, res1)

            x = self.segmentation_head(x)
            x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)

            return x

    def init_weight(self):
        for m in self.children():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


class UNetFormer(nn.Module):
    def __init__(self,
                 decode_channels=64,
                 dropout=0.1,
                 backbone_name='swsl_resnet18',
                 pretrained=True,
                 window_size=8,
                 num_classes=6
                 ):
        super().__init__()

        if 'convnext' in backbone_name:
            self.backbone = timm.create_model(backbone_name, features_only=True,
                                              output_stride=32, pretrained=pretrained)
        else:
            self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32,
                                              out_indices=(1, 2, 3, 4), pretrained=pretrained)

        encoder_channels = self.backbone.feature_info.channels()

        self.decoder = Decoder(encoder_channels, decode_channels, dropout, window_size, num_classes)

    def forward(self, x):
        h, w = x.size()[-2:]
        res1, res2, res3, res4 = self.backbone(x)
        if self.training:
            x, ah = self.decoder(res1, res2, res3, res4, h, w)
            return x, ah
        else:
            x = self.decoder(res1, res2, res3, res4, h, w)
            return x
    
    
def get_model(config):
    model = globals().get(config.model.arch)

    return model(**config.model.params)

In [18]:
import torchvision.transforms as T


# class ContrailDataset:
#     def __init__(self, df, transform=None, normalize=False):
#         self.df = df  
#         self.images = df['image']
#         self.labels = df['label']
#         self.transform =transform
#         self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
#         self.normalize=normalize
        
        
#     def __len__(self):
#         return len(self.images)
        
#     def __getitem__(self, idx):
#         image = np.load("../../input/" + self.images[idx]).astype(float)   
#         label = np.load("../../input/" + self.labels[idx]).astype(float)
        

#         # label_cls = 1 if label.sum() > 0 else 0
#         if self.transform :
#             data = self.transform(image=image, mask=label)
#             image  = data['image']
#             label  = data['mask']
#             image = np.transpose(image, (2, 0, 1))
#             label = np.transpose(label, (2, 0, 1))    
            
            
# #         return torch.tensor(image), torch.tensor(label)
    
#         if self.normalize:
#             image = self.normalize_image(torch.tensor(image))
#             return image, torch.tensor(label)
#         else:
#             return torch.tensor(image), torch.tensor(label)
    
    

In [19]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=1e-6):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum()
    den = y_true.sum() + y_pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    
    return dice


def get_transform(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
    ], p=1.0)
    return transform

def get_transform2(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_LINEAR),
    ], p=1.0)
    return transform




In [20]:
val_df = pd.read_csv("../../input/data_utils/val_df_filled.csv")
val_df_full = val_df.copy()

val_dups = np.load("../../input/data_utils/dups_val.npy")

val_dups = [int(val_id) for val_id in val_dups]

In [27]:
# val_df = val_df.loc[~val_df['id'].isin(val_dups)].reset_index(drop=True)
print(val_df.shape, val_df_full.shape)

(1856, 4) (1856, 4)


In [50]:
# !pip install einops
# !pip install easydict

In [55]:
from torchvision.transforms import functional as transforms_F
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from torchvision.utils import _log_api_usage_once


_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
    Image.HAMMING: 'PIL.Image.HAMMING',
    Image.BOX: 'PIL.Image.BOX',
}


def _pil_interp(method):
    if method == 'bicubic':
        return Image.BICUBIC
    elif method == 'lanczos':
        return Image.LANCZOS
    elif method == 'hamming':
        return Image.HAMMING
    else:
        # default bilinear, do we want to allow nearest?
        return Image.BILINEAR


class Resize(torch.nn.Module):

    def __init__(self, size, interpolation=Image.BILINEAR, max_size=None, antialias=None):
        super().__init__()
        _log_api_usage_once(self)
        if not isinstance(size, (int, Sequence)):
            raise TypeError(f"Size should be int or sequence. Got {type(size)}")
        if isinstance(size, Sequence) and len(size) not in (1, 2):
            raise ValueError("If size is a sequence, it should have 1 or 2 values")
        self.size = size
        self.max_size = max_size

        # Backward compatibility with integer value
        if isinstance(interpolation, int):
            interpolation = _interpolation_modes_from_int(interpolation)

        self.interpolation = interpolation
        self.antialias = antialias

    def forward(self, img, mask):
        img = transforms_F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
        mask = transforms_F.resize(mask, self.size, Image.NEAREST, self.max_size, self.antialias)
        return img, mask

    def __repr__(self) -> str:
        detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
        return f"{self.__class__.__name__}{detail}"

    

class ToTensor:
    def __init__(self) -> None:
        _log_api_usage_once(self)

    def __call__(self, img_pic, mask_pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return transforms_F.to_tensor(img_pic), transforms_F.to_tensor(mask_pic)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

    
class Normalize(torch.nn.Module):
    def __init__(self, mean, std, inplace=False):
        super().__init__()
        _log_api_usage_once(self)
        self.mean = mean
        self.std = std
        self.inplace = inplace

    def forward(self, img_tensor, mask_tensor) -> tuple:
        """
        Args:
            tensor (Tensor): Tensor image to be normalized.

        Returns:
            Tensor: Normalized Tensor image.
        """
        return transforms_F.normalize(img_tensor, self.mean, self.std, self.inplace), mask_tensor

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

    
class Compose:
    def __init__(self, transforms):
        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            _log_api_usage_once(self)
        self.transforms = transforms

    def __call__(self, img, mask):
        for t in self.transforms:
            img, mask = t(img, mask)
        return img, mask

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += f"    {t}"
        format_string += "\n)"
        return format_string


def transforms_imagenet_eval(
        img_size=224,
        crop_pct=None,
        interpolation='bilinear',
#         interpolation='nearest',
        use_prefetcher=False,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        normalize=True,
):
    tfl = [
        Resize(img_size, _pil_interp(interpolation)),
    ]
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        tfl += [ToNumpy()]
    else:
        tfl += [
            ToTensor(),
        ]

        if normalize:
            tfl += [
                Normalize(
                 mean=torch.tensor(mean),
                 std=torch.tensor(std))
            ]

    return Compose(tfl)


def get_transforms(phase_config):
    return transforms_imagenet_eval(
            img_size=phase_config.Resize.height,
            normalize=True,
    )

In [61]:
class ContrailDataset:
    def __init__(self, df, transform=None, normalize=False):
        self.df = df  
        self.images = df['image']
        self.labels = df['label']
        self.transform =transform
#         self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.normalize=normalize
        
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = np.load("../../input/" + self.images[idx]).astype(float)   
        label = np.load("../../input/" + self.labels[idx]).astype(float)
        

        # label_cls = 1 if label.sum() > 0 else 0
        if self.transform :
            data = self.transform(image=image, mask=label)
            image  = data['image']
            label  = data['mask']
            image = np.transpose(image, (2, 0, 1))
            label = np.transpose(label, (2, 0, 1))    
            
            
#         return torch.tensor(image), torch.tensor(label)
    
        if self.normalize:
            image = self.normalize_image(torch.tensor(image))
            return image, torch.tensor(label)
        else:
            return torch.tensor(image), torch.tensor(label)

In [71]:
from easydict import EasyDict as edict
import yaml
from typing import Sequence
from torchvision.transforms.functional import InterpolationMode, _interpolation_modes_from_int



def load_config(config_path):
    with open(config_path, 'r') as fid:
        yaml_config = edict(yaml.load(fid, Loader=yaml.Loader))

    return yaml_config

config = load_config('/home/rohits/pv1/Contrail_Detection/output/config.yml')


if config.model.arch != 'UNetFormer':
    config.model.params.encoder_weights = None
else:
    config.model.params.pretrained = False

transforms = get_transforms(config.transforms.test)
model = get_model(config)
model.cuda()
model.load_state_dict(torch.load('/home/rohits/pv1/Contrail_Detection/output/checkpoints/best_model.pth')["model_state_dict"])
model.eval()
model.float()
print()




In [66]:
final_preds = []

        
val_transform = get_transform([512,512])
valid_dataset = ContrailDataset(val_df, transform=transforms, normalize=True)  

valid_loader = DataLoader(
    valid_dataset, 
    batch_size = 32, #32, 
    shuffle = False, 
    num_workers = 2, 
    pin_memory = True, 
    drop_last = False
)


model.to(device)
model.eval()
print()






In [None]:
# get_transform([512,512])['image']

In [73]:
val_transform, transforms

(Compose([
   Resize(always_apply=False, p=1, height=512, width=512, interpolation=0),
 ], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}),
 Compose(
     Resize(size=512, interpolation=bilinear, max_size=None, antialias=None)
     ToTensor()
     Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
 ))

In [64]:
preds = []
masks_ = []      

for index, (images, masks) in enumerate(tqdm(valid_loader)):  
    images  = images.to(device, dtype=torch.float)
    masks  = masks.to(device, dtype=torch.float)
    masks = torch.nn.functional.interpolate(masks, size=256, mode='nearest') 

    masks_.append(torch.squeeze(masks, dim=1))
    with torch.inference_mode():
        images = torch.nn.functional.interpolate(images, size=512, mode='nearest')
        pred = model(images).sigmoid()   

        pred = torch.nn.functional.interpolate(pred, size=256, mode='nearest')
        preds.append(torch.squeeze(pred, dim=1))
        
    
model_masks = torch.cat(masks_, dim=0)
model_preds = torch.cat(preds, dim=0)

  0%|          | 0/58 [00:00<?, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/rohits/anaconda3/envs/contrail/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/rohits/anaconda3/envs/contrail/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/rohits/anaconda3/envs/contrail/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_4162342/1165069015.py", line 21, in __getitem__
    data = self.transform(image=image, mask=label)
TypeError: Compose.__call__() got an unexpected keyword argument 'image'


In [49]:
best_threshold = 0.0
best_dice_score = 0.0
for threshold in [i / 100 for i in range(101)] :
    score = dice_coef(model_masks, model_preds, thr=threshold).cpu().detach().numpy() 
    if score > best_dice_score:
        best_dice_score = score
        best_threshold = threshold
        
print(best_dice_score, best_threshold)

0.6495943 0.52


In [65]:
transforms

Compose(
    Resize(size=512, interpolation=bilinear, max_size=None, antialias=None)
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [32]:
# val_df = val_df.head(100)

In [33]:
final_preds1 = []

for idx, cfg in enumerate(CFGS1):   

    if idx <= 7:
        continue
        
    print(cfg)
    val_transform = get_transform(cfg['img_size'])
    val_transform2 = get_transform2(cfg['img_size'])

    valid_dataset = ContrailDataset(val_df, transform=val_transform, normalize=False)  
    if cfg['normalize'] and ("ioa_" in cfg['call_sign']):
        valid_dataset = ContrailDataset(val_df, transform=val_transform2, normalize=True)  
#         valid_dataset = ContrailDataset(val_df, transform=val_transform, normalize=True)  



    
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size = 32, #32, 
        shuffle = False, 
        num_workers = 2, 
        pin_memory = True, 
        drop_last = False
    )
    
    
    model_base = cfg['model_func'](cfg)        
      
    if cfg['tta']:
        if "roh_" in cfg['call_sign']:
            model = tta.SegmentationTTAWrapper(model_base, tta.aliases.flip_transform(), merge_mode='mean')
        else:
            model = tta.SegmentationTTAWrapper(model_base, tta.aliases.hflip_transform(), merge_mode='mean')
    
        model.to(device)
        model.eval()
    
    
    model_base.to(device)
    model_base.eval()
    

    
    preds = []
    masks_ = []      
        
    for index, (images, masks) in enumerate(tqdm(valid_loader)):  
        images  = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        if cfg['img_size'][0] != 256:
#             masks = torch.nn.functional.interpolate(masks, size=256, mode='nearest') 
            masks = torch.nn.functional.interpolate(masks, size=256, mode='nearest') 


        masks_.append(torch.squeeze(masks, dim=1))
        with torch.inference_mode():
            images = torch.nn.functional.interpolate(images, size=cfg['img_size'][0], mode='nearest')
            pred = model_base(images).sigmoid()   
            if cfg['tta']:
                pred = model(images).sigmoid()
#                 pred2 = model(images).sigmoid()
#                 pred = (pred + pred2) / 2
                
            pred = torch.nn.functional.interpolate(pred, size=256, mode='nearest')
            preds.append(torch.squeeze(pred, dim=1))
        
            
            
            
    
    model_masks = torch.cat(masks_, dim=0)
    model_preds = torch.cat(preds, dim=0)
        
    print(model_preds.shape)
        
#     model_masks = torch.flatten(model_masks, start_dim=0, end_dim=1)
#     model_preds = torch.flatten(model_preds, start_dim=0, end_dim=1)  
    
    # save
#     torch.save(model_preds, f"../../output/final_preds/{cfg['call_sign']}.pt")    
#     torch.save(model_preds, f"../../output/final_preds/{cfg['call_sign']}_full.pt")    

#     if idx == 6:
#     torch.save(model_masks, f'../../output/final_preds/val_masks.pt')

    np.save(f"../../output/final_preds/{cfg['call_sign']}", model_preds.detach().cpu().numpy())
    if idx == 0:
        np.save(f"../../output/final_preds/val_masks", model_masks.detach().cpu().numpy())

    
    best_threshold = 0.0
    best_dice_score = 0.0
    for threshold in [i / 100 for i in range(101)] :
        score = dice_coef(model_masks, model_preds, thr=threshold).cpu().detach().numpy() 

        
        if score > best_dice_score:
            best_dice_score = score
            best_threshold = threshold
    
        
    print(best_dice_score, best_threshold)
    final_preds1.append(model_preds)
    
    if cfg['tta']: del model
    del model_base
    torch.cuda.empty_cache()
    gc.collect()

{'model_name': 'Unet++', 'backbone': 'tf_efficientnet_b8', 'img_size': [512, 512], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/nirjhar/fullfinetune-effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2cv670-00.bin', 'threshold': 0.79, 'call_sign': 'nir_05_tta', 'model_func': <function load_model3 at 0x7f4d2aa97eb0>, 'tta': True, 'normalize': False}


  0%|          | 0/58 [00:00<?, ?it/s]

torch.Size([1856, 256, 256])
0.6744085 0.76


In [34]:
call_signs1 = [
#     "roh_01_tta", # can be removed
    "roh_02_tta",
    
    
#     "roh_03_tta",
    "roh_04_tta", 
    
    
    "nir_01_tta",
#     "nir_02_tta",
#     "nir_03_tta",
#     "nir_04_tta",
    "nir_05_tta",
    
    
    "121", 
    "122", 
    "125", 
    "127", 
    "131"  
] 

model_masks = torch.from_numpy(np.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/val_masks.npy")).flatten() 

preds = []

for idx, sign in tqdm(enumerate(call_signs1), total=len(call_signs1)):
    wt = torch.from_numpy(np.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/{sign}.npy")).flatten()   
    preds.append(wt)
    print(wt.shape)
    
    score = dice_coef(model_masks, wt, thr=0.5).cpu().detach().numpy() 
    print(score)
    
    
# # preds1.append(preds2)
    
final_preds = preds
final_preds = torch.stack(final_preds).mean(dim=0)
score = dice_coef(model_masks, final_preds, thr=0.5).cpu().detach().numpy() 

print("0.5 TH Score: ", score)


best_threshold = 0.0
best_dice_score = 0.0
for threshold in [i / 100 for i in range(101)] :
    if threshold < 0.35 or threshold > 0.51:
        continue
    
    score = dice_coef(model_masks, final_preds, thr=threshold).cpu().detach().numpy() 
    if score > best_dice_score:
        best_dice_score = score
        best_threshold = threshold
        
print(best_dice_score, best_threshold)

  0%|          | 0/9 [00:00<?, ?it/s]

torch.Size([121634816])
0.67922837
torch.Size([121634816])
0.6797982
torch.Size([121634816])
0.6567023
torch.Size([121634816])
0.67340404
torch.Size([121634816])
0.6707117
torch.Size([121634816])
0.67332405
torch.Size([121634816])
0.65260977
torch.Size([121634816])
0.6755674
torch.Size([121634816])
0.6600541
0.5 TH Score:  0.69304866
0.6936055 0.45


In [27]:
# ,dkngld

# 0.6857662 # 0.68593323 : 0.47 TH
# 0.68798006  # 0.6884547 0.45
# 0.6903099 # 0.6911206 0.44
# 0.6909886 # 0.6928196 0.45



# 0.6914439 # 0.69235253 0.44


In [None]:
# 0.657336 0.01


# call_signs1 = [
#     "roh_01_tta", "nir_01_tta", "nir_02_tta", "nir_03_tta", "nir_04_tta"
# ] 

# call_signs1 = [
#     "roh_02_tta",  "roh_03_tta", "roh_04_tta", "roh_05_tta", "roh_06_tta",  "roh_07_tta",  "roh_08_tta"  
# ] 

call_signs1 = [
    "roh_01_tta",
    "roh_02_tta", 
    "roh_03_tta", 
    "nir_01_tta",
    "nir_02_tta",
    "nir_03_tta",
    "nir_04_tta",
    "121", 
    "122",
    "125", 
    "127",
    "131"
    
    
#     "ioa_01", 
#     "ioa_02",
#     "roh_02_tta"
    
] 

model_masks = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/val_masks.pt")


preds1 = []

for idx, sign in tqdm(enumerate(call_signs1), total=len(call_signs1)):
    wt = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/{sign}.pt")    
    preds1.append(wt)
    
    score = dice_coef(model_masks, wt, thr=0.5).cpu().detach().numpy() 
    print(score)
    
    
# preds1.append(preds2)
    
final_preds = preds1
final_preds = torch.stack(final_preds).mean(dim=0)
score = dice_coef(model_masks, final_preds, thr=0.5).cpu().detach().numpy() 

print("0.5 TH Score: ", score)


best_threshold = 0.0
best_dice_score = 0.0
for threshold in [i / 100 for i in range(101)] :
    score = dice_coef(model_masks, final_preds, thr=threshold).cpu().detach().numpy() 
    if score > best_dice_score:
        best_dice_score = score
        best_threshold = threshold
print(best_dice_score, best_threshold)

In [39]:
# 0.5 TH Score:  0.67887855
# 0.6856482 0.28



# 0.5 TH Score:  0.6811164
# 0.6867434 0.35


In [34]:
call_signs1 = [
    "roh_01_tta_full", "nir_01_tta_full", "nir_02_tta_full", "nir_03_tta_full", "nir_04_tta_full"
] 

model_masks = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/val_masks_full.pt")


preds1 = []

for idx, sign in tqdm(enumerate(call_signs1), total=len(call_signs1)):
    wt = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/{sign}.pt")    
    preds1.append(wt)
    
    score = dice_coef(model_masks, wt, thr=0.5).cpu().detach().numpy() 
    print(score)
    
final_preds = preds1
final_preds = torch.stack(final_preds).mean(dim=0)
score = dice_coef(model_masks, final_preds, thr=0.5).cpu().detach().numpy() 
print("0.5 TH Score: ", score)

best_threshold = 0.0
best_dice_score = 0.0
for threshold in [i / 100 for i in range(101)] :
    score = dice_coef(model_masks, final_preds, thr=threshold).cpu().detach().numpy() 
    if score > best_dice_score:
        best_dice_score = score
        best_threshold = threshold
print(best_dice_score, best_threshold)


  0%|          | 0/5 [00:00<?, ?it/s]

0.6469362
0.6567023
0.6457618
0.6704817
0.655622
0.5 TH Score:  0.6751985
0.6808599 0.35


In [None]:
# 0.5 TH Score:  0.6760312
# 0.6810049 0.33

In [None]:
# CFGS1 = [   
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'model_func': load_model_exp28,
#         'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1.ckpt',
#         'call_sign': "ioa_01",
#         'tta': True, #False, #True
#         'normalize': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "ioa_02",
#         'model_func': load_model_exp28_snapshot,
#         'tta': False, #True
#         'normalize': True,
#         'model_pth': [
#             ######## full train.csv 
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v2.ckpt',
#             ######### 4-skf split 
#         ]
#     }, 
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "roh_02",
#         'model_func': load_model_exp28_snapshot,        
#         'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1.ckpt',
#         'tta': False, #True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "roh_03",
#         'model_func': load_model_exp28_snapshot,        
#         'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v1.ckpt',
#         'tta': False, #True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "roh_04",
#         'model_func': load_model_exp28_snapshot,        
#         'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v2.ckpt',
#         'tta': False, #True
#     }, 
    
#         {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "ioa_02",
#         'model_func': load_model_exp28_snapshot,
#         'tta': False, #True
#         'normalize': True,
#         'model_pth': [
#             ######## full train.csv 
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v2.ckpt',
#             ######### 4-skf split 
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v1.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v2.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v3.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v4.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v5.ckpt'
#         ]
#     }, 
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'gluon_senet154',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'threshold': 0.5, #0.24,
#         'call_sign': "ioa_02_tta",
#         'model_func': load_model_exp28_snapshot,
#         'tta': True, #True
#         'normalize': True,
#         'model_pth': [
#             ######## full train.csv 
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v1.ckpt',
#             '/home/rohits/pv1/Contrail_Detection/output/ioannis/Unet-gluon_senet154_fold-1-v2.ckpt',
#             ######### 4-skf split 
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v1.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v2.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v3.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v4.ckpt',
#             #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v5.ckpt'
#         ]
#     }, 
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'efficientnet-b7',
#         'img_size': [256, 256],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_01_tta",
#         'model_func': load_modelr1,
#         'tta': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'efficientnet-b7',
#         'img_size': [256, 256],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_01",
#         'model_func': load_modelr1,
#         'tta': False
#     }, 
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'cs3darknet_l',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/cs3darknet_l-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_02_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
#         {
#         'model_name': 'Unet',
#         'backbone': 'cs3edgenet_x',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/cs3edgenet_x-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_03_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'cs3sedarknet_l',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/cs3sedarknet_l-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_04_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'ecaresnet101d',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/ecaresnet101d-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_05_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'ecaresnet269d',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/ecaresnet269d-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_06_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'fbnetv3_g',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/fbnetv3_g-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_07_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
        
#     {
#         'model_name': 'Unet',
#         'backbone': 'tf_mixnet_l',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/tf_mixnet_l-512/checkpoint_dice_ctrl_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_08_tta",
#         'model_func': load_modelr2,
#         'tta': True
#     }, 
    
    

#     {
#         'model_name': 'Unet',
#         'backbone': 'maxvit_small_tf_512',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/maxvitfold0mix_maxvit_small_tf_512_best_epochstage2oof685cv657-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_01_tta",
#         'model_func': load_model4,
#         'tta': True
#     },
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'eca_nfnet_l1',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl1v2fold0_eca_nfnet_l1_best_epochstage2-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_02_tta",
#         'model_func': load_model2,
#         'tta': True
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b8',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_03_tta", 
#         'model_func': load_model3,
#         'tta': True
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b7_ns',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv2fpn512_tf_efficientnet_b7_ns_best_epochcv650lb675-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_04_tta", 
#         'model_func': load_model1,
#         'tta': True
#     }, 
    
    
    
#         {
#         'model_name': 'Unet',
#         'backbone': 'eca_nfnet_l2',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl2v2fold0_eca_nfnet_l2_best_epochstage2-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_05_tta",
#         'model_func': load_model2,
#         'tta': True
#     },
#     {
#     'model_name': 'Unet++',
#         'backbone': 'densenetblur121d',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl2v2fold0_densenetblur121d_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_06_tta", 
#         'model_func': load_model2,
#         'tta': True
#     },
    
#     {
#     'model_name': 'Unet++',
#         'backbone': 'convnextv2_base',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/convnextfold0_convnextv2_base_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_07_tta", 
#         'model_func': load_model2,
#         'tta': True
#     },
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'maxvit_small_tf_512',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/maxvitfold0mix_maxvit_small_tf_512_best_epochstage2oof685cv657-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_01",
#         'model_func': load_model4,
#         'tta': False
#     },
    
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'eca_nfnet_l1',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl1v2fold0_eca_nfnet_l1_best_epochstage2-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_02",
#         'model_func': load_model2,
#         'tta': False
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b8',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_03", 
#         'model_func': load_model3,
#         'tta': False
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b7_ns',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv2fpn512_tf_efficientnet_b7_ns_best_epochcv650lb675-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_04", 
#         'model_func': load_model1,
#         'tta': False
#     }, 
    
#         {
#         'model_name': 'Unet',
#         'backbone': 'eca_nfnet_l2',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl2v2fold0_eca_nfnet_l2_best_epochstage2-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_05",
#         'model_func': load_model2,
#         'tta': False
#     },
#     {
#     'model_name': 'Unet++',
#         'backbone': 'densenetblur121d',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl2v2fold0_densenetblur121d_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_06", 
#         'model_func': load_model2,
#         'tta': False
#     },
    
#     {
#     'model_name': 'Unet++',
#         'backbone': 'convnextv2_base',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/convnextfold0_convnextv2_base_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_07", 
#         'model_func': load_model2,
#         'tta': False
#     },
    
    
    
    
    
# ]