In [None]:
import gc
import os
import numpy as np
import pandas as pd

from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
from torch.optim import AdamW
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchmetrics.functional import dice

from transformers import get_cosine_with_hard_restarts_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger

import warnings
warnings.filterwarnings("ignore")

In [None]:
import sys
sys.path.append("../input/pretrained-models-pytorch")
sys.path.append("../input/efficientnet-pytorch")
sys.path.append("/kaggle/input/smp-github/segmentation_models.pytorch-master")
sys.path.append("/kaggle/input/timm-pretrained-resnest/resnest/")
import segmentation_models_pytorch as smp

In [None]:
class CFG:
    seed = 42
    epochs = 20
    train_bs = 48
    valid_bs = 128
    workers = 2
    image_size = 384
    num_warmup_steps = 350
    num_training_steps = 3150
    num_cycles = 1
    loss_smooth = 1.0
    lr = 0.0005
    weight_decay = 0.0
    train_folds = [0, 1, 2, 3]
    
    encoder_name = "timm-resnest26d"
    data_path = "/kaggle/input/my-data-for-contr"

In [None]:
class ContrailsDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_size=256, train=True):

        self.df = df
        self.trn = train
        self.normalize_image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.image_size = image_size
        if image_size != 256:
            self.resize_image = T.transforms.Resize(image_size)

    def __getitem__(self, index):
        row = self.df.iloc[index]
        con_path = row.path
        con = np.load(str(con_path))

        img = con[..., :-1]
        label = con[..., -1]

        label = torch.tensor(label)

        img = torch.tensor(np.reshape(img, (256, 256, 3))).to(torch.float32).permute(2, 0, 1)

        if self.image_size != 256:
            img = self.resize_image(img)

        img = self.normalize_image(img)

        return img.float(), label.float()

    def __len__(self):
        return len(self.df)

In [None]:
class LightningModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = smp.Unet(encoder_name=CFG.encoder_name,
                              encoder_weights="imagenet",
                              in_channels=3,
                              classes=1,
                              activation=None,
                             )
        
        self.loss_module = smp.losses.DiceLoss(mode="binary", smooth=CFG.loss_smooth)
        self.val_step_outputs = []
        self.val_step_labels = []

    def forward(self, batch):
        imgs = batch
        preds = self.model(imgs)
        return preds

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), 
                          lr=CFG.lr, 
                          weight_decay=CFG.weight_decay
                         )
        
        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,
                                                                       num_warmup_steps=CFG.num_warmup_steps,
                                                                       num_training_steps=CFG.num_training_steps,
                                                                       num_cycles=CFG.num_cycles
                                                                      )
        
        lr_scheduler_dict = {"scheduler": scheduler, "interval": "step"}
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_dict}

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        
        if CFG.image_size != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
            
        loss = self.loss_module(preds, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=16)

        for param_group in self.trainer.optimizers[0].param_groups:
            lr = param_group["lr"]
        self.log("lr", lr, on_step=True, on_epoch=False, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        if CFG.image_size != 256:
            preds = torch.nn.functional.interpolate(preds, size=256, mode='bilinear')
        loss = self.loss_module(preds, labels)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.val_step_outputs.append(preds)
        self.val_step_labels.append(labels)

    def on_validation_epoch_end(self):
        all_preds = torch.cat(self.val_step_outputs)
        all_labels = torch.cat(self.val_step_labels)
        all_preds = torch.sigmoid(all_preds)
        self.val_step_outputs.clear()
        self.val_step_labels.clear()
        val_dice = dice(all_preds, all_labels.long())
        self.log("val_dice", val_dice, on_step=False, on_epoch=True, prog_bar=True)
        if self.trainer.global_rank == 0:
            print(f"\nEpoch: {self.current_epoch}", flush=True)

In [None]:
def stratified_kfold_loaders(metadata, num_splits=5, random_state=42):

    skfold = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=random_state)

    for n, (train_indices, valid_indices) in enumerate(skfold.split(metadata, metadata["contrail"])):
        metadata.loc[valid_indices, "kfold"] = int(n)
    
    return metadata

In [None]:
pl.seed_everything(CFG.seed)
gc.enable()

contrails = CFG.data_path + "contrails/"
train_path = CFG.data_path + "train_df.csv"
valid_path = CFG.data_path + "valid_df.csv"

train_df = pd.read_csv(train_path)
valid_df = pd.read_csv(valid_path)

train_df['path'] = contrails + train_df['record_id'].astype(str) + '.npy'
valid_df['path'] = contrails + valid_df['record_id'].astype(str) + '.npy'

data = pd.concat([train_df, valid_df]).reset_index()

data = stratified_kfold_loaders(data, num_splits=4)
data['kfold'] = data['kfold'].astype(int)

for fold in CFG.train_folds:
    print(f"Fold {fold}")
    trn_df = data[data.kfold != fold].reset_index(drop=True)
    vld_df = data[data.kfold == fold].reset_index(drop=True)

    dataset_train = ContrailsDataset(trn_df, CFG.image_size, train=True)
    dataset_validation = ContrailsDataset(vld_df, CFG.image_size, train=False)

    data_loader_train = DataLoader(
        dataset_train,
        batch_size=CFG.train_bs,
        shuffle=True,
        num_workers=CFG.workers,
    )
    data_loader_validation = DataLoader(
        dataset_validation,
        batch_size=CFG.valid_bs,
        shuffle=False,
        num_workers=CFG.workers,
    )

    checkpoint_callback = ModelCheckpoint(
        save_weights_only=True,
        monitor="val_dice",
        dirpath="models",
        mode="max",
        filename=f"model-f{fold}-{{val_dice:.4f}}",
        save_top_k=1,
        verbose=1,
    )

    progress_bar_callback = TQDMProgressBar(refresh_rate=1)

    early_stop_callback = EarlyStopping(monitor="val_loss"
                                        mode="min"
                                        patience=999
                                        verbose=1
                                       )

    trainer = pl.Trainer(callbacks=[checkpoint_callback, early_stop_callback, progress_bar_callback],
                         logger=CSVLogger(save_dir=f'logs_f{fold}/'),
                         max_epochs=CFG.epochs,
                         min_epochs=CFG.epochs,
                         enable_progress_bar=True,
                         precision="16-mixed",
                         devices=2,
                        )

    model = LightningModule()

    trainer.fit(model, data_loader_train, data_loader_validation)

    del (dataset_train,
         dataset_validation,
         data_loader_train,
         data_loader_validation,
         model,
         trainer,
         checkpoint_callback,
         progress_bar_callback,
         early_stop_callback,
        )
    torch.cuda.empty_cache()
    gc.collect()