In [None]:
# !pip install -q pytorch_toolbelt
# !pip install -q torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
# !pip install -q git+https://github.com/qubvel/segmentation_models.pytorch
# !pip install -q pycocotools
# !pip install -q cython
# !pip install -q git+https://github.com/lucasb-eyer/pydensecrf.git

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import os
import random
import time
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast
import cv2

import numpy as np
import pandas as pd

# Python package for pre-processing 
from pycocotools.coco import COCO
import torchvision
import torchvision.transforms as transforms
from pytorch_toolbelt import losses as L
import segmentation_models_pytorch as smp

import albumentations as A
from albumentations.pytorch import ToTensor

# Python package for visualization
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
plt.rcParams['axes.grid'] = False

import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral
from skimage.color import gray2rgb
from skimage.color import rgb2gray

print('Pytorch version: {}'.format(torch.__version__))
print('Is GPU available: {}'.format(torch.cuda.is_available()))
if torch.cuda.is_available():
  print(torch.cuda.get_device_name(0))
  print('The number of GPUs available: {}'.format(torch.cuda.device_count()))
device = "cuda" if torch.cuda.is_available() else "cpu" 

print('CPU count: {}'.format(os.cpu_count()))  # 2

Pytorch version: 1.8.0
Is GPU available: True
Tesla V100-SXM2-16GB
The number of GPUs available: 1
CPU count: 4


In [None]:
%matplotlib inline

dataset_path = '/content/drive/MyDrive/segment/data'
anns_file_path = os.path.join(dataset_path, 'test.json')

# Read annotations
with open(anns_file_path, 'r') as f:
    dataset = json.loads(f.read())

categories = dataset['categories']
anns = dataset['annotations']
imgs = dataset['images']
nr_cats = len(categories)
nr_annotations = len(anns)
nr_images = len(imgs)

# Load categories and super categories
cat_names = []
super_cat_names = []
super_cat_ids = {}
super_cat_last_name = ''
nr_super_cats = 0
for cat_it in categories:
    cat_names.append(cat_it['name'])
    super_cat_name = cat_it['supercategory']
    # Adding new supercat
    if super_cat_name != super_cat_last_name:
        super_cat_names.append(super_cat_name)
        super_cat_ids[super_cat_name] = nr_super_cats
        super_cat_last_name = super_cat_name
        nr_super_cats += 1

print('Number of super categories:', nr_super_cats)
print('Number of categories:', nr_cats)
print('Number of annotations:', nr_annotations)
print('Number of images:', nr_images)

# Count annotations
cat_histogram = np.zeros(nr_cats,dtype=int)
for ann in anns:
    cat_histogram[ann['category_id']] += 1

# Convert to DataFrame
df = pd.DataFrame({'Categories': cat_names, 'Number of annotations': cat_histogram})
df = df.sort_values('Number of annotations', 0, False)

sorted_temp_df = df.sort_index()
sorted_df = pd.DataFrame(["Backgroud"], columns = ["Categories"])
sorted_df = sorted_df.append(sorted_temp_df, ignore_index=True)
sorted_df

In [None]:
category_names = list(sorted_df.Categories)
def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

