In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
import seaborn as sns

import os
import random

from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision import datasets, transforms
from utils import *
import albumentations as album
#import extractors
import segmentation_models_pytorch as smp
from tqdm.auto import tqdm as tq

sns.set(style='white', context='notebook', palette='deep')

# Preprocessing required

In [3]:
# Need to be run only one time
ENCODER = 'resnext101_32x4d'
ENCODER_WEIGHTS = 'ssl'
CLASSES = 25
ACTIVATION = 'softmax' # could be None for logits or 'softmax2d','softmax' for multiclass segmentation

# create segmentation model with pretrained encoder
model = smp.DeepLabV3(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=25, 
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

Downloading: "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth" to C:\Users\rkroc/.cache\torch\hub\checkpoints\semi_supervised_resnext101_32x4-dc43570a.pth


  0%|          | 0.00/169M [00:00<?, ?B/s]

In [None]:
# Resize the image to (256 x 256)
# CenterCrop it to (224 x 224)
# Convert it to Tensor – all the values in the image will be scaled so they lie between [0, 1]instead of the original, [0, 255] range.
# Normalize it with the Imagenet specific values where mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]

In [None]:
# helper function for data visualization
def visualize(**images):#plot images in a row
    n_images = len(images)
    plt.figure(figsize=(12,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label 23kjerbfds
def one_hot_encode(image,n_classes):
    x = F.one_hot(image,n_classes)
    return x
 
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2,0,1).astype('float32') # convert into tensor

def get_preprocessing(preprocessing_fn=None):
    _transform = [
        album.Lambda(image=preprocessing_fn),
        album.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return album.Compose(_transform)

In [None]:
# transform = transforms.Compose([transforms.ToTensor()])

class BackgroundDataset(torch.utils.data.Dataset):
    def __init__(
            self,path_m,path_i 
            augmentation=None, 
            preprocessing=None,
    ):
        self.path_m = mask_path
        self.path_i = img_path
        self.name = os.listdir(os.path.join(path, 'train_masks'))
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, idx):
        
        # read images and masks
        mask_name = self.name[idx]
        mask_path = os.path.join(self.path_m,'train_masks',mask_name)
        img_path = os.path.join(self.path_i,'train_images',mask_name.replace('png','jpg'))

        
#         image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
#         mask = cv2.cvtColor(cv2.imread(mask_path),0)
        image = Image.open(img_path)
        image = image.resize((256,256),Image.ANTIALIAS)
        mask = Image.open(mask_path)
        mask = mask.resize((256,256),Image.ANTIALIAS)
#         image = keep_image_size_open(img_path)
#         mask = keep_mask_size_open(mask_path)
        
        image = np.asarray(image).astype('int32')
        mask = np.asarray(mask).astype('int32')
        mask = np.where(mask<=24,mask,0) # removing every above classes
        
        
   
    
        #one-hot-encode the mask  
        mask = torch.from_numpy(mask).to(torch.int64)
        mask = one_hot_encode(mask,25)
        mask = np.asarray(mask).astype('int32')
        

        
         # preprocessing applied only on numpy array image
        sample = self.preprocessing(image=image, mask=mask)
        image, mask = sample['image'], sample['mask']
        
            
        return image,mask
        
    def __len__(self):
        return len(self.name)
    
if __name__ == '__main__':
    data = BackgroundDataset('',preprocessing=get_preprocessing(preprocessing_fn))
    check_image = data[219][0] # checking for the random 100th image
    check_mask = data[219][1]
    print(check_image.shape,check_mask.shape)
    print(check_image.dtype,check_mask.dtype)
    print(len(data))

In [None]:
bs = 3
nw = 0
# Splitting into Train and Val
full_dataset = BackgroundDataset('',preprocessing=get_preprocessing(preprocessing_fn))
train_size = int(0.9 * len(full_dataset))
val_size   = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Creating  data_loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=bs,num_workers=nw,shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs,num_workers=nw,shuffle=True)


In [None]:
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

# Set num of epochs
EPOCHS = 20

# Set device: `cuda` or `cpu`
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

# define optimizer
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0005),
])

# define learning rate scheduler (not used in this NB)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-5,
)

# load best saved model checkpoint from previous commit (if present)
# if os.path.exists('../input/pyramid-scene-parsing-pspnet-resnext50-pytorch/best_model.pth'):
#     model = torch.load('../input/pyramid-scene-parsing-pspnet-resnext50-pytorch/best_model.pth', map_location=DEVICE)

In [None]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
%%time

if TRAINING:

    best_iou_score = 0.0
    train_logs_list, valid_logs_list = [], []

    for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(val_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model.state_dict(), 'DEEPMIND_model.pt')
            print('Model saved!')