In [21]:
import os
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
import albumentations as albu
import torch
import base64
from pycocotools import _mask as coco_mask
import typing as t
import zlib
import json

import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.dataset_synapse import Synapse_dataset
from utils import test_single_volume
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from scipy.ndimage import zoom

In [22]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cpu'

In [23]:
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

dataset_config = {
        'HubMap': {
            'Dataset': Synapse_dataset,
            'volume_path': '../data/HubMap/test_h5',
            'list_dir': './lists/lists_HubMap',
            'num_classes': 2,
            'z_spacing': 1,
        },
    }
dataset_name = 'HubMap'
num_classes = dataset_config[dataset_name]['num_classes']
volume_path = dataset_config[dataset_name]['volume_path']
Dataset = dataset_config[dataset_name]['Dataset']
list_dir = dataset_config[dataset_name]['list_dir']
z_spacing = dataset_config[dataset_name]['z_spacing']

img_size = 512
is_pretrain = True
vit_name = 'R50-ViT-B_16'
n_skip = 3
vit_patches_size = 16
max_epochs = 50
batch_size = 4
base_lr = 0.002
tgt_model_epoch = 39
exp = 'TU_' + dataset_name + str(img_size)
snapshot_path = "../model/{}/{}".format(exp, 'TU')
snapshot_path = snapshot_path + '_pretrain' if is_pretrain else snapshot_path
snapshot_path += '_' + vit_name
snapshot_path = snapshot_path + '_skip' + str(n_skip)
snapshot_path = snapshot_path + '_vitpatch' + str(vit_patches_size) if vit_patches_size!=16 else snapshot_path
snapshot_path = snapshot_path + '_epo' + str(max_epochs) if max_epochs != 30 else snapshot_path
if dataset_name == 'ACDC':  # using max_epoch instead of iteration to control training duration
    snapshot_path = snapshot_path + '_' + str(max_iterations)[0:2] + 'k' if max_iterations != 30000 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(batch_size)
snapshot_path = snapshot_path + '_lr' + str(base_lr) if base_lr != 0.01 else snapshot_path
snapshot_path = snapshot_path + '_'+str(img_size)
snapshot_path = snapshot_path + '_s'+str(seed) if seed!=1234 else snapshot_path

config_vit = CONFIGS_ViT_seg[vit_name]
config_vit.n_classes = num_classes
config_vit.n_skip = n_skip
config_vit.patches.size = (vit_patches_size, vit_patches_size)
if vit_name.find('R50') !=-1:
    config_vit.patches.grid = (int(img_size/vit_patches_size), int(img_size/vit_patches_size))
net = ViT_seg(config_vit, img_size=img_size, num_classes=config_vit.n_classes)

snapshot = os.path.join(snapshot_path, 'best_model.pth')
if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(tgt_model_epoch))
#if not os.path.exists(snapshot): snapshot = snapshot.replace('best_model', 'epoch_'+str(14))
print(f'Loaded snapshot path is: {snapshot}')

net.load_state_dict(torch.load(snapshot, map_location=torch.device('cpu')))

Loaded snapshot path is: ../model/TU_HubMap512/TU_pretrain_R50-ViT-B_16_skip3_epo50_bs4_lr0.002_512/epoch_39.pth


<All keys matched successfully>