class CustomDataset(Dataset):
    """COCO format"""
    def __init__(self, data_dir, mode = 'train', transform = None):
        super().__init__()
        self.mode = mode
        self.transform = transform
        self.coco = COCO(data_dir)
        
    def __getitem__(self, index: int):
        # Get the image_info using coco library
        image_id = self.coco.getImgIds(imgIds=index)
        image_infos = self.coco.loadImgs(image_id)[0]

        # Load the image using opencv
        images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
        images /= 255.0
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
            anns = self.coco.loadAnns(ann_ids)
            # print("image_infos['id'] : {}".format(image_infos['id']) )
            # Load the categories in a variable
            cat_ids = self.coco.getCatIds()
            cats = self.coco.loadCats(cat_ids)

            # masks_size : height x width            
            masks = np.zeros((image_infos["height"], image_infos["width"]), dtype=np.float32)
  
            # Background = 0, Unknown = 1, General trash = 2, ... , Cigarette = 11
            for i in range(len(anns)):
                className = get_classname(anns[i]['category_id'], cats)
                pixel_value = category_names.index(className)
                masks = np.maximum(self.coco.annToMask(anns[i])*pixel_value, masks)
            

            # We can use Albumentations for image & mask transformation(or augmentation)
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
                masks =  masks.squeeze()
            
            return images, masks, image_infos
        
        if self.mode == 'test':            
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            
            return images, image_infos
    
    
    def __len__(self) -> int:        
        return len(self.coco.getImgIds())

In [None]:
test_path = os.path.join(dataset_path, 'test.json')

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

test_transform = A.Compose([
    ToTensor(),
])

test_tta_transform = A.Compose([
    A.HorizontalFlip(p=1),
    ToTensor(),
])

test_vtta_transform = A.Compose([
    A.VerticalFlip(p=1),
    ToTensor(),
])

# test dataset
test_dataset = CustomDataset(data_dir=test_path, mode='test', transform=test_transform)
test_tta_dataset = CustomDataset(data_dir=test_path, mode='test', transform=test_tta_transform)
test_vtta_dataset = CustomDataset(data_dir=test_path, mode='test', transform=test_vtta_transform)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=3,
                                          pin_memory=True,
                                          num_workers=4,
                                          collate_fn=collate_fn)

test_tta_loader = torch.utils.data.DataLoader(dataset=test_tta_dataset,
                                          batch_size=3,
                                          pin_memory=True,
                                          num_workers=4,
                                          collate_fn=collate_fn)

test_vtta_loader = torch.utils.data.DataLoader(dataset=test_vtta_dataset,
                                          batch_size=3,
                                          pin_memory=True,
                                          num_workers=4,
                                          collate_fn=collate_fn)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [None]:
models = list()

In [None]:
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_eff3.pt'
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/40ep_best_model_resnext50_320.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-efficientnet_b3', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[0].load_state_dict(state_dict)

# switch to evaluation mode
models[0].eval()
print('')




In [None]:
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_seresnext50.pt'
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/40ep_best_model_resnext50_320.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-seresnext50_32x4d', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[1].load_state_dict(state_dict)

# switch to evaluation mode
models[1].eval()
print('')




In [None]:
# path of saved best model
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_eff3_dsize.pt'
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/40ep_best_model_eff3_480.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-seresnext50_32x4d', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[2].load_state_dict(state_dict)

# switch to evaluation mode
models[2].eval()
print('')




In [None]:
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_eff3_tmp.pt'
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/40ep_best_model_resnext50_480.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-efficientnet_b3', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[3].load_state_dict(state_dict)

# switch to evaluation mode
models[3].eval()
print('')




In [None]:
# path of saved best model
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_eff3_sampler_320.pt'
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/40ep_best_model_eff3_320.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-efficientnet_b3', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[4].load_state_dict(state_dict)

# switch to evaluation mode
models[4].eval()
print('')




In [None]:
# model_path = '/content/drive/MyDrive/segment/saved/realfinal/finalep_best_model_wresnet50_320.pt'
model_path = '/content/drive/MyDrive/segment/saved/realrealfinal/finalep_best_model_eff3_sampler.pt'

# initialize the model
models.append(smp.DeepLabV3Plus(encoder_name='tu-efficientnet_b3', classes=12, encoder_weights='imagenet').to(device))
# load the saved best model
checkpoint = torch.load(model_path, map_location=device)
state_dict = checkpoint.state_dict()
models[5].load_state_dict(state_dict)

# switch to evaluation mode
models[5].eval()
print('')




