In [1]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'


In [2]:
import os
import shutil
import pathlib
import gc
import pandas as pd
import ttach as tta
from pathlib import Path
from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset

import torchvision
from torchvision import datasets
import cv2
from torch.cuda import amp

import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid
import optuna

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import timm

In [3]:
def seed_everything(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

In [4]:
seed_everything(42)

In [5]:
folder = "train"  # "validation"
    
NTB = 5

In [6]:
CFGS = [ 
    
    {
        'model_name': 'Unet++',
        'backbone': 'tf_efficientnet_b7_ns',
        'img_size': [512, 512],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv2fpn512_tf_efficientnet_b7_ns_best_epochcv650lb675-00.bin',
        'threshold': 0.62, #0.24,
        'call_sign': "nir_01",
        'tta': True
    },

    {
        'model_name': 'Unet++',
        'backbone': 'tf_efficientnet_b7_ns',
        'img_size': [512, 512],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv512unetplus_tf_efficientnet_b7_ns_best_epochstage2cv669-00.bin',
        'threshold': 0.62, #0.24,
        'call_sign': "nir_02", 
        'tta': True
    },
    {
        'model_name': 'Unet',
        'backbone': 'eca_nfnet_l1',
        'img_size': [512, 512],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/contrail_01/output/nirjhar/unetplusecarnf01_eca_nfnet_l1_best_epochstage2cv656-00.bin',
        'threshold': 0.62, #0.24,
        'call_sign': "nir_03",
        'tta': True
    },
    {
    'model_name': 'Unet++',
        'backbone': 'tf_efficientnet_b8',
        'img_size': [512, 512],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/nirjhar/effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2-00.bin',
        'threshold': 0.79,
        'call_sign': "nir_04", 
        'tta': True
    },    

    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01/Unet/efficientnet-b7-256/checkpoint_dice_fold3.pth',
        'threshold': 0.24, #0.24,
        'call_sign': "roh_06",
        'tta': True
    }, 
    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold0.pth',
        'threshold': 0.24, #0.24,
        'call_sign': "roh_07",
        'tta': True
    },     
    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_ctrl_fold3.pth',
        'threshold': 0.24, #0.24,
        'call_sign': "roh_08",
        'tta': True
    },     
    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_pl/Unet/efficientnet-b7-256/checkpoint_dice_ctrl_fold0.pth',
        'threshold': 0.24, #0.24,
        'call_sign': "roh_11",
        'tta': True
    }, 
    
    
    {
        'model_name': 'Unet',
        'backbone': 'efficientnet-b7',
        'img_size': [256, 256],
        'num_classes': 1,
        'model_pth':  '/home/rohits/pv1/Contrail_Detection/output/exp_01_pl_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold3.pth',
        'threshold': 0.24, #0.24,
        'call_sign': "roh_11",
        'tta': True
    }, 

]

ensemple_type = "No" #"No" #"mine"




In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [8]:
class ContrailDataset:
    def __init__(self, df, transform=None):
        self.df = df  
        self.images = df['image']
        self.transform =transform
        
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        image = np.load("../../input/" + self.images[idx]).astype(float)   
#         label = np.load("../../input" + self.labels[idx]).astype(float)
        
        
        # label_cls = 1 if label.sum() > 0 else 0
        if self.transform :
            data = self.transform(image=image)
            image  = data['image']
            image = np.transpose(image, (2, 0, 1))
            
        return torch.tensor(image)

In [9]:
import segmentation_models_pytorch as smp
import timm
n_blocks = 4

