In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# !pip install --upgrade segmentation_models_pytorch

In [2]:
import os
import shutil
import pathlib
import gc
import pandas as pd
import ttach as tta
from pathlib import Path
from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets
import cv2
from torch.cuda import amp

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid
import optuna

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import timm

In [3]:
def seed_everything(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [4]:
seed_everything(42)

In [5]:
folder = "train"  # "validation"
    
NTB = 3

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [7]:
import torchvision.transforms as T


class ContrailDataset:
    def __init__(self, df, transform=None, normalize=False):
        self.df = df  
        self.images = df['image']
        self.labels = df['label']
        self.transform =transform
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.normalize=normalize
        
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = np.load("../../input/" + self.images[idx]).astype(float)   
#         label = np.load("../../input/" + self.labels[idx]).astype(float)
        

        if self.transform :
#             data = self.transform(image=image, mask=label)
            data = self.transform(image=image)
            image  = data['image']
#             label  = data['mask']
            image = np.transpose(image, (2, 0, 1))
#             label = np.transpose(label, (2, 0, 1))    
            
            
#         return torch.tensor(image), torch.tensor(label)
    
        if self.normalize:
            image = self.normalize_image(torch.tensor(image))
            return image
        else:
            return torch.tensor(image)
    
    
# class ContrailDataset:
#     def __init__(self, df, transform=None):
#         self.df = df  
#         self.images = df['image']
#         self.transform =transform
        
        
#     def __len__(self):
#         return len(self.images)
        
#     def __getitem__(self, idx):
#         image = np.load("../../input/" + self.images[idx]).astype(float)   
# #         label = np.load("../../input" + self.labels[idx]).astype(float)
        
        
#         # label_cls = 1 if label.sum() > 0 else 0
#         if self.transform :
#             data = self.transform(image=image)
#             image  = data['image']
#             image = np.transpose(image, (2, 0, 1))
            
#         return torch.tensor(image)

In [8]:
###TImmunetplusplus model Nirjhar

n_blocks = 4

class TimmSegModel(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)     
        self.encoder = timm.create_model(
            cfg['backbone'],
            in_chans=3,
            features_only=True,
            drop_rate=0.8,
            drop_path_rate=0.5,
            pretrained=False
        )
        self.encoder.conv_stem=nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.encoder.blocks[5] = nn.Identity()
        self.encoder.blocks[6] = nn.Sequential(
            nn.Conv2d(self.encoder.blocks[4][2].conv_pwl.out_channels, 320, 1),
            nn.BatchNorm2d(320),
            nn.ReLU6(),
        )
        tr = torch.randn(1,6,64,64)
        g = self.encoder(tr)
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(
            decoder_channels[n_blocks-1],
            cfg["num_classes"], 
            kernel_size=(3, 3),
            stride=(1, 1), 
            padding=(1, 1)
        )

    def forward(self,x):
        x = F.relu6(self.mybn1(self.conv1(x)))
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    

def load_model1(cfg):
    model = TimmSegModel(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model


#### Model 2 Nirjhar

n_blocks =4
class TimmSegModel2(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel2, self).__init__()

        self.encoder = timm.create_model(
            cfg['backbone'],
            in_chans=3,
            features_only=True,
            drop_rate=0.5,
            pretrained=False
        )
        g = self.encoder(torch.rand(1, 3, 128, 128))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Sequential(
            nn.Conv2d(decoder_channels[n_blocks-1], 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.UpsamplingBilinear2d(scale_factor=1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    

def load_model2(cfg):
    model = TimmSegModel2(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model


##################

n_blocks = 4

class TimmSegModel3(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel3, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)     
        self.encoder = timm.create_model(
            cfg["backbone"],
            in_chans=3,
            features_only=True,
            drop_rate=0.8,
            drop_path_rate=0.5,
            pretrained=False
        )
        self.encoder.conv_stem=nn.Conv2d(6, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.encoder.blocks[5] = nn.Identity()
        self.encoder.blocks[6] = nn.Sequential(
            nn.Conv2d(self.encoder.blocks[4][2].conv_pwl.out_channels, 320, 1),
            nn.BatchNorm2d(320),
            nn.ReLU6(),
        )
        tr = torch.randn(1,6,64,64)
        g = self.encoder(tr)
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], cfg['num_classes'], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        x = F.relu6(self.mybn1(self.conv1(x)))
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    
    
    
def load_model3(cfg):
    model = TimmSegModel3(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model



class TimmSegModel4(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel4, self).__init__()

        self.encoder = timm.create_model(
            cfg["backbone"],
            in_chans=3,
            features_only=True,
            drop_rate=0.5,
            pretrained=False
        )
        g = self.encoder(torch.rand(1, 3, 512, 512))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Sequential(
            nn.Conv2d(decoder_channels[n_blocks-1], 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.UpsamplingBilinear2d(scale_factor=1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    
    
def load_model4(cfg):
    model = TimmSegModel4(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model







## Rohit Model 
class Net_R(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        self.model = smp.Unet(
            encoder_name=cfg["backbone"],     
            encoder_weights=None,   
            in_channels=3,  
            classes=cfg["num_classes"],
            activation=None
        )
        
    
    def forward(self, inputs):
        mask = self.model(inputs)
        return mask
    
    
def load_modelr1(cfg):
    model = Net_R(cfg)
    model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])
    return model


#     if idx in [11, 12]:
#         model = TimmSegModelR1(cfg)
#         model = torch.nn.DataParallel(model).cuda()
#         model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])
#     else:
#         model = Net(cfg)    
#         model = torch.nn.DataParallel(model).cuda()
#         model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])

class TimmSegModelR1(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModelR1, self).__init__()

        self.n_blocks = 4
        self.encoder = timm.create_model(
            cfg['backbone'],
            in_chans=3,
            features_only=True,
            drop_rate=0.5,
            pretrained=False
        )
        g = self.encoder(torch.rand(1, 3, 128, 128))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:self.n_blocks+1],
                decoder_channels=decoder_channels[:self.n_blocks],
                n_blocks=self.n_blocks,
            )

        self.segmentation_head = nn.Sequential(
            nn.Conv2d(decoder_channels[self.n_blocks-1], 
                      cfg['num_classes'],
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=(1, 1)
            ),
            nn.UpsamplingBilinear2d(scale_factor=1)
        )

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:self.n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features


    
def load_modelr2(cfg):
    model = TimmSegModelR1(cfg)
    model = torch.nn.DataParallel(model).cuda()
    model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])
    return model

### Ioannis

In [9]:
### -----------------------------------------
### exp28 Senet154 + deep_supervision 
### -----------------------------------------

import warnings 
warnings.filterwarnings('ignore')
# from einops import rearrange, reduce, repeat
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder, DecoderBlock
import torch.nn.functional as F
import torchvision.transforms as T
import pickle 
import glob
import pytorch_lightning as pl
import timm 

print(f"Pytorch version: {torch.__version__}")
print(f"Segmentation Models version: {smp.__version__}")
print(f"Timm version: {timm.__version__}")


#################################
##### Dataset 
################################


class DatasetExp28(torch.utils.data.Dataset):
    def __init__(self, df, image_size=256, train=True, normalize=True):
        
        self.df = df
        self.trn = train
        self.df_idx: pd.DataFrame = pd.DataFrame({'idx': os.listdir(f'/kaggle/input/google-research-identify-contrails-reduce-global-warming/test')})
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        self.normalize = normalize
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)
        
        #self.images = df['image']
    
    def read_record(self, directory):
        record_data = {}
        for x in [
            "band_11", 
            "band_14", 
            "band_15"
        ]:

            record_data[x] = np.load(os.path.join(directory, x + ".npy"))

        return record_data

    def normalize_range(self, data, bounds):
        """Maps data to the range [0, 1]."""
        return (data - bounds[0]) / (bounds[1] - bounds[0])
    
    def get_false_color(self, record_data):
        _T11_BOUNDS = (243, 303)
        _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
        _TDIFF_BOUNDS = (-4, 2)
        
        N_TIMES_BEFORE = 4

        r = self.normalize_range(record_data["band_15"] - record_data["band_14"], _TDIFF_BOUNDS)
        g = self.normalize_range(record_data["band_14"] - record_data["band_11"], _CLOUD_TOP_TDIFF_BOUNDS)
        b = self.normalize_range(record_data["band_14"], _T11_BOUNDS)
        false_color = np.clip(np.stack([r, g, b], axis=2), 0, 1)
        img = false_color[..., N_TIMES_BEFORE]

        return img
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        data = self.read_record(con_path)
        img = self.get_false_color(data)
        
        ###### Rohit dataset -- todo 
        #img = np.load(self.images[index]).astype(float)
        
        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)
        
        if self.image_size != 256:
            img = self.resize_image(img)
            
        if self.normalize:
            img = self.normalize_image(img)
        
        #image_id = int(self.df_idx.iloc[index]['idx'])
        return img.float() #, torch.tensor(image_id)
    

    
#############################    
### MODEL    
#############################
class SmpUnetDecoder(nn.Module):
    def __init__(self,in_channel,skip_channel,out_channel):
        super().__init__()
        self.center = nn.Identity()
        i_channel = [in_channel,]+ out_channel[:-1]
        s_channel = skip_channel
        o_channel = out_channel
        block = [
            DecoderBlock(i, s, o, use_batchnorm=True, attention_type=None)
            for i, s, o in zip(i_channel, s_channel, o_channel)
        ]
        self.block = nn.ModuleList(block)
        
    def forward(self, feature, skip):
        d = self.center(feature)
        decode = []
        for i, block in enumerate(self.block):
            s = skip[i]
            d = block(d, s)
            decode.append(d)
        last  = d
        return last, decode


def conv3x3(in_channel, out_channel): #not change resolusion
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=3,stride=1,padding=1,dilation=1,bias=False)

def conv1x1(in_channel, out_channel): #not change resolution
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=1,stride=1,padding=0,dilation=1,bias=False)
    
class Net(nn.Module):
    def __init__(self, cfg, vb=False):
        super().__init__()
        self.cfg = cfg
        self.vb = vb
        self.crop_depth = None #cfg.crop_depth
        self.deepsupervision = False #cfg.deepsupervision #True
        
        # if cfg.backbone == 'resnet34' or cfg.backbone == 'seresnet34':
        if 'resnet34' in cfg.backbone:
            conv_dim=64
            encoder_dim = [conv_dim, 64, 128, 256, 512, ]
            decoder_dim = [256, 128, 64, 32, 16]
        else:
            conv_dim=128 #128
            encoder_dim  = [conv_dim] + [256, 512, 1024, 2048]
            decoder_dim = [512, 256, 128, 64, 16]
            
        # self.encoder = resnet34d(pretrained=False,in_chans=CFG.one_depth)
        self.encoder = timm.create_model(cfg.backbone, pretrained=cfg.pretrained, in_chans=cfg.in_chans, num_classes=0) #drop_path_rate=0.2, 
        self.decoder = SmpUnetDecoder(
            in_channel=encoder_dim[-1],
            skip_channel=encoder_dim[:-1][::-1] + [0],
            out_channel=decoder_dim
        )
        ### seg head
        seg_head_in_c = 16 #if 'resnet34' in cfg.backbone else 128
        self.logit = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1)
        
        #-- pool attention weight
        self.weight = nn.ModuleList([
            nn.Conv2d(encoder_dim[i], 1, kernel_size=1, padding=0) for i in range(len(encoder_dim))
        ])
        #### deep supervision
        deep_ch = [16, 64, 128, 256]   # [64, 64, 64, 64]
        self.deep4 = conv1x1(deep_ch[0],1)#.apply(init_weight)
        self.deep3 = conv1x1(deep_ch[1],1)#.apply(init_weight)
        self.deep2 = conv1x1(deep_ch[2],1)#.apply(init_weight)
        self.deep1 = conv1x1(deep_ch[3],1)#.apply(init_weight)
        
        self.up1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up4 = nn.Upsample(scale_factor=1, mode='bilinear', align_corners=True)
        
    def forward(self, batch):
        K = 1
        x = batch
        # ----
        encoder = []
        e = self.encoder
        
        x = e.conv1(x)
        x = e.bn1(x)
        x = e.act1(x)
        encoder.append(x)
        
        x = F.avg_pool2d(x,kernel_size=2,stride=2)
        x = e.layer1(x)
        encoder.append(x)
        
        x = e.layer2(x)
        encoder.append(x)
        
        x = e.layer3(x)
        encoder.append(x)
        
        x = e.layer4(x)
        encoder.append(x)
        #if self.vb: print('encoder', [f.shape for f in encoder])

        ### decoder
        last, decoder = self.decoder(feature=encoder[-1], skip = encoder[:-1][::-1] + [None])
        #if self.vb: print('decoder',[f.shape for f in decoder])
        #if self.vb: print('last',last.shape)
        
        ### head
        logit = self.logit(last)
        #if self.vb: print('logit',logit.shape)
        
        if self.deepsupervision:
            y4 = decoder[-1]  ### torch.Size([B, 16, 256, 256])
            y3 = decoder[-2]  ### torch.Size([B, 32, 128, 128])
            y2 = decoder[-3]  ### torch.Size([B, 64, 64, 64])
            y1 = decoder[-4]  ### torch.Size([B, 128, 32, 32])
            ### --> B x C x H x W
            y1 = self.up1(y1)  ### x 8
            y2 = self.up2(y2)  ### x 4 
            y3 = self.up3(y3)  ### x 2
            y4 = self.up4(y4)  ### x 1  
            ##################
            s4 = self.deep4(y4) ### --> B x 1 x H x W
            s3 = self.deep3(y3)
            s2 = self.deep2(y2)
            s1 = self.deep1(y1)
            logits_deeps = [s4,s3,s2,s1]
            #if self.vb: print('logits_deeps', s1.shape,  s2.shape,  s3.shape,  s4.shape)
            return logit, logits_deeps
        return logit

    
class args_28:
    size=512 
    exp_name = "exp28"
    use_folds = [0]
    ##use_folds = [0,1,2,3]
    in_chans = 3 #12 # 6 # 65
    size = 512 #384 # 224
    image_size = size
    target_size = 1
    batch_size = 8 #32
    backbone = 'gluon_senet154'
    seg_model = "Unet"   ## "Unet++", "MAnet", "Linknet", "FPN", "PSPNet", "PAN", "DeepLabV3",  "DeepLabV3+"
    model_name = f'{seg_model}-{backbone}'
    pretrained = False
    deepsupervision=False 

    
    
class LightningModule_28(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Net(args_28) 

    def forward(self, batch):
        return self.model(batch)

    
def load_model_exp28(cfg):
    model = LightningModule_28().load_from_checkpoint(cfg['model_pth']) 
    return model



### ensemble muiltiple checkpoints 


### move all weights to nb 
MODEL_DIR_EXP28 = "/kaggle/input/contrails-exp28b-senet154-512/"    
CKPTS_EXP28 = [
    #     ######## 4-skf split 
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0.ckpt',
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v1.ckpt',
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v2.ckpt',
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v3.ckpt',
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v4.ckpt',
    #     '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold0-v5.ckpt'
    #     ####### full train.csv 
    '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold-1.ckpt',
    '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold-1-v1.ckpt',
    '/kaggle/input/contrails-exp28b-senet154-512/Unet-gluon_senet154_fold-1-v2.ckpt',
]


class EnsembleModel_28(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.model = nn.ModuleList()
        
        for i, model_path in enumerate(cfg['model_pth']):
            _model = LightningModule_28().load_from_checkpoint(model_path) #cfg['model_pth']
            _model.eval()
            print('Load weights from:', model_path)
            self.model.append(_model)        
    
    def forward(self, x):
        outputs=[]
        for m in self.model:
            outputs.append(m(x)) #(torch.sigmoid(out))
        res = torch.stack(outputs,dim=0).mean(0)
        return res  

def load_model_exp28_snapshot(cfg):
    model = EnsembleModel_28(cfg)
    return model

Pytorch version: 2.0.1+cu117
Segmentation Models version: 0.3.3
Timm version: 0.9.2


In [10]:
### -----------------------------------------
### exp27
### -----------------------------------------


class args_27:
    in_chans = 3 
    size = 512 #384 # 224
    backbone = 'resnest26d'
    seg_model = "Unet"   ## "Unet++", "MAnet", "Linknet", "FPN", "PSPNet", "PAN", "DeepLabV3",  "DeepLabV3+"
    pretrained = False
    deepsupervision=False 



class SmpUnetDecoder(nn.Module):
    def __init__(self,in_channel,skip_channel,out_channel):
        super().__init__()
        self.center = nn.Identity()
        i_channel = [in_channel,]+ out_channel[:-1]
        s_channel = skip_channel
        o_channel = out_channel
        block = [
            DecoderBlock(i, s, o, use_batchnorm=True, attention_type=None)
            for i, s, o in zip(i_channel, s_channel, o_channel)
        ]
        self.block = nn.ModuleList(block)
        
    def forward(self, feature, skip):
        d = self.center(feature)
        decode = []
        for i, block in enumerate(self.block):
            s = skip[i]
            d = block(d, s)
            decode.append(d)
        last  = d
        return last, decode


def conv3x3(in_channel, out_channel): #not change resolusion
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=3,stride=1,padding=1,dilation=1,bias=False)

def conv1x1(in_channel, out_channel): #not change resolution
    return nn.Conv2d(in_channel,out_channel,
                      kernel_size=1,stride=1,padding=0,dilation=1,bias=False)
    
class Net_27(nn.Module):
    def __init__(self, cfg, vb=False):
        super().__init__()
        self.cfg = cfg
        self.vb = vb
        self.crop_depth = None #cfg.crop_depth
        self.deepsupervision = False #cfg.deepsupervision #True
        
        # if cfg.backbone == 'resnet34' or cfg.backbone == 'seresnet34':
        if 'resnet34' in cfg.backbone:
            conv_dim=64
            encoder_dim = [conv_dim, 64, 128, 256, 512, ]
            decoder_dim = [256, 128, 64, 32, 16]
        else:
            conv_dim=64 #128
            encoder_dim  = [conv_dim] + [256, 512, 1024, 2048]
            decoder_dim = [256, 128, 64, 32, 16]
            
        # self.encoder = resnet34d(pretrained=False,in_chans=CFG.one_depth)
        self.encoder = timm.create_model(cfg.backbone, pretrained=False, in_chans=3, num_classes=0) #drop_path_rate=0.2, 
        self.decoder = SmpUnetDecoder(
            in_channel=encoder_dim[-1],
            skip_channel=encoder_dim[:-1][::-1] + [0],
            out_channel=decoder_dim
        )
        ### seg head
        seg_head_in_c = 16 #if 'resnet34' in cfg.backbone else 128
        self.logit = nn.Conv2d(decoder_dim[-1], 1, kernel_size=1)
        
        #-- pool attention weight
        self.weight = nn.ModuleList([
            nn.Conv2d(encoder_dim[i], 1, kernel_size=1, padding=0) for i in range(len(encoder_dim))
        ])
        #### deep supervision
        deep_ch = [16, 32, 64, 128] # [64, 64, 64, 64]
        self.deep4 = conv1x1(deep_ch[0],1)#.apply(init_weight)
        self.deep3 = conv1x1(deep_ch[1],1)#.apply(init_weight)
        self.deep2 = conv1x1(deep_ch[2],1)#.apply(init_weight)
        self.deep1 = conv1x1(deep_ch[3],1)#.apply(init_weight)
        
        self.up1 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
        self.up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up4 = nn.Upsample(scale_factor=1, mode='bilinear', align_corners=True)
        
    def forward(self, batch):
        K = 1
        x = batch
        # ----
        encoder = []
        e = self.encoder
        
        x = e.conv1(x)
        x = e.bn1(x)
        x = e.act1(x)
        encoder.append(x)
        
        x = F.avg_pool2d(x,kernel_size=2,stride=2)
        x = e.layer1(x)
        encoder.append(x)
        
        x = e.layer2(x)
        encoder.append(x)
        
        x = e.layer3(x)
        encoder.append(x)
        
        x = e.layer4(x)
        encoder.append(x)
        #if self.vb: print('encoder', [f.shape for f in encoder])

        ### decoder
        last, decoder = self.decoder(feature=encoder[-1], skip = encoder[:-1][::-1] + [None])
        #if self.vb: print('decoder',[f.shape for f in decoder])
        #if self.vb: print('last',last.shape)
        
        ### head
        logit = self.logit(last)
        #if self.vb: print('logit',logit.shape)
        
        if self.deepsupervision:
            y4 = decoder[-1]  ### torch.Size([B, 16, 256, 256])
            y3 = decoder[-2]  ### torch.Size([B, 32, 128, 128])
            y2 = decoder[-3]  ### torch.Size([B, 64, 64, 64])
            y1 = decoder[-4]  ### torch.Size([B, 128, 32, 32])
            ### --> B x C x H x W
            y1 = self.up1(y1)  ### x 8
            y2 = self.up2(y2)  ### x 4 
            y3 = self.up3(y3)  ### x 2
            y4 = self.up4(y4)  ### x 1  
            ##################
            s4 = self.deep4(y4) ### --> B x 1 x H x W
            s3 = self.deep3(y3)
            s2 = self.deep2(y2)
            s1 = self.deep1(y1)
            logits_deeps = [s4,s3,s2,s1]
            #if self.vb: print('logits_deeps', s1.shape,  s2.shape,  s3.shape,  s4.shape)
            return logit, logits_deeps
        return logit


    
    
class LightningModule_27(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Net_27(args_27) 

    def forward(self, batch):
        return self.model(batch)

    
class EnsembleModel_27(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.model = nn.ModuleList()
        
        #for fold, w in zip(use_folds, weights):
        for i, model_path in enumerate(cfg['model_pth']):
            
            #model_path = CKPTS_EXP28[i]
            _model = LightningModule_27().load_from_checkpoint(model_path) #cfg['model_pth']
            _model.eval()
            print('Load weights from:', model_path)

            self.model.append(_model)
            #self.weights.append(w)
        
    
    def forward(self, x):
        outputs=[]
        for m in self.model:
            outputs.append(m(x)) #(torch.sigmoid(out))
        res = torch.stack(outputs,dim=0).mean(0)
        return res  

def load_model_exp27_soup(cfg):
    model = EnsembleModel_27(cfg)
    return model

In [11]:
### -----------------------------------------
### exp26 Unet++ regnetx_080
### -----------------------------------------

# class args:
#     backbone = 'timm-regnetx_080'

class SegModel_26(nn.Module):
    def __init__(self):
        super(SegModel_26, self).__init__()
        
        self.seg = smp.UnetPlusPlus(encoder_name='timm-regnetx_080', encoder_weights=None, classes=1, activation=None)
        ##self.seg = seg_models[cfg.seg_model](encoder_name=cfg.backbone, encoder_weights="imagenet", classes=1, activation=None)

    def forward(self,x):
        global_features = self.seg.encoder(x)
        seg_features = self.seg.decoder(*global_features)
        seg_features = self.seg.segmentation_head(seg_features)
        return seg_features
    
    
class LightningModule_26(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = SegModel_26()

    def forward(self, batch):
        return self.model(batch)


class EnsembleModel_26(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.model = nn.ModuleList()
        
        for i, model_path in enumerate(cfg['model_pth']):
            _model = LightningModule_26().load_from_checkpoint(model_path) #cfg['model_pth']
            _model.eval()
            print('Load weights from:', model_path)
            self.model.append(_model)
            #self.weights.append(w)
        
    def forward(self, x):
        outputs=[]
        for m in self.model:
            outputs.append(m(x)) #(torch.sigmoid(out))
        res = torch.stack(outputs,dim=0).mean(0)
        return res  

    
def load_model_exp26_soup(cfg):
    model = EnsembleModel_26(cfg)
    return model

In [12]:
### -----------------------------------------
### exp19 Unet++
### -----------------------------------------

# class args:
#     exp_name = "exp19"
#     in_chans = 3 #12 # 6 # 65
#     size = 256  
#     target_size = 1
#     batch_size = 32
#     backbone = 'timm-resnest26d'


class SegModel_19(nn.Module):
    def __init__(self):
        super(SegModel_19, self).__init__()
        self.seg = smp.UnetPlusPlus(encoder_name='timm-resnest26d', encoder_weights=None, classes=1, activation=None)
    
    def forward(self,x):
        global_features = self.seg.encoder(x)
        seg_features = self.seg.decoder(*global_features)
        seg_features = self.seg.segmentation_head(seg_features)
        return seg_features


class LightningModule_19(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = SegModel_19()

    def forward(self, batch):
        return self.model(batch)

    
class EnsembleModel_19(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.model = nn.ModuleList()
        
        for i, model_path in enumerate(cfg['model_pth']):
            _model = LightningModule_19().load_from_checkpoint(model_path) #cfg['model_pth']
            _model.eval()
            print('Load weights from:', model_path)
            self.model.append(_model)
            #self.weights.append(w)
        
    def forward(self, x):
        outputs=[]
        for m in self.model:
            outputs.append(m(x)) #(torch.sigmoid(out))
        res = torch.stack(outputs,dim=0).mean(0)
        return res  

    
def load_model_exp19_soup(cfg):
    model = EnsembleModel_19(cfg)
    return model

In [13]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=1e-6):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum()
    den = y_true.sum() + y_pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    
    return dice


def get_transform(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
    ], p=1.0)
    return transform

def get_transform2(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_LINEAR),
    ], p=1.0)
    return transform




In [14]:
# val_df = pd.read_csv("../../input/data_utils/val_df_filled.csv")
# val_df_full = val_df.copy()
# val_dups = np.load("../../input/data_utils/dups_val.npy")
# val_dups = [int(val_id) for val_id in val_dups]

# val_df = val_df.loc[~val_df['id'].isin(val_dups)].reset_index(drop=True)
# print(val_df.shape, val_df_full.shape)

In [15]:
val_df = pd.read_csv(f'../../input/pseudo/{folder}_data_{NTB}.csv') 

In [16]:
CFGS1 = [
#     {
#         'model_name': 'Unet',
#         'backbone': 'efficientnet-b7',
#         'img_size': [256, 256],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold0.pth',
#         'threshold': 0.24, #0.24,
#         'call_sign': "roh_01_tta",
#         'model_func': load_modelr1,
#         'tta': True,
#         'normalize': False
#     }, 
#     {
#         'model_name': 'Unet',
#         'backbone': 'maxvit_small_tf_512',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/maxvitfold0mix_maxvit_small_tf_512_best_epochstage2oof685cv657-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_01_tta",
#         'model_func': load_model4,
#         'tta': True,
#         'normalize': False
#     },
    
#     {
#         'model_name': 'Unet',
#         'backbone': 'eca_nfnet_l1',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/ecanfnetl1v2fold0_eca_nfnet_l1_best_epochstage2-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_02_tta",
#         'model_func': load_model2,
#         'tta': True,
#         'normalize': False
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b8',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2-00.bin',
#         'threshold': 0.79,
#         'call_sign': "nir_03_tta", 
#         'model_func': load_model3,
#         'tta': True,
#         'normalize': False
#     },
#     {
#         'model_name': 'Unet++',
#         'backbone': 'tf_efficientnet_b7_ns',
#         'img_size': [512, 512],
#         'num_classes': 1,
#         'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv2fpn512_tf_efficientnet_b7_ns_best_epochcv650lb675-00.bin',
#         'threshold': 0.62, #0.24,
#         'call_sign': "nir_04_tta", 
#         'model_func': load_model1,
#         'tta': True,
#         'normalize': False
#     }, 
    {
        'model_name': 'Unet',
        'backbone': 'gluon_senet154',
        'img_size': [512, 512],
        'num_classes': 1,
        'threshold': 0.5, #0.24,
        'call_sign': "ioa_01",
        'model_func': load_model_exp28_snapshot,
        'tta': True, #True
        'normalize': True,
        'model_pth': [
            '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1.ckpt',
            '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v1.ckpt',
            '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v2.ckpt',
        ]
    }, 
    {
        'model_name': 'Unet',
        'backbone': 'gluon_senet154',
        'img_size': [512, 512],
        'num_classes': 1,
        'threshold': 0.5, #0.24,
        'model_func': load_model_exp28,
        'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold0.ckpt',        
        'call_sign': "ioa_02",
        'tta': True, #False, #True
        'normalize': True
    }
    
]

In [17]:
final_preds = []

for idx, cfg in enumerate(CFGS1):   

#     if idx <= 7:
#         continue
        
    print(cfg)
    val_transform = get_transform(cfg['img_size'])
    val_transform2 = get_transform2(cfg['img_size'])

    valid_dataset = ContrailDataset(val_df, transform=val_transform, normalize=False)  
    if cfg['normalize'] and ("ioa_" in cfg['call_sign']):
        valid_dataset = ContrailDataset(val_df, transform=val_transform2, normalize=True)  
#         valid_dataset = ContrailDataset(val_df, transform=val_transform, normalize=True)  



    
    valid_loader = DataLoader(
        valid_dataset, 
        batch_size = 32, #32, 
        shuffle = False, 
        num_workers = 4, 
        pin_memory = True, 
        drop_last = False
    )
    
    
    model_base = cfg['model_func'](cfg)        
      
    if cfg['tta']:
        if "roh_" in cfg['call_sign']:
            model = tta.SegmentationTTAWrapper(model_base, tta.aliases.flip_transform(), merge_mode='mean')
        else:
            model = tta.SegmentationTTAWrapper(model_base, tta.aliases.hflip_transform(), merge_mode='mean')
    
        model.to(device)
        model.eval()
    
    
    model_base.to(device)
    model_base.eval()
    

    
    preds = []
#     masks_ = []      
        
#     for index, (images, masks) in enumerate(tqdm(valid_loader)):  
    for index, (images) in enumerate(tqdm(valid_loader)):  

        images  = images.to(device, dtype=torch.float)
#         masks  = masks.to(device, dtype=torch.float)
#         if cfg['img_size'][0] != 256:
#             masks = torch.nn.functional.interpolate(masks, size=256, mode='nearest') 
#         masks_.append(torch.squeeze(masks, dim=1))
        with torch.inference_mode():
            images = torch.nn.functional.interpolate(images, size=cfg['img_size'][0], mode='nearest')
            pred = model_base(images).sigmoid()   
            if cfg['tta']:
                pred2 = model(images).sigmoid()
                pred = (pred + pred2) / 2
                
            pred = torch.nn.functional.interpolate(pred, size=256, mode='nearest')
            preds.append(torch.squeeze(pred, dim=1))
        
            
            
    model_preds = torch.cat(preds, dim=0).detach().cpu()  
    torch.save(model_preds, f"../../output/pseudo_preds/{cfg['call_sign']}.pt")    

    
    
#     model_masks = torch.cat(masks_, dim=0)
#     model_preds = torch.cat(preds, dim=0)
        
#     model_masks = torch.flatten(model_masks, start_dim=0, end_dim=1)
#     model_preds = torch.flatten(model_preds, start_dim=0, end_dim=1)  
    
#     # save
#     torch.save(model_preds, f"../../output/final_preds/{cfg['call_sign']}.pt")    
    
#     best_threshold = 0.0
#     best_dice_score = 0.0
#     for threshold in [i / 100 for i in range(101)] :
#         score = dice_coef(model_masks, model_preds, thr=threshold).cpu().detach().numpy() 

        
#         if score > best_dice_score:
#             best_dice_score = score
#             best_threshold = threshold
    
        
#     print(best_dice_score, best_threshold)
#     final_preds1.append(model_preds)
    
    
    final_preds.append(model_preds)
    
    if cfg['tta']: del model
    del model_base
    torch.cuda.empty_cache()
    gc.collect()

{'model_name': 'Unet', 'backbone': 'gluon_senet154', 'img_size': [512, 512], 'num_classes': 1, 'threshold': 0.5, 'call_sign': 'ioa_01', 'model_func': <function load_model_exp28_snapshot at 0x7fdee0e22cb0>, 'tta': True, 'normalize': True, 'model_pth': ['/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1.ckpt', '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v1.ckpt', '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v2.ckpt']}


Lightning automatically upgraded your loaded checkpoint from v1.5.10 to v2.0.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../output/ioannis/exp_28/Unet-gluon_senet154_fold-1.ckpt`


Load weights from: /home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.10 to v2.0.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v1.ckpt`


Load weights from: /home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v1.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.10 to v2.0.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v2.ckpt`


Load weights from: /home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold-1-v2.ckpt


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'gluon_senet154', 'img_size': [512, 512], 'num_classes': 1, 'threshold': 0.5, 'model_func': <function load_model_exp28 at 0x7fdee0e22b00>, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/ioannis/exp_28/Unet-gluon_senet154_fold0.ckpt', 'call_sign': 'ioa_02', 'tta': True, 'normalize': True}


Lightning automatically upgraded your loaded checkpoint from v1.4.8 to v2.0.4. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../output/ioannis/exp_28/Unet-gluon_senet154_fold0.ckpt`


  0%|          | 0/642 [00:00<?, ?it/s]

In [19]:
call_signs = [
    "roh_01_tta", "nir_01_tta", "nir_02_tta", "nir_03_tta", "nir_04_tta", "ioa_01", "ioa_02"
]


final_preds = []
for sign in tqdm(call_signs, total=len(call_signs)):
    wt = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/pseudo_preds/{sign}.pt") 
    final_preds.append(wt)


  0%|          | 0/7 [00:00<?, ?it/s]

In [20]:
final_preds = torch.stack(final_preds).mean(dim=0)
final_preds = (final_preds>0.35).double()

In [21]:
val_df['id'] = val_df['image'].apply(lambda x: x.split("/")[-2])

In [22]:
ids = val_df['id'].values
for (val, label_id) in tqdm(zip(final_preds, ids), total=len(ids)): 
    mask = val.view(256, 256, 1).detach().cpu().numpy()
    np.save(f"../../input/pseudo/{folder}_data_{NTB}/{label_id}/label_6811_702lb.npy", mask.astype('float16')) 


  0%|          | 0/20529 [00:00<?, ?it/s]

In [21]:
# 0.657336 0.01


# call_signs1 = [
#     "roh_01_tta", "nir_01_tta", "nir_02_tta", "nir_03_tta", "nir_04_tta"
# ] 

# call_signs1 = [
#     "roh_02_tta",  "roh_03_tta", "roh_04_tta", "roh_05_tta", "roh_06_tta",  "roh_07_tta",  "roh_08_tta"  
# ] 

call_signs1 = [
    "roh_01_tta", 
    "nir_01_tta",
    "nir_02_tta",
    "nir_03_tta",
    "nir_04_tta",
    
    "ioa_01", 
    "ioa_02",
    
] 

model_masks = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/val_masks.pt")


preds1 = []

for idx, sign in tqdm(enumerate(call_signs1), total=len(call_signs1)):
    wt = torch.load(f"/home/rohits/pv1/Contrail_Detection/output/final_preds/{sign}.pt")    
    preds1.append(wt)
    
    score = dice_coef(model_masks, wt, thr=0.5).cpu().detach().numpy() 
    print(score)
    
    
# preds1.append(preds2)
    
final_preds = preds1
final_preds = torch.stack(final_preds).mean(dim=0)
score = dice_coef(model_masks, final_preds, thr=0.5).cpu().detach().numpy() 

print("0.5 TH Score: ", score)


best_threshold = 0.0
best_dice_score = 0.0
for threshold in [i / 100 for i in range(101)] :
    score = dice_coef(model_masks, final_preds, thr=threshold).cpu().detach().numpy() 
    if score > best_dice_score:
        best_dice_score = score
        best_threshold = threshold
print(best_dice_score, best_threshold)

  0%|          | 0/7 [00:00<?, ?it/s]

0.6510141
0.66153294
0.649828
0.6697093
0.6561677
0.66287154
0.6537375
0.5 TH Score:  0.6811164
0.6867434 0.35


In [19]:
# 0.5 TH Score:  0.67887855
# 0.6856482 0.28



# 0.5 TH Score:  0.6811164
# 0.6867434 0.35
