In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np
import pandas as pd
import cv2

import matplotlib.pyplot as plt

import os
import random
import tqdm

from IPython.display import clear_output

import gc

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True

In [None]:
!pip3 install -U git+https://github.com/albu/albumentations --no-cache-dir
!pip3 install pytorch-lightning==0.8.5
!pip3 install segmentation-models-pytorch
clear_output()

In [None]:
import albumentations as albu
from albumentations.pytorch.transforms import ToTensorV2

import pytorch_lightning as pl
from pytorch_lightning.logging import TensorBoardLogger

import segmentation_models_pytorch as smp

In [None]:
!pip3 install pytorch_toolbelt==0.3.2
clear_output()

In [None]:
from pytorch_toolbelt import losses as L
from pytorch_toolbelt.losses import DiceLoss, FocalLoss
import pytorch_toolbelt

In [None]:
!unzip ../input/tgs-salt-identification-challenge/train.zip
clear_output()

In [None]:
root_dir        = '/kaggle/input/tgs-salt-identification-challenge/'
IMAGE_PATH      = './images/'
MASK_PATH       = './masks/'
IMAGE_PATH_TEST = './competition_data/test/images/'

# **Transforms**

In [None]:
def get_val_augs(height=128, width=128):

    train_transform = [
        albu.Resize(height=height, width=width, interpolation=cv2.INTER_CUBIC, always_apply=True),
        albu.Normalize(),
        ToTensorV2(),
    ]

    return albu.Compose(train_transform)


def get_train_augs(height=128, width=128):

    train_transform = [
        albu.Resize(height=height, width=width, interpolation=cv2.INTER_CUBIC, always_apply=True),
        albu.HorizontalFlip(p=0.5),
        albu.VerticalFlip(p=0.5),
        albu.OpticalDistortion(interpolation=cv2.INTER_CUBIC, border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        albu.ShiftScaleRotate(shift_limit=0.01, scale_limit=(-0.15, 0.25), rotate_limit=5,
                              interpolation=cv2.INTER_CUBIC, border_mode=cv2.BORDER_REFLECT_101,
                              p=0.5),
#         albu.ToGray(),
        albu.Normalize(),
        ToTensorV2(),
    ]

    return albu.Compose(train_transform) 

# **Dataset**

In [None]:
def stratify_dataset(df, columns, n_fold, ds_path=None):
    folds = df.copy()
    folds["fold"] = (list(range(n_fold)) * folds.shape[0])[:folds.shape[0]]
    if ds_path is not None:
        folds.to_csv(str(ds_path), index=False)
    return folds

In [None]:
def split_by_fold(folds, val_fold_n):
    folds = folds.copy()
    train_df = folds[~folds.fold.isin([val_fold_n])]
    val_df = folds[folds.fold.isin([val_fold_n])]
    return train_df, val_df

In [None]:
class SaltDataset(Dataset):
    def __init__(self, df, transforms=None, phase='train'):
        
        self.root_dir   = root_dir
        self.ids        = df.index
#         self.depths     = df['z'].to_numpy()
#         self.rle        = df['rle_mask'].to_numpy()
        self.transforms = transforms
        self.phase = phase

    def __getitem__(self, index: int):
        id    = self.ids[index]
        if self.phase=='train':
            image = cv2.imread(str(IMAGE_PATH+id+'.png'), cv2.IMREAD_COLOR)
            mask = cv2.imread(str(MASK_PATH+id+'.png'), cv2.IMREAD_COLOR)
            mask = np.all(mask == (255,255,255), axis = 2).astype(np.uint8) * 1
            if self.transforms is not None:
                data = self.transforms(image=image, mask=mask)
                image = data['image']
                mask = data['mask']
            return image, mask
        else:
            image = cv2.imread(str(IMAGE_PATH_TEST+id+'.png'), cv2.IMREAD_COLOR)
            if self.transforms is not None:
                data = self.transforms(image=image)
                image = data['image']
            return image
    def __len__(self):  # return count of sample we have
        return len(self.ids)

In [None]:
df  = pd.read_csv(root_dir+'train.csv', index_col='id')
depths_df = pd.read_csv(root_dir+'depths.csv', index_col='id')
df = df.join(depths_df)

df = stratify_dataset(df, ['id', 'rle_mask', 'z'], 5)

train_df, val_df = split_by_fold(df, 0)

train_transforms = get_train_augs()
val_transforms   = get_val_augs()

train_dataset = SaltDataset(df=train_df, transforms=train_transforms)
valid_dataset = SaltDataset(df=val_df, transforms=val_transforms)

batch_size  = 16
num_workers = 2
    
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)

In [None]:
x, y = next(iter(train_loader))
print(x.shape, y.shape)
plt.imshow(x[0].permute(1, 2, 0).numpy())
plt.show()
plt.imshow(y[0])
plt.show()

# **Metrics**

In [None]:
def iou_torch(
    preds, target, thresh=0.5):
    with torch.no_grad():
        smooth = 1e-6
        
        target = target.byte()
        preds = preds.squeeze(1)
        preds = (torch.sigmoid(preds) > thresh).byte()

        intersection = (preds & target).float().sum(dim=(1, 2))
        union = (preds | target).float().sum(dim=(1, 2))

        iou = (intersection + smooth) / (union + smooth)

        return iou.mean()