In [None]:
for imgs, image_infos in test_tta_loader:
    image_infos = image_infos
    temp_images = imgs

    models[0].eval()
    # inference
    outs = models[0](torch.stack(temp_images).to(device))
    oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()
    
    break

i = 0
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 16))

print('Shape of Original Image :', list(temp_images[i].shape))
print('Shape of Predicted : ', list(oms[i].shape))
print('Unique values, category of transformed mask : \n', [{int(i),category_names[int(i)]} for i in list(np.unique(oms[i]))])

# Original image
ax1.imshow(temp_images[i].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("Original image : {}".format(image_infos[i]['file_name']), fontsize = 15)

# Predicted mask
ax2.imshow(oms[i])
ax2.grid(False)
ax2.set_title("Predicted : {}".format(image_infos[i]['file_name']), fontsize = 15)

plt.show()

In [None]:
"""
Function which returns the labelled image after applying CRF

"""
#Original_image = Image which has to labelled
#Mask image = Which has been labelled by some technique..
def crf(original_image, mask_img, labelmap):

    # Converting annotated image to RGB if it is Gray scale
    if(len(mask_img.shape)<3):
        mask_img = gray2rgb(mask_img)

#     #Converting the annotations RGB color to single 32 bit integer
    annotated_label = mask_img[:,:,0] + (mask_img[:,:,1]<<8) + (mask_img[:,:,2]<<16)

#     # Convert the 32bit integer color to 0,1, 2, ... labels.
    colors, labels = np.unique(annotated_label, return_inverse=True)
    
    n_labels = len(labelmap)
    #Setting up the CRF model
    d = dcrf.DenseCRF2D(original_image.shape[1], original_image.shape[2], n_labels)

    # get unary potentials (neg log probability)
    U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
    d.setUnaryEnergy(U)

    # This adds the color-independent term, features are the locations only.
    d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,
                      normalization=dcrf.NORMALIZE_SYMMETRIC)
        
    #Run Inference for 10 steps 
    Q = d.inference(10)

    # Find out the most probable class for each pixel.
    MAP = np.argmax(Q, axis=0)
    
    for i in range(n_labels):
      MAP[MAP==i] = int(labelmap[i])
    return MAP.reshape((original_image.shape[1],original_image.shape[2]))

In [None]:
def sharpen(p,t=0.5):
        if t!=0:
            return p**t
        else:
            return p

In [None]:
t = 0
for imgs, image_infos in test_loader:
    t += 1
    image_infos = image_infos
    temp_images = imgs
    outs = torch.zeros((3, 12, 512, 512)).to(device)
    logits = torch.zeros((3, 12, 512, 512)).to(device)
    crf_output = list()
    for model in [models[1]]:
      model.eval()
      # inference
      logits = model(torch.stack(temp_images).to(device))
      outs += sharpen(torch.softmax(logits, dim=1))
      # outs += sharpen(model(torch.stack(temp_images).to(device)))
    outs /= len(models)
    condpseudo = torch.amax(outs.squeeze(), dim=1).detach().cpu().numpy()
    # print(condpseudo.mean(axis=(1,2)))
    oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()
    # print((condpseudo>0.9).sum(axis=(1,2)) / (320*320))
    for i in range(3):
      crf_output.append(crf(temp_images[i],oms[i], list(np.unique(oms[i]))))
    np.array(crf_output).reshape((3, 512, 512))
    if t==13:
      break

i = 2
fig, (ax1, ax2, ax3) = plt.subplots(nrows=1, ncols=3, figsize=(16, 16))