In [24]:
net

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (hybrid_model): ResNetV2(
        (root): Sequential(
          (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
          (gn): GroupNorm(32, 64, eps=1e-06, affine=True)
          (relu): ReLU(inplace=True)
        )
        (body): Sequential(
          (block1): Sequential(
            (unit1): PreActBottleneck(
              (gn1): GroupNorm(32, 64, eps=1e-06, affine=True)
              (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (gn2): GroupNorm(32, 64, eps=1e-06, affine=True)
              (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (gn3): GroupNorm(32, 256, eps=1e-06, affine=True)
              (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (relu): ReLU(inplace=True)
              (downsample): StdConv2d(64, 256, 

In [25]:
DATA_DIR = './'
x_test_dir = os.path.join(DATA_DIR, '../../test')
y_test_dir = os.path.join(DATA_DIR, '../../test')
best_model = net
best_model = best_model.to(DEVICE)
suffix = f'all_test_preproc_trans_unet_img_size_{img_size}_tgt_epoch_{tgt_model_epoch}'
print(suffix)

all_test_preproc_trans_unet_img_size_512_tgt_epoch_39


In [26]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [27]:
class HubMapDataset(BaseDataset):
    """Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['unlabelled', 'blood_vessel']
    
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return self.images_fps[i], image, mask
        
    def __len__(self):
        return len(self.ids)

In [28]:
def get_training_augmentation():
    train_transform = [

        albu.HorizontalFlip(p=0.5),

        albu.ShiftScaleRotate(scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0),

        albu.PadIfNeeded(min_height=512, min_width=352, always_apply=True, border_mode=0),
        albu.RandomCrop(height=512, width=352, always_apply=True),

        albu.GaussNoise(p=0.2),
        albu.Perspective(p=0.5),

        albu.OneOf(
            [
                albu.CLAHE(p=1),
                albu.RandomBrightnessContrast(p=1),
                albu.RandomGamma(p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.Sharpen(p=1),
                albu.Blur(blur_limit=3, p=1),
                albu.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),

        albu.OneOf(
            [
                albu.RandomBrightnessContrast(p=1),
                albu.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return albu.Compose(train_transform)


def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        albu.PadIfNeeded(512, 512)
    ]
    return albu.Compose(test_transform)

def to_tensor(image, **kwargs):
  return torch.tensor(image, dtype=torch.float32)

def convert_to_grayscale(image):
  grayscale = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  max_grayscale_num = grayscale.max()
  min_grayscale_num = grayscale.min()
  grayscale = (grayscale-min_grayscale_num) / float(max_grayscale_num-min_grayscale_num)
  return grayscale

def preprocess_validation_imgs(image, img_size, **kwargs):
  grayscale = convert_to_grayscale(image)
  height, width = grayscale.shape
  if img_size < height:
    grayscale = zoom(grayscale, (img_size / height, img_size / width), order=3)
  return grayscale
  
from functools import partial
def get_preprocessing(preprocessing_fn, img_size):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    partial_func = partial(preprocessing_fn, img_size=img_size)
    _transform = [
        albu.Lambda(image=partial_func),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [29]:
test_dataset_without_preproc = HubMapDataset(
    x_test_dir, 
    y_test_dir, 
    classes=['unlabelled', 'blood_vessel'],
)

test_dataset = HubMapDataset(
    x_test_dir, 
    y_test_dir, 
    preprocessing=get_preprocessing(preprocess_validation_imgs, img_size),
    classes=['unlabelled', 'blood_vessel'],
)

test_loader = DataLoader(test_dataset, batch_size=8)

In [35]:
import cv2
import numpy as np
import base64
from pycocotools import _mask as coco_mask
import typing as t
import zlib

def extract_polygon_masks(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    masks = []
    
    for contour in contours:
        epsilon = 0.01 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        
        if approx.shape[0] >= 3:
            polygon_mask = np.zeros_like(mask)
            cv2.drawContours(polygon_mask, [approx], 0, 1, -1)
            masks.append(polygon_mask.astype('bool'))
    
    return masks

def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != np.bool_:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)
  
  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str

def decode_binary_mask(encoded_mask):
    # Decode base64 and decompress the binary string
    binary_str = base64.b64decode(encoded_mask)
    decompressed_str = zlib.decompress(binary_str)

    # Decode RLE-encoded mask
    encoded_mask = np.frombuffer(decompressed_str, dtype=np.uint8)
    decoded_mask = coco_mask.decode({"counts": encoded_mask})

    # Convert COCO API format to binary mask
    mask = np.squeeze(decoded_mask)

    return mask

In [31]:
def overlay_mask(image, mask):
    # Make a copy of the original image to avoid modifying it
    overlay = np.copy(image)
    
    # Convert the binary mask to a boolean mask
    mask = mask.astype(bool)
    
    # Set the red channel of the overlay where the mask is True to 255
    overlay[mask, 0] = 255
    
    # Set the green and blue channels of the overlay where the mask is True to 0
    overlay[mask, 1:] = 0
    
    return overlay

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch
import numpy as np
from torchvision import models
from torchvision import transforms

best_model.eval()
for i in range(20):
  image_vis = test_dataset_without_preproc[i][1].astype('uint8')
  _, image, gt_mask = test_dataset[i]

  gt_mask = gt_mask.squeeze()
  
  pr_mask = best_model(image.unsqueeze(0).unsqueeze(0))
  pr_mask = torch.softmax(pr_mask, dim=1)[:,1,:,:].squeeze().detach().numpy()
  pr_mask_height, pr_mask_width = pr_mask.shape
  if pr_mask_height != image_vis.shape[0]:
    pr_mask_zoomed = zoom(pr_mask, (float(image_vis.shape[0]/pr_mask_height), float(image_vis.shape[1]/pr_mask_width)), order=3)
  else:
    pr_mask_zoomed = pr_mask
  pr_mask_zoomed = (pr_mask_zoomed>0.5).astype('uint8')
  print(np.unique(pr_mask_zoomed), pr_mask_zoomed.shape)
  print(image_vis.shape, gt_mask.shape)
  masked_image_gt = overlay_mask(image_vis, gt_mask[:,:,1].numpy())
  print(pr_mask_zoomed.shape)
  masked_image_pr = overlay_mask(image_vis, pr_mask_zoomed)

  visualize(
      image_orig = image_vis,
      image=image, 
      ground_truth_mask=masked_image_gt, 
      predicted_mask=pr_mask_zoomed
  )

In [32]:
import time
def generate_submission(model, device, dataloader, suffix):
    model.eval()
    num_batches = len(dataloader)
    print(f'Processing a total of {num_batches} images for submission')
    submission_dicts = []
    start_time = time.time()
    # Disable gradient calculation
    with torch.no_grad():
        # Iterate over the validation dataset
        for batch_idx, (img_file, inputs, targets) in enumerate(dataloader):
            model_inference_start_time = time.time()
            inputs = inputs.to(device)
            outputs = model(inputs.unsqueeze(1))
            print(f'Takes {float(time.time()-model_inference_start_time)} seconds to inference for 8 samples')
            post_proc_start_time = time.time()
            outputs = torch.softmax(outputs, dim=1)[:,1,:,:].detach().numpy()
            output_batch_size, outputs_height, outputs_width = outputs.shape
            if outputs_height != 512:
              outputs_zoomed = zoom(outputs, (1, float(512/outputs_height), float(512/outputs_width)), order=3)
            else:
              outputs_zoomed = outputs
            outputs_zoomed_thresh = (outputs_zoomed>0.5).astype('uint8')
            print(outputs.shape, outputs_zoomed.shape, outputs_zoomed_thresh.shape)
            for i in range(len(outputs_zoomed_thresh)):
              cur_dict = dict()
              img_id = img_file[i].split('/')[-1].split('.')[0]
              cur_dict['id'] = img_id
              cur_dict['height'] = 512
              cur_dict['width'] = 512
              prediction_string = ''
              polygon_masks = extract_polygon_masks(outputs_zoomed_thresh[i,:,:])
              for polygon_mask in polygon_masks:
                polygon_mask_conf = round(((polygon_mask * outputs_zoomed[i,:,:]).sum())/(polygon_mask.sum()), 2)
                polygon_mask_string = encode_binary_mask(polygon_mask).decode('utf-8')
                prediction_string += f'0 {polygon_mask_conf} {polygon_mask_string} '
              cur_dict['prediction_string'] = prediction_string.strip()
              submission_dicts.append(cur_dict)
            print(f'Takes {float(time.time()-post_proc_start_time)} seconds for post processing 8 samples')
            if (batch_idx+1) % 10 == 0:
              print(f'On batch {batch_idx} and finished in {float(time.time()-start_time)/60} minutes')
        submission_df = pd.DataFrame.from_dict(submission_dicts)
        submission_df.to_csv(f'./submissions/submission_{suffix}.csv', index=False)

In [33]:
if not os.path.exists('./submissions'):
  os.mkdir('./submissions')

In [34]:
generate_submission(best_model, DEVICE, test_loader, suffix)

Processing a total of 1 images for submission
Takes 2.8165123462677 seconds to inference for 8 samples
(1, 512, 512) (1, 512, 512) (1, 512, 512)
Takes 0.015571832656860352 seconds for post processing 8 samples


In [None]:
def generate_ground_truth_map_files(tiles_dicts_new, x_test_dir, suffix):
  bbox_dicts = []
  segfile_dicts = []
  labels_info = set()
  img_width = 512
  img_height = 512
  print(f'Processing a total of {len(tiles_dicts_new)} tiles')
  start_time = time.time()
  for idx, tiles_dict in enumerate(tiles_dicts_new):
    img_id = tiles_dict['id']
    base_image = cv2.imread(f'{x_test_dir}/{img_id}.png')
    for annot in tiles_dict['annotations']:
      if annot['type'] == 'blood_vessel':
        blood_vessel_masked_image = np.zeros((512, 512))
        cur_dict = dict()
        cur_segfile_dict = dict()
        cur_dict['ImageID'] = img_id
        cur_dict['LabelName'] = annot['type']
        coords = annot['coordinates'][0]
        cv2.fillPoly(blood_vessel_masked_image, pts=[np.array(coords)], color=1)
        encoded_mask = encode_binary_mask(blood_vessel_masked_image.astype('bool')).decode('utf-8')
        x_vals = [x[0] for x in coords]
        y_vals = [x[1] for x in coords]
        x_min = float(min(x_vals))/img_width
        x_max = float(max(x_vals))/img_width
        y_min = float(min(y_vals))/img_height
        y_max = float(max(y_vals))/img_height
        cur_dict['XMin'] = x_min
        cur_dict['XMax'] = x_max
        cur_dict['YMin'] = y_min
        cur_dict['YMax'] = y_max
        cur_dict['IsGroupOf'] = 0
        cur_segfile_dict['ImageID'] = img_id
        cur_segfile_dict['LabelName'] = annot['type']
        cur_segfile_dict['ImageWidth'] = img_width
        cur_segfile_dict['ImageHeight'] = img_height
        cur_segfile_dict['XMin'] = x_min
        cur_segfile_dict['XMax'] = x_max
        cur_segfile_dict['YMin'] = y_min
        cur_segfile_dict['YMax'] = y_max
        cur_segfile_dict['IsGroupOf'] = 0
        cur_segfile_dict['Mask'] = encoded_mask
        bbox_dicts.append(cur_dict)
        segfile_dicts.append(cur_segfile_dict)
        cur_labels_info = (img_id, annot['type'], 1)
        if cur_labels_info not in labels_info:
          labels_info.add(cur_labels_info)
    if idx % 50 == 0:
      print(f'Finished {idx} tiles in {float(time.time()-start_time)/60} minutes')
  bbox_dicts_df = pd.DataFrame.from_dict(bbox_dicts)
  bbox_dicts_df.to_csv(f'./map_input_data/segmentation_bbox_{suffix}.csv', index=False)
  labels_info_df = pd.DataFrame(list(labels_info), columns=['ImageID', 'LabelName', 'Confidence'])
  labels_info_df.to_csv(f'./map_input_data/segmentation_labels_{suffix}.csv', index=False)
  segfile_df = pd.DataFrame.from_dict(segfile_dicts)
  segfile_df.to_csv(f'./map_input_data/segmentation_masks_{suffix}.csv', index=False)

In [None]:
if not os.path.exists('./map_input_data'):
  os.mkdir('./map_input_data')

In [None]:
with open(f'./polygons.jsonl', 'r') as json_file:
    json_list = list(json_file)
    
tiles_dicts = []
for json_str in json_list:
    tiles_dicts.append(json.loads(json_str))

In [None]:
x_test_dir

In [None]:
tgt_ids = [x.split('.')[0] for x in os.listdir(x_test_dir)]
tiles_dicts_new = [x for x in tiles_dicts if x['id'] in tgt_ids]

In [None]:
generate_ground_truth_map_files(tiles_dicts_new, x_test_dir, suffix)

In [None]:
def create_pred_parts(pred_string):
  pred_parts = pred_string.split()
  pred_parts_arr = []
  for i in range(0, len(pred_parts), 3):
    pred_parts_arr.append(pred_parts[i:i+3])
  return pred_parts_arr

In [None]:
suffix

In [None]:
def generate_prediction_map_file(submission_df, suffix):
  non_empty_pred_mask = submission_df.apply(lambda x: x['prediction_string']!='', axis=1)
  submission_df = submission_df[non_empty_pred_mask].dropna(subset=['prediction_string'], axis=0)
  submission_df = submission_df.rename(columns={'id': 'ImageID', 'height': 'ImageHeight', 'width': 'ImageWidth'})
  submission_df['prediction_string'] = submission_df['prediction_string'].apply(create_pred_parts)
  submission_df = submission_df.explode('prediction_string')
  submission_df[['LabelName', 'Score', 'Mask']] = submission_df['prediction_string'].apply(pd.Series)
  submission_df['LabelName'] = submission_df['LabelName'].astype(str).replace('0', 'blood_vessel')
  submission_df = submission_df.drop('prediction_string', axis=1)
  submission_df.to_csv(f'./map_input_data/seg_preds_{suffix}.csv', index=False)

In [None]:
submission_df = pd.read_csv(f'./submissions/submission_{suffix}.csv')
generate_prediction_map_file(submission_df, suffix)

In [None]:
seg_preds_df = pd.read_csv(f'./map_input_data/seg_preds_{suffix}.csv')
seg_preds_df.head()

In [None]:
seg_preds_df = seg_preds_df.loc[seg_preds_df['Score']>=0.7]
seg_preds_df.to_csv(f'./map_input_data/seg_preds_{suffix}_test.csv', index=False)
seg_preds_df.head()

In [None]:
import seaborn as sns
print(seg_preds_df.shape, seg_preds_df.loc[seg_preds_df['Score']>0.7].shape, seg_preds_df.loc[seg_preds_df['Score']>0.75].shape)
sns.boxplot(x=seg_preds_df["Score"])

In [79]:
import pycocotools
def encode_binary_mask(mask: np.ndarray) -> t.Text:
  """Converts a binary mask into OID challenge encoding ascii text."""

  # check input mask --
  if mask.dtype != np.bool_:
    raise ValueError(
        "encode_binary_mask expects a binary mask, received dtype == %s" %
        mask.dtype)

  mask = np.squeeze(mask)
  if len(mask.shape) != 2:
    raise ValueError(
        "encode_binary_mask expects a 2d mask, received shape == %s" %
        mask.shape)
  
  # convert input mask to expected COCO API input --
  mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1)
  mask_to_encode = mask_to_encode.astype(np.uint8)
  mask_to_encode = np.asfortranarray(mask_to_encode)

  # RLE encode mask --
  encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"]

  # compress and base64 encoding --
  binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION)
  base64_str = base64.b64encode(binary_str)
  return base64_str

def decode_binary_mask(encoded_mask):
    # Decode base64 and decompress the binary string
    binary_str = base64.b64decode(encoded_mask)
    print(binary_str)
    decompressed_str = zlib.decompress(binary_str)
    print(decompressed_str)
    
    # Decode RLE-encoded mask
    encoded_mask = np.frombuffer(decompressed_str, dtype=np.uint8)
    print(encoded_mask)
    decoded_mask = coco_mask.decode(decompressed_str)
    print(decoded_mask)

    # Convert COCO API format to binary mask
    mask = np.squeeze(decoded_mask)

    return mask

In [81]:
input_str = 'eNrLDo8xM8yzN/YzMDTwNzD0D8lIMAQAO8IFkQ=='
mask = decode_binary_mask(input_str.encode('utf-8'))

b'x\xda\xcb\x0e\x8f13\xcc\xb37\xf6304\xf070\xf4\x0f\xc9H0\x04\x00;\xc2\x05\x91'
b'kW\\61n?3N010O01OTh`1'
[107  87  92  54  49 110  63  51  78  48  49  48  79  48  49  79  84 104
  96  49]


TypeError: 'int' object is not subscriptable

In [None]:
import time
from multiprocessing import Pool
def generate_submission_multiproc(model, device, dataloader, suffix):
  model.eval()
  num_batches = len(dataloader)
  print(f'Processing a total of {num_batches} images for submission')
  submission_dicts = []
  start_time = time.time()
  cur_batch = []
  batch_size = 8
  with torch.no_grad():
    for batch_idx, (img_file, inputs, targets) in enumerate(dataloader):
      if (batch_idx+1) % batch_size == 0:
        with Pool(8) as pool:
          result = pool.starmap(run_inference_multiproc, cur_batch)
          return result
      else:
        cur_batch.append((model, device, img_file, inputs))


def run_inference_multiproc(model, device, img_file, inputs):
    with torch.no_grad():
      cur_dict = dict()
      img_id = img_file[0].split('/')[-1].split('.')[0]
      cur_dict['id'] = img_id
      cur_dict['height'] = 512
      cur_dict['width'] = 512
      prediction_string = ''
      inputs = inputs.to(device)
      outputs = model(inputs.unsqueeze(1))
      outputs = torch.softmax(outputs, dim=1)[:,1,:,:].squeeze().detach().numpy()
      outputs_height, outputs_width = outputs.shape
      if outputs_height != cur_dict['height']:
        outputs_zoomed = zoom(outputs, (float(cur_dict['height']/outputs_height), float(cur_dict['width']/outputs_width)), order=3)
      else:
        outputs_zoomed = outputs
      outputs_zoomed_thresh = (outputs_zoomed>0.5).astype('uint8')
      polygon_masks = extract_polygon_masks(outputs_zoomed_thresh)
      for polygon_mask in polygon_masks:
        polygon_mask_conf = round(((polygon_mask * outputs_zoomed).sum())/(polygon_mask.sum()), 2)
        polygon_mask_string = encode_binary_mask(polygon_mask).decode('utf-8')
        prediction_string += f'0 {polygon_mask_conf} {polygon_mask_string} '
      cur_dict['prediction_string'] = prediction_string.strip()
    return cur_dict

In [None]:
result = generate_submission_multiproc(best_model, DEVICE, test_loader, suffix)