# **Model**

In [None]:
class SegModel(pl.LightningModule):
    def __init__(self):
        super(SegModel, self).__init__()
        self.batch_size = 16
        self.learning_rate = 1e-3
        self.num_workers = 2
        self.net = smp.Unet("resnet34", 
                 encoder_weights = "imagenet", 
                 in_channels = 3,
                 classes = 1,
                 activation = None)
        
        df  = pd.read_csv(root_dir+'train.csv', index_col='id')
        depths_df = pd.read_csv(root_dir+'depths.csv', index_col='id')
        df = df.join(depths_df)

        df = stratify_dataset(df, ['id', 'rle_mask', 'z'], 5)

        train_df, val_df = split_by_fold(df, 0)

        train_transforms = get_train_augs()
        val_transforms   = get_val_augs()

        self.train_dataset = SaltDataset(df=train_df, transforms=train_transforms)
        self.valid_dataset = SaltDataset(df=val_df, transforms=val_transforms)
        
        self.loss = DiceLoss(mode = 'binary', log_loss = False)
        self.losses_log = {
            'train': [],
            'val': []
        }
        self.metrics_log = {
            'train': [],
            'val': []
        }

    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch, batch_nb) :
        x, y = batch
        y_hat = self.forward(x)
        losses = self.loss(y_hat, y)
        metrics = iou_torch(y_hat, y)
        torch.cuda.empty_cache()
        gc.collect
        
        return {'loss' : losses.mean(),
                'metric': metrics.mean()}
    
    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self(x.float())
        losses = self.loss(y_hat, y)
        metrics = iou_torch(y_hat, y)
        torch.cuda.empty_cache()
        gc.collect

        return {'val_loss' : losses.mean(),
                'val_metric': metrics.mean()}
    
    def training_epoch_end(self, outputs):
        # OPTIONAL

        avg_losses = torch.tensor([x['loss'] for x in outputs]).mean()
        self.losses_log['train'].append(avg_losses)
        
        avg_metrics = torch.tensor([x['metric'] for x in outputs]).mean()
        self.metrics_log['train'].append(avg_metrics)
        
        print('epoch: %.0f | phase: train | loss: %.3f | metric: %.3f |'% (self.current_epoch, avg_losses, avg_metrics))
        
        tensorboard_logs = {'loss': avg_losses,
                            'metric' : avg_metrics}
        
        return {'avg_loss': avg_losses,
                'avg_metric': avg_metrics,
                'log': tensorboard_logs}
    
    def validation_epoch_end(self, outputs):
        # OPTIONAL
        avg_losses = torch.tensor([x['val_loss'] for x in outputs]).mean()
        self.losses_log['val'].append(avg_losses)
        
        avg_metrics = torch.tensor([x['val_metric'] for x in outputs]).mean()
        self.metrics_log['val'].append(avg_metrics)
        
    
        print('epoch: %.0f | phase: val | loss: %.3f | metric: %.3f |'% (self.current_epoch, avg_losses, avg_metrics))
        
        tensorboard_logs = {'val_loss': avg_losses,
                            'val_metric' : avg_metrics}
        
        return {'avg_val_loss': avg_losses,
                'avg_val_metric': avg_metrics,
                'log': tensorboard_logs}
    
    
    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(),
                               lr = self.learning_rate)
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 
                                                         mode = "min",
                                                         factor = 0.3,
                                                         patience = 5,
                                                         verbose = True)
        return [opt], [sch]
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = True,
                                           num_workers = self.num_workers)
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.valid_dataset,
                                           batch_size = self.batch_size,
                                           shuffle = False,
                                           num_workers = self.num_workers)

In [None]:
model = SegModel()

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath = 'top_model_weights.ckpt',
    save_top_k = 1,
    verbose = True, 
    monitor = 'avg_val_metric',
    mode = 'max')

In [None]:
logger = TensorBoardLogger(
    "lightning_logs")

In [None]:
trainer = pl.Trainer(gpus = 1, 
                     max_epochs = 250, 
                     checkpoint_callback = checkpoint_callback,
                     early_stop_callback = None,
                     logger = logger,
                     show_progress_bar = True)

# **Fit**

In [None]:
trainer.fit(model)

In [None]:
best_path = checkpoint_callback.best_model_path
print(best_path)
pred_model = SegModel.load_from_checkpoint(best_path)
pred_model.eval()

In [None]:
!unzip ../input/tgs-salt-identification-challenge/competition_data.zip
clear_output()

In [None]:
def rle_encode(im):
    
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [None]:
image_path = "./competition_data/test/images/"
sub_df = pd.read_csv('./competition_data/sample_submission.csv')
n = sub_df.shape[0]

rle_mask = []
for idx in range(n):
    
    sample_name = sub_df['id'][idx]
    image = cv2.imread(str(image_path+sample_name+'.png'), cv2.IMREAD_COLOR)
    image = val_transforms(image=image)['image']
    
    pred = pred_model(image.unsqueeze(0))
    pred = (torch.sigmoid(pred) > 0.5).byte()
    
    rle_mask.append(rle_encode(pred.squeeze().numpy()))
    print("\rprogress {}/{}".format(idx+1, n), end = "")
    
sub_df['rle_mask'] = rle_mask

In [None]:
sub_df.to_csv('submission.csv', index = False)