print('Shape of Original Image :', list(temp_images[i].shape))
print('Shape of Predicted : ', list(oms[i].shape))
print('Unique values, category of transformed mask : \n', [{int(i),category_names[int(i)]} for i in list(np.unique(oms[i]))])
print('Unique crf values, category of transformed mask : \n', [{int(i),category_names[int(i)]} for i in list(np.unique(crf_output[i]))])
# Original image
ax1.imshow(temp_images[i].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("Original image : {}".format(image_infos[i]['file_name']), fontsize = 15)

# Predicted mask
ax2.imshow(oms[i])
ax2.grid(False)
ax2.set_title("Predicted : {}".format(image_infos[i]['file_name']), fontsize = 15)

# CRF mask
ax3.imshow(crf_output[i])
ax3.grid(False)
ax3.set_title("Predicted crf : {}".format(image_infos[i]['file_name']), fontsize = 15)

plt.show()

In [None]:
def test(model, data_loader, device):
    size = 256
    transform = A.Compose([A.Resize(256, 256)])
    print('Start prediction.')
    for model in models:
      model.eval()
    
    file_name_list = []
    confidence_list = []
    over90_list = []
    over80_list = []
    over70_list = []
    preds_array = np.empty((0, size*size), dtype=np.long)
    
    with torch.no_grad():
        for step, ((imgs, image_infos), (himgs, himage_infos), (vimgs, vimage_infos)) in enumerate(zip(test_loader, test_tta_loader, test_vtta_loader)):
            preds = torch.zeros((3, 12, 512, 512)).to(device)
            # inference (320 x 320)
            for model in models:
              outs = model(torch.stack(imgs).to(device))
              houts = model(torch.stack(himgs).to(device))
              vouts = model(torch.stack(vimgs).to(device))
              ttaouts = torch.from_numpy(houts.detach().cpu().numpy()[:, :, :, ::-1].copy()).to(device)
              vttaouts = torch.from_numpy(vouts.detach().cpu().numpy()[:, :, ::-1, :].copy()).to(device)
              preds += (sharpen(torch.softmax(outs, dim=1))+sharpen(torch.softmax(ttaouts, dim=1))+sharpen(torch.softmax(vttaouts, dim=1)))/3
            preds /= len(models)
            oms = torch.argmax(preds.squeeze(), dim=1).detach().cpu().numpy()
            condpseudo = torch.amax(preds.squeeze(), dim=1).detach().cpu().numpy()
            confidence = condpseudo.mean(axis=(1,2))
            over90 = (condpseudo>0.9).sum(axis=(1,2)) / (512*512)
            over80 = (condpseudo>0.8).sum(axis=(1,2)) / (512*512)
            over70 = (condpseudo>0.7).sum(axis=(1,2)) / (512*512)

            confidence_list.extend(confidence)
            over90_list.extend(over90)
            over80_list.extend(over80)
            over70_list.extend(over70)


            # resize (256 x 256)
            temp_mask = []
            for img, mask in zip(np.stack(temp_images), oms):
                transformed = transform(image=img, mask=mask)
                mask = transformed['mask']
                temp_mask.append(mask)

            oms = np.array(temp_mask)
            
            oms = oms.reshape([oms.shape[0], size*size]).astype(int)
            preds_array = np.vstack((preds_array, oms))
            
            file_name_list.append([i['file_name'] for i in image_infos])
    print("End prediction.")
    file_names = [y for x in file_name_list for y in x]
    
    return file_names, preds_array, confidence_list, over90_list, over80_list, over70_list

In [None]:
# inference
file_names, preds, conf, o90, o80, o70 = test(models, test_loader, device)
submission = pd.DataFrame()

submission['image_id'] = file_names
submission['PredictionString'] = [' '.join(str(e) for e in string.tolist()) for string in preds]

# save submission.csv
submission.to_csv("submission.csv", index=False)

Start prediction.
End prediction.


In [None]:
submission['conf'] = conf
submission['o90'] = o90
submission['o80'] = o80
submission['o70'] = o70
submission.describe()

In [None]:
condidx = submission[(submission['conf']>0.955) & (submission['o90']>0.85)].index
preds = preds[condidx]

In [None]:
category_names = list(sorted_df.Categories)
def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

class CustomDataset(Dataset):
    """COCO format"""
    def __init__(self, data_dir, mode = 'train', transform = None, preds = preds):
        super().__init__()
        self.mode = mode
        self.transform = transform
        self.coco = COCO(data_dir)
        self.preds = preds

    def __getitem__(self, index: int):
        # Get the image_info using coco library
        image_id = self.coco.getImgIds(imgIds=index)
        image_infos = self.coco.loadImgs(image_id)[0]

        # Load the image using opencv
        images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
        images /= 255.0
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
            anns = self.coco.loadAnns(ann_ids)
            # print("image_infos['id'] : {}".format(image_infos['id']) )
            # Load the categories in a variable
            cat_ids = self.coco.getCatIds()
            cats = self.coco.loadCats(cat_ids)

            # masks_size : height x width            
            masks = np.zeros((image_infos["height"], image_infos["width"]), dtype=np.float32)
  
            # Background = 0, Unknown = 1, General trash = 2, ... , Cigarette = 11
            for i in range(len(anns)):
                className = get_classname(anns[i]['category_id'], cats)
                pixel_value = category_names.index(className)
                masks = np.maximum(self.coco.annToMask(anns[i])*pixel_value, masks)

            # We can use Albumentations for image & mask transformation(or augmentation)
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
                masks =  masks.squeeze()
            
            return images, masks, image_infos
        
        if self.mode == 'test':
            if self.transform is not None:
                masks = np.array(preds[index])
                masks = masks.reshape((256, 256))
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
                masks =  masks.squeeze()

            return images, masks, image_infos
    
    
    def __len__(self) -> int:
      if self.mode == 'train':
        return len(self.coco.getImgIds())
      elif self.mode == 'test':
        return len(preds)

In [None]:
# COCO(test_path).getImgIds()

In [None]:
test_path = os.path.join(dataset_path, 'test.json')
train_all_path = os.path.join(dataset_path, 'train_all.json')

# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

train_transform = A.Compose([  
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.9),
    A.RandomGamma(p=0.3),
    A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=45, shift_limit=0.1, p=1),
    # A.OneOf([
    #   A.RandomContrast(limit=0.1),
    #   A.RandomGamma(),
    #   A.RandomBrightness(limit=0.1),
    #   ], p=0.9),
    A.Cutout(num_holes=1, max_h_size=60, max_w_size=60),
    ToTensor(),
])