class TimmSegModel(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)     
        self.encoder = timm.create_model(
            cfg["backbone"],
            in_chans=3,
            features_only=True,
            drop_rate=0.8,
            drop_path_rate=0.5,
            pretrained=False
        )
        self.encoder.conv_stem=nn.Conv2d(6, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.encoder.blocks[5] = nn.Identity()
        self.encoder.blocks[6] = nn.Sequential(
            nn.Conv2d(self.encoder.blocks[4][2].conv_pwl.out_channels, 320, 1),
            nn.BatchNorm2d(320),
            nn.ReLU6(),
        )
        tr = torch.randn(1,6,64,64)
        g = self.encoder(tr)
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], cfg['num_classes'], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        x = F.relu6(self.mybn1(self.conv1(x)))
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    
    
    
def load_model_timmunetplusplus(cfg):
    model = TimmSegModel(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model

In [10]:
###TImmunetplusplus model Nirjhar

n_blocks = 4

class TimmSegModel(nn.Module):
    def __init__(self, backbone, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)     
        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=0.8,
            drop_path_rate=0.5,
            pretrained=False
        )
        self.encoder.conv_stem=nn.Conv2d(6, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.encoder.blocks[5] = nn.Identity()
        self.encoder.blocks[6] = nn.Sequential(
            nn.Conv2d(self.encoder.blocks[4][2].conv_pwl.out_channels, 320, 1),
            nn.BatchNorm2d(320),
            nn.ReLU6(),
        )
        tr = torch.randn(1,6,64,64)
        g = self.encoder(tr)
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(
            decoder_channels[n_blocks-1],
            cfg["num_classes"], 
            kernel_size=(3, 3),
            stride=(1, 1), 
            padding=(1, 1)
        )

    def forward(self,x):
        x = F.relu6(self.mybn1(self.conv1(x)))
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    

#path='./exp/baselinev2/qishentrialv2_tf_efficientnet_b7_ns_last_epoch-00.bin'

def build_model_timmunetplus(backbone, cfg):
    model = TimmSegModel(backbone, cfg)
    return model

def load_model_timmunetplus(path, backbone, cfg):
    model = build_model_timmunetplus(backbone, cfg)
    model.load_state_dict(torch.load(path))
    return model


#### Model 2 Nirjhar

n_blocks =4
class TimmSegModel2(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=True):
        super(TimmSegModel2, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=0.5,
            pretrained=False
        )
        g = self.encoder(torch.rand(1, 3, 128, 128))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Sequential(
            nn.Conv2d(decoder_channels[n_blocks-1], 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.UpsamplingBilinear2d(scale_factor=1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    


def build_model(backbone):
    model = TimmSegModel2(backbone, segtype='unet')
    return model

def load_model(path,backbone):
    model = build_model(backbone)
    model.load_state_dict(torch.load(path))
    return model


##################

n_blocks = 4

class TimmSegModel3(nn.Module):
    def __init__(self, cfg, segtype='unet', pretrained=True):
        super(TimmSegModel3, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(6, 12, 3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(12, 36, 3, stride=1, padding=1, bias=False)
        self.mybn1 = nn.BatchNorm2d(6)
        self.mybn2 = nn.BatchNorm2d(12)
        self.mybn3 = nn.BatchNorm2d(36)     
        self.encoder = timm.create_model(
            cfg["backbone"],
            in_chans=3,
            features_only=True,
            drop_rate=0.8,
            drop_path_rate=0.5,
            pretrained=False
        )
        self.encoder.conv_stem=nn.Conv2d(6, 72, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

        self.encoder.blocks[5] = nn.Identity()
        self.encoder.blocks[6] = nn.Sequential(
            nn.Conv2d(self.encoder.blocks[4][2].conv_pwl.out_channels, 320, 1),
            nn.BatchNorm2d(320),
            nn.ReLU6(),
        )
        tr = torch.randn(1,6,64,64)
        g = self.encoder(tr)
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], cfg['num_classes'], kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        x = F.relu6(self.mybn1(self.conv1(x)))
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
    
    
    
def load_model3(cfg):
    model = TimmSegModel3(cfg)
    model.load_state_dict(torch.load(cfg['model_pth']))
    return model










## Rohit Model 
class Net(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        self.model = smp.Unet(
            encoder_name=cfg["backbone"],     
            encoder_weights=None,   
            in_channels=3,  
            classes=cfg["num_classes"],
            activation=None
        )
        
    
    def forward(self, inputs):
        mask = self.model(inputs)
        return mask

In [11]:
def dice_coef(y_true, y_pred, thr=0.5, epsilon=1e-6):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum()
    den = y_true.sum() + y_pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    
    return dice


def get_transform(img_size):
    transform = A.Compose([
        A.Resize(*img_size, interpolation=cv2.INTER_NEAREST),
    ], p=1.0)
    return transform

In [12]:
params = {'w1': 0.026130564275158946,
 'w2': 0.9456225644729294,
 'w3': 0.977118766928013,
 'w4': 0.4122243077816452,
 'w5': 0.07925478207938379,
 'w6': 0.06900567964478133,
 'w7': 0.8020704404435675,
 'w8': 0.34190685300944035,
 'w9': 0.17518538289485933}

def weighted_ensemble(params, final_preds):    
    for index, val in enumerate(params.keys()):
        if index == 0:
            preds = params[val]*final_preds[0]
        else:
            preds += params[val]*final_preds[index]
    
    param_sum = 0
    for key, val in params.items():
        param_sum += val

    preds = preds/param_sum
    return preds

In [13]:
val_df = pd.read_csv(f'../../input/pseudo/{folder}_data_{NTB}.csv') 


# final_preds = []

In [14]:
val_df

Unnamed: 0,image,label,fold,id
0,/pseudo/train_data_5/1284412112608546821/image...,/pseudo/train_data_5/1284412112608546821/label...,4.0,1284412112608546821
1,/pseudo/train_data_5/7457695218848685981/image...,/pseudo/train_data_5/7457695218848685981/label...,1.0,7457695218848685981
2,/pseudo/train_data_5/836236084461732921/image.npy,/pseudo/train_data_5/836236084461732921/label.npy,1.0,836236084461732921
3,/pseudo/train_data_5/7829917977180135058/image...,/pseudo/train_data_5/7829917977180135058/label...,4.0,7829917977180135058
4,/pseudo/train_data_5/5319255125658459358/image...,/pseudo/train_data_5/5319255125658459358/label...,3.0,5319255125658459358
...,...,...,...,...
20524,/pseudo/train_data_5/8443915190215904823/image...,/pseudo/train_data_5/8443915190215904823/label...,3.0,8443915190215904823
20525,/pseudo/train_data_5/8495643844280686935/image...,/pseudo/train_data_5/8495643844280686935/label...,2.0,8495643844280686935
20526,/pseudo/train_data_5/856381910009426679/image.npy,/pseudo/train_data_5/856381910009426679/label.npy,2.0,856381910009426679
20527,/pseudo/train_data_5/3751790308836191485/image...,/pseudo/train_data_5/3751790308836191485/label...,0.0,3751790308836191485


In [15]:
final_preds = []

for idx, cfg in enumerate(CFGS):    
    print(cfg)
    val_transform = get_transform(cfg['img_size'])
    valid_dataset = ContrailDataset(val_df, transform=val_transform)  

    valid_loader = DataLoader(
        valid_dataset, 
        batch_size = 32, #32, 
        shuffle = False, 
        num_workers = 2, 
        pin_memory = True, 
        drop_last = False
    )

    
    if ensemple_type == "mine":
        model = Net(cfg)    
        model = torch.nn.DataParallel(model).cuda()
        model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])
    else:
        if idx <= 1:
            model = load_model_timmunetplus(cfg['model_pth'], cfg['backbone'], cfg)
        elif idx == 2:
            model = load_model(cfg['model_pth'], cfg['backbone'])
        elif idx == 3:
            model = load_model3(cfg)
        else:
            model = Net(cfg)    
            model = torch.nn.DataParallel(model).cuda()
            model.load_state_dict(torch.load(cfg['model_pth'], map_location=torch.device('cpu'))['model'])

    if cfg['tta']:
        if ensemple_type == "mine":
            model = tta.SegmentationTTAWrapper(model, tta.aliases.flip_transform(), merge_mode='mean')
        else:
            if idx <= 3:
                model = tta.SegmentationTTAWrapper(model, tta.aliases.hflip_transform(), merge_mode='mean')
            else:
                model = tta.SegmentationTTAWrapper(model, tta.aliases.flip_transform(), merge_mode='mean')

    
    
    model.to(device)
    model.eval()
    
    preds = []
    
    for index, (images) in enumerate(tqdm(valid_loader)):  
        images  = images.to(device, dtype=torch.float)

        with torch.inference_mode():
            images = torch.nn.functional.interpolate(images,size=cfg['img_size'][0], mode='nearest')
            pred = model(images)                     
            pred = torch.nn.functional.interpolate(pred.sigmoid(), size=256, mode='nearest') 
            preds.append(torch.squeeze(pred, dim=1))

            
    model_preds = torch.cat(preds, dim=0).detach().cpu()  
#     model_preds = torch.flatten(model_preds, start_dim=0, end_dim=1)  
        
    final_preds.append(model_preds)
    
    del model, model_preds
    torch.cuda.empty_cache()
    gc.collect()

{'model_name': 'Unet++', 'backbone': 'tf_efficientnet_b7_ns', 'img_size': [512, 512], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv2fpn512_tf_efficientnet_b7_ns_best_epochcv650lb675-00.bin', 'threshold': 0.62, 'call_sign': 'nir_01', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet++', 'backbone': 'tf_efficientnet_b7_ns', 'img_size': [512, 512], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/contrail_01/output/nirjhar/qishentrialv512unetplus_tf_efficientnet_b7_ns_best_epochstage2cv669-00.bin', 'threshold': 0.62, 'call_sign': 'nir_02', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'eca_nfnet_l1', 'img_size': [512, 512], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/contrail_01/output/nirjhar/unetplusecarnf01_eca_nfnet_l1_best_epochstage2cv656-00.bin', 'threshold': 0.62, 'call_sign': 'nir_03', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet++', 'backbone': 'tf_efficientnet_b8', 'img_size': [512, 512], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/nirjhar/effb8modelv2fold0_tf_efficientnet_b8_best_epochstage2-00.bin', 'threshold': 0.79, 'call_sign': 'nir_04', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'efficientnet-b7', 'img_size': [256, 256], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/exp_01/Unet/efficientnet-b7-256/checkpoint_dice_fold3.pth', 'threshold': 0.24, 'call_sign': 'roh_06', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'efficientnet-b7', 'img_size': [256, 256], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold0.pth', 'threshold': 0.24, 'call_sign': 'roh_07', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'efficientnet-b7', 'img_size': [256, 256], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/exp_01_s2/Unet/efficientnet-b7-256/checkpoint_dice_ctrl_fold3.pth', 'threshold': 0.24, 'call_sign': 'roh_08', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'efficientnet-b7', 'img_size': [256, 256], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/exp_01_pl/Unet/efficientnet-b7-256/checkpoint_dice_ctrl_fold0.pth', 'threshold': 0.24, 'call_sign': 'roh_11', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

{'model_name': 'Unet', 'backbone': 'efficientnet-b7', 'img_size': [256, 256], 'num_classes': 1, 'model_pth': '/home/rohits/pv1/Contrail_Detection/output/exp_01_pl_s2/Unet/efficientnet-b7-256/checkpoint_dice_fold3.pth', 'threshold': 0.24, 'call_sign': 'roh_11', 'tta': True}


  0%|          | 0/642 [00:00<?, ?it/s]

In [16]:
final_preds = weighted_ensemble(params, final_preds)
threshold = 0.59
final_preds = (final_preds>threshold).double()

In [17]:
# ids = val_df['id'].values
# for (val, label_id) in tqdm(zip(final_preds, ids), total=len(ids)): 
#     mask = val.view(256, 256, 1).detach().cpu().numpy()
#     np.save(f"../../input/pseudo/{folder}_data_{NTB}/{label_id}/label_684cv_687lb.npy", mask.astype('float16')) 


val_df['id'] = val_df['image'].apply(lambda x: x.split("/")[-2])

In [18]:
ids = val_df['id'].values
for (val, label_id) in tqdm(zip(final_preds, ids), total=len(ids)): 
    mask = val.view(256, 256, 1).detach().cpu().numpy()
    np.save(f"../../input/pseudo/{folder}_data_{NTB}/{label_id}/label_684cv_687lb.npy", mask.astype('float16')) 


  0%|          | 0/20529 [00:00<?, ?it/s]

In [19]:
final_preds[0].shape

torch.Size([256, 256])

In [20]:
# val_df['image'][0].split("/")[-2]

In [21]:
len(ids)

20529