In [9]:
import gc
import os
import sys
import warnings
from glob import glob
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from torch.cuda import amp
from tqdm import tqdm

In [10]:
sys.path.append('../')
from script.metrics import *
from script.dataset import *
from script.helper import *
from script.scheduler import *
from script.model import *
from script.loss import *

# Config

In [11]:
class CFG:
    debug = False
    # ============== comp exp name =============
    comp_name = 'contrail'
    comp_dir_path = '/kaggle/input/'
    comp_folder_name = 'google-research-identify-contrails-reduce-global-warming'

    dataset_path = "/kaggle/working/dataset_train/pseud_ashcolor_4label/"
    new_label_path = f"{dataset_path}/labels/model2_iter_0/"
    model_path = "/kaggle/working/notebook/experiment/v2/model30/model30/model30.pth"
    
    valid_batch_size = 32
    num_workers = 4
    
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]



warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = True
# set_seed(CFG.seed)
# os.makedirs(f'./{CFG.exp_name}/', exist_ok=True)
pd.set_option('display.max_rows', 500)
pd.options.display.max_colwidth = 300

# Pseudo Labelling

In [12]:
train_df = pd.read_csv(f"{CFG.dataset_path}/train_df.csv")
label_df = train_df[train_df["label_path"].isnull()]
if CFG.debug:
    label_df=label_df[:2000]

label_df

Unnamed: 0,record_id,image_path,time,label_path
0,1000216489776414077,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/1000216489776414077_0.npy,0,
1,1000216489776414077,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/1000216489776414077_1.npy,1,
2,1000216489776414077,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/1000216489776414077_2.npy,2,
3,1000216489776414077,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/1000216489776414077_3.npy,3,
5,1000216489776414077,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/1000216489776414077_5.npy,5,
...,...,...,...,...
164226,999815704182867427,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/999815704182867427_2.npy,2,
164227,999815704182867427,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/999815704182867427_3.npy,3,
164229,999815704182867427,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/999815704182867427_5.npy,5,
164230,999815704182867427,/kaggle/working/dataset_train/pseud_ashcolor_4label/images/999815704182867427_6.npy,6,


In [13]:
dataset_label = ContrailsDataset(label_df, CFG.valid_aug_list, "pseudo_labeling")
dataloader_label = DataLoader(dataset_label, batch_size=CFG.valid_batch_size, num_workers = CFG.num_workers)

In [14]:
model, dice_score, thresh=load_model(CFG.model_path)
model.eval()
print(f"{dice_score = :.4f}")
print(f"{thresh = }")

model_arch:  UnetPlusPlus
backbone:  timm-resnest101e
dice_score = 0.6702
thresh = 0.01


In [15]:
def pseudo_inference(dataloader_label, model, thresh, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    for i, (images, record_ids, times) in tqdm(enumerate(dataloader_label), total=len(dataloader_label)):
        images = images.cuda()
        with torch.no_grad():
            preds = model(images)
        preds = torch.sigmoid(preds).cpu().detach().numpy()
        preds_thresh = np.where(preds > thresh, 1, 0)

        for num in range(images.shape[0]):
            pred = preds_thresh[num, :, :, :].transpose(1, 2, 0)
            record_id = int(record_ids[num])
            time = int(times[num])
            save_path = f"{save_dir}/{record_id}_{time}.npy"
            np.save(save_path, pred)
            
pseudo_inference(dataloader_label, model, thresh, CFG.new_label_path)

100%|██████████| 4491/4491 [2:21:29<00:00,  1.89s/it]  