test_transform = A.Compose([  
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    ToTensor(),
])

# test dataset
train_all_dataset = CustomDataset(data_dir=train_all_path, mode='train', transform=train_transform)
test_dataset = CustomDataset(data_dir=test_path, mode='test', transform=test_transform)
dataset = torch.utils.data.ConcatDataset([train_all_dataset, test_dataset])


train_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                           batch_size=8,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=4,
                                           drop_last=True,
                                           collate_fn=collate_fn)

# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                           batch_size=8,
#                                           pin_memory=True,
#                                           num_workers=4,
#                                           collate_fn=collate_fn)

# test_tta_loader = torch.utils.data.DataLoader(dataset=test_tta_dataset,
#                                           batch_size=3,
#                                           pin_memory=True,
#                                           num_workers=4,
#                                           collate_fn=collate_fn)

# test_vtta_loader = torch.utils.data.DataLoader(dataset=test_vtta_dataset,
#                                           batch_size=3,
#                                           pin_memory=True,
#                                           num_workers=4,
#                                           collate_fn=collate_fn)

loading annotations into memory...
Done (t=13.34s)
creating index...
index created!
loading annotations into memory...
Done (t=0.06s)
creating index...
index created!


In [None]:
for imgs, masks, image_infos in train_loader:
    image_infos = image_infos[0]
    temp_images = imgs
    temp_masks = masks
    break

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))

print('image shape:', list(temp_images[0].shape))
print('mask shape: ', list(temp_masks[0].shape))
print('Unique values, category of transformed mask : \n', [{int(i),category_names[int(i)]} for i in list(np.unique(temp_masks[0]))])

