# pre


In [None]:
from IPython.display import clear_output

In [None]:
!unzip ../input/tgs-salt-identification-challenge/train.zip -d ./train
clear_output()

In [None]:
!unzip ../input/tgs-salt-identification-challenge/test.zip -d ./test
clear_output()

**Move Directory**

In [None]:
from distutils.dir_util import copy_tree

fromDir = '../input/tgs-salt-identification-challenge'
toDir = './'

copy_tree(fromDir, toDir)

#  Library

In [None]:
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, dataloader, random_split

from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import wandb
import time
import copy

# login wandb
os.environ['WANDB_API_KEY'] = '2a291fe931b1a2be33e1c09cb5b86dcd6843ea48'

# CONFIG

In [None]:
# Model
RUN_NAME        = 'unetv1'
N_CLASSES       = 1
INPUT_SIZE      = 101
EPOCHS          = 100
LEARNING_RATE   = 0.002
START_FRAME     = 16
DROP_RATE       = 0.5

# Data
SAVE_PATH       = './'
DATA_PATH       = './'
IMAGE_PATH      = 'train/images/'
MASK_PATH       = 'train/masks/'

REAL_SIZE       = 101
RANDOM_SEED     = 42
VALID_RATIO     = 0.2
BATCH_SIZE      = 16
NUM_WORKERS     = 0
CLASSES         = {1:'salt'}

# Create dataset and dataloader train

In [None]:
class TGSDataset(Dataset):
    """TGS Salt Identification dataset."""
    
    def __init__(self, root_dir=DATA_PATH, transform=None):
        """
        Args:
            root_path (string): Directory with all the images.
            transformer (function): whether to apply the data augmentation scheme
                mentioned in the paper. Only applied on the train split.
        """

        # load dataset from root dir
        train_df  = pd.read_csv(root_dir+'train.csv', index_col='id')
        depths_df = pd.read_csv(root_dir+'depths.csv', index_col='id')
        train_df = train_df.join(depths_df)

        self.root_dir   = root_dir
        self.ids        = train_df.index
        self.depths     = train_df['z'].to_numpy()
        self.rle        = train_df['rle_mask'].to_numpy()
        
        if transform is None:
            self.transfrom = transforms.Compose([transforms.Grayscale(), 
                                                 transforms.ToTensor(),])
                                                 

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        id    = self.ids[index]
        depth = self.depths[index]

        # file should be unzipped
        image = Image.open(self.root_dir+IMAGE_PATH+id+'.png')
        mask  = Image.open(self.root_dir+MASK_PATH+id+'.png')
    
        image = self.transfrom(image)
        mask  = self.transfrom(mask)

        return image, mask

In [None]:
def get_dataloader(dataset, 
                    batch_size=BATCH_SIZE, random_seed=RANDOM_SEED, 
                    valid_ratio=VALID_RATIO, shuffle=True, num_workers=NUM_WORKERS):
    """
    Params:
    -------
    - dataset: the dataset.
    - batch_size: how many samples per batch to load.
    - random_seed: fix seed for reproducibility.
    - valid_ratio: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - num_workers: number of subprocesses to use when loading the dataset.
    """

    error_msg = "[!] valid_ratio should be in the range [0, 1]."
    assert ((valid_ratio >= 0) and (valid_ratio <= 1)), error_msg

    # split the dataset
    n = len(dataset)
    n_valid = int(valid_ratio*n)
    n_train = n - n_valid

    # init random seed
    torch.manual_seed(random_seed)

    train_dataset, valid_dataset = random_split(dataset, (n_train, n_valid))

    train_loader = DataLoader(train_dataset, batch_size, shuffle=shuffle, num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, valid_loader

In [None]:
# load dataset
dataset = TGSDataset(DATA_PATH)
trainloader, validloader = get_dataloader(dataset=dataset, valid_ratio=0.05)

In [None]:
def show_dataset(dataset, n_sample=4):
    """Visualize dataset with n_sample"""
    fig = plt.figure()

    # show image
    for i in range(n_sample):
        image, mask = dataset[i]
        image = transforms.ToPILImage()(image)
        mask = transforms.ToPILImage()(mask)
        print(i, image.size, mask.size)


        plt.tight_layout()
        ax = plt.subplot(1, n_sample, i + 1)
        ax.set_title('Sample #{}'.format(i))
        ax.axis('off')

        plt.imshow(image, cmap="Greys")
        plt.imshow(mask, alpha=0.3, cmap="OrRd")

        if i == n_sample-1:
            plt.show()
            break

In [None]:
show_dataset(dataset)

# Train and Eval Function