ax1.imshow(temp_images[0].permute([1,2,0]))
ax1.grid(False)
ax1.set_title("input image : {}".format(image_infos['file_name']), fontsize = 15)

ax2.imshow(temp_masks[0])
ax2.grid(False)
ax2.set_title("masks : {}".format(image_infos['file_name']), fontsize = 15)

plt.show()

In [None]:
def train(num_epochs, model, data_loader, criterion,  optimizer, scheduler, saved_dir, i, device):
  print('Start training..')
  best_loss = 9999999
  for epoch in range(num_epochs):
      model.train()
      trn_mIoU = []
      trn_acc = []
      model.zero_grad()       
      for step, (images, masks, _) in enumerate(data_loader):
          images = torch.stack(images)       # (batch, channel, height, width)
          masks = torch.stack(masks).long()  # (batch, height, width)
          masks_tensor = masks.view(images.shape[0], 1, images.shape[2], images.shape[3])
          zeros = torch.zeros(images.shape[0], 12, images.shape[2], images.shape[3], dtype=masks.dtype)
          masks = zeros.scatter_(1, masks_tensor, 1).to(device) 
          images, masks = images.to(device), masks.to(device)

          outputs = model(images)
          loss = criterion(outputs, masks) 
          loss = loss / 4
          loss.backward()
          torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
          if (step+1) % 4 == 0:             
            optimizer.step()                            
            model.zero_grad()    

          outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()
          masks = torch.argmax(masks.squeeze(), dim=1).detach().cpu().numpy()
          res = label_accuracy_score(masks, outputs, n_class=12)
          # tmIoU_list, tb_mIoU = mIoU(outputs, masks, smooth=1e-10, n_classes=12)
          # acc = pixel_accuracy(outputs, masks)
          trn_mIoU.append(res[2])
          trn_acc.append(res[0])
          
          # print the loss at 20 step intervals.
          if (step + 1) % 20 == 0:
              print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, lr: {}'.format(
                  epoch+1, num_epochs, step+1, len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
      print('Epoch {} - mIoU: {:.4f}, acc: {:.4f}'.format(epoch+1, np.mean(trn_mIoU), np.mean(trn_acc)))
      scheduler.step()
  save_model(model, saved_dir, i)

In [None]:
# define the evaluation function
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
import numpy as np

def _fast_hist(label_true, label_pred, n_class):
    mask = (label_true >= 0) & (label_true < n_class)
    hist = np.bincount(
        n_class * label_true[mask].astype(int) +
        label_pred[mask], minlength=n_class ** 2).reshape(n_class, n_class)
    return hist


def label_accuracy_score(label_trues, label_preds, n_class):
    """Returns accuracy score evaluation result.
      - overall accuracy
      - mean accuracy
      - mean IU
      - fwavacc
    """
    hist = np.zeros((n_class, n_class))
    for lt, lp in zip(label_trues, label_preds):
        hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
    acc = np.diag(hist).sum() / hist.sum()
    with np.errstate(divide='ignore', invalid='ignore'):
        acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    with np.errstate(divide='ignore', invalid='ignore'):
        iu = np.diag(hist) / (
            hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)
        )
    mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
    return acc, acc_cls, mean_iu, fwavacc

In [None]:
saved_dir = '/content/drive/MyDrive/segment/saved/pseudo'
if not os.path.isdir(saved_dir):                                                           
    os.mkdir(saved_dir)
    
def save_model(model, saved_dir, i, file_name='best_model_withcond.pt'):
    file_name = str(i) + file_name
    check_point = {'net': model.state_dict()}
    output_path = os.path.join(saved_dir, file_name)
    torch.save(model, output_path)

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

try:
    from itertools import ifilterfalse
except ImportError:  # py3k
    from itertools import filterfalse

eps = 1e-6


def dice_round(preds, trues):
    preds = preds.float()
    return soft_dice_loss(preds, trues)


def soft_dice_loss(outputs, targets, per_image=False):
    batch_size = outputs.size()[0]
    eps = 1e-5
    if not per_image:
        batch_size = 1
    dice_target = targets.contiguous().view(batch_size, -1).float()
    dice_output = outputs.contiguous().view(batch_size, -1)
    intersection = torch.sum(dice_output * dice_target, dim=1)
    union = torch.sum(dice_output, dim=1) + torch.sum(dice_target, dim=1) + eps
    loss = (1 - (2 * intersection + eps) / union).mean()

    return loss


def jaccard(outputs, targets, per_image=False, non_empty=False, min_pixels=5):
    batch_size = outputs.size()[0]
    eps = 1e-3
    if not per_image:
        batch_size = 1
    dice_target = targets.contiguous().view(batch_size, -1).float()
    dice_output = outputs.contiguous().view(batch_size, -1)
    target_sum = torch.sum(dice_target, dim=1)
    intersection = torch.sum(dice_output * dice_target, dim=1)
    losses = 1 - (intersection + eps) / (torch.sum(dice_output + dice_target, dim=1) - intersection + eps)
    if non_empty:
        assert per_image == True
        non_empty_images = 0
        sum_loss = 0
        for i in range(batch_size):
            if target_sum[i] > min_pixels:
                sum_loss += losses[i]
                non_empty_images += 1
        if non_empty_images == 0:
            return 0
        else:
            return sum_loss / non_empty_images

    return losses.mean()


class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, per_image=False):
        super().__init__()
        self.size_average = size_average
        self.register_buffer('weight', weight)
        self.per_image = per_image

    def forward(self, input, target):
        return soft_dice_loss(input, target, per_image=self.per_image)


class JaccardLoss(nn.Module):
    def __init__(self, weight=None, size_average=True, per_image=False, non_empty=False, apply_sigmoid=False,
                 min_pixels=5):
        super().__init__()
        self.size_average = size_average
        self.register_buffer('weight', weight)
        self.per_image = per_image
        self.non_empty = non_empty
        self.apply_sigmoid = apply_sigmoid
        self.min_pixels = min_pixels

    def forward(self, input, target):
        if self.apply_sigmoid:
            input = torch.sigmoid(input)
        return jaccard(input, target, per_image=self.per_image, non_empty=self.non_empty, min_pixels=self.min_pixels)


class StableBCELoss(nn.Module):
    def __init__(self):
        super(StableBCELoss, self).__init__()

    def forward(self, input, target):
        input = input.float().contiguous().view(-1)
        target = target.float().contiguous().view(-1)
        neg_abs = - input.abs()
        # todo check correctness
        loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
        return loss.mean()

class FocalLoss2d(nn.Module):
    def __init__(self, gamma=2, ignore_index=255):
        super().__init__()
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, outputs, targets):
        outputs = outputs.contiguous()
        targets = targets.contiguous()
        eps = 1e-8
        non_ignored = targets.contiguous().view(-1) != self.ignore_index
        targets = targets.contiguous().view(-1)[non_ignored].float()
        outputs = outputs.contiguous().view(-1)[non_ignored]
        outputs = torch.clamp(outputs, eps, 1. - eps)
        targets = torch.clamp(targets, eps, 1. - eps)
        pt = (1 - targets) * (1 - outputs) + targets * outputs
        return (-(1. - pt) ** self.gamma * torch.log(pt)).mean()

class ComboLoss(nn.Module):
    def __init__(self, weights, per_image=False, channel_weights=[1, 0.5, 0.5], channel_losses=None):
        super().__init__()
        self.weights = weights
        self.bce = StableBCELoss()
        self.dice = DiceLoss(per_image=False)
        self.jaccard = JaccardLoss(per_image=False)
        self.focal = FocalLoss2d()
        self.mapping = {'bce': self.bce,
                        'dice': self.dice,
                        'focal': self.focal,
                        'jaccard': self.jaccard}
        self.expect_sigmoid = {'dice', 'focal', 'jaccard'}
        self.per_channel = {'dice', 'jaccard'}
        self.values = {}
        self.channel_weights = channel_weights
        self.channel_losses = channel_losses

    def forward(self, outputs, targets):
        loss = 0
        weights = self.weights
        sigmoid_input = torch.sigmoid(outputs)
        for k, v in weights.items():
            if not v:
                continue
            val = 0 
            if k in self.per_channel:
                channels = targets.size(1)
                for c in range(channels):
                    if not self.channel_losses or k in self.channel_losses[c]:
                        val += self.channel_weights[c] * self.mapping[k](sigmoid_input[:, c, ...] if k in self.expect_sigmoid else outputs[:, c, ...],
                                               targets[:, c, ...])

            else:
                val = self.mapping[k](sigmoid_input if k in self.expect_sigmoid else outputs, targets)
            self.values[k] = val
            loss += self.weights[k] * val
            # print(k, val)
        return loss.clamp(min=1e-5)

In [None]:
## Over9000 Optimizer . Inspired by Iafoss . Over and Out !
##https://github.com/mgrankin/over9000/blob/master/ralamb.py
import torch, math
from torch.optim.optimizer import Optimizer

# RAdam + LARS
class Ralamb(Optimizer):

    def __init__(self, params, lr=1e-2, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-4):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(Ralamb, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(Ralamb, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('Ralamb does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # Decay the first and second moment running average coefficient
                # m_t
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                # v_t
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, radam_step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        radam_step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = radam_step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                radam_step = p_data_fp32.clone()
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
                else:
                    radam_step.add_(-radam_step_size * group['lr'], exp_avg)

                radam_norm = radam_step.pow(2).sum().sqrt()
                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
                if weight_norm == 0 or radam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / radam_norm

                state['weight_norm'] = weight_norm
                state['adam_norm'] = radam_norm
                state['trust_ratio'] = trust_ratio

                if N_sma >= 5:
                    p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
                else:
                    p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss

# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py

""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict

class Lookahead(Optimizer):
    def __init__(self, base_optimizer, alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if 'slow_buffer' not in param_state:
                param_state['slow_buffer'] = torch.empty_like(fast_p.data)
                param_state['slow_buffer'].copy_(fast_p.data)
            slow = param_state['slow_buffer']
            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure=None):
        # print(self.k)
        #assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group['lookahead_step'] += 1
            if group['lookahead_step'] % group['lookahead_k'] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['state'],
            'param_groups': state_dict['param_groups'],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if 'slow_state' not in state_dict:
            print('Loading state_dict from optimizer without Lookahead applied.')
            state_dict['slow_state'] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            'state': state_dict['slow_state'],
            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)

def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
     adam = Adam(params, *args, **kwargs)
     return Lookahead(adam, alpha, k)


# RAdam + LARS + LookAHead

# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20

def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
     ralamb = Ralamb(params, *args, **kwargs)
     return Lookahead(ralamb, alpha, k)

RangerLars = Over9000 

In [None]:
criterion = ComboLoss(weights={'bce': 5,'dice': 1,'focal': 5},
                      channel_weights=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
                      # channel_weights=[0.0001, 0.1813306 , 0.01049447, 0.00313614, 0.04424467, 0.05193036, 0.04788384, 0.00945399, 0.02173117, 0.00382078, 0.46088194, 0.16509204],
                      # channel_losses=0)
optimizer = Over9000(model.parameters(), lr=1e-7, weight_decay=1e-4) 
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-8, last_epoch=-1)

In [None]:
i = 0
for model in models[:2]:
  train(10, model, train_loader, criterion, optimizer, scheduler, saved_dir, i, device)
  i+=1