In [1]:
!pip install -q segmentation_models_pytorch

In [2]:
from pathlib import Path
import os 
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import cv2
import random
import torch
from tqdm import tqdm
from torch import nn
from glob import glob

In [3]:
class CFG:
    img_size = [224, 224]
    valid_bs = 64
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    thr = 0.5

In [4]:
from torch.utils.data import Dataset, DataLoader 
import albumentations as A

data_transforms = {
    "train": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
#         A.HorizontalFlip(p=0.5),
# #         A.VerticalFlip(p=0.5),
#         A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
#         A.OneOf([
#             A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
# # #             A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
#             A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
#         ], p=0.25),
#         A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
#                          min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ], p=1.0),
    
    "valid": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

def load_img(path):
    img = cv2.imread(path, cv2.IMREAD_COLOR)  # read as BGR by default
    if img is None:
        raise ValueError(f"Image at {path} could not be loaded.")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # convert to RGB

    img = img.astype('float32')  # convert from uint8 to float32
    mx = np.max(img)
    if mx:
        img /= mx  # scale to [0, 1]
    return img
    
class ShipDataset(Dataset):
    def __init__(self, image_list, transforms=None):
        self.images = image_list 
        self.transforms = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = load_img(img_path)
        if self.transforms:
            data = self.transforms(image=img)
            img = data['image']
        img = np.transpose(img, (2, 0, 1))
        return torch.tensor(img, dtype=torch.float32), img_path

In [5]:
import segmentation_models_pytorch as smp

def build_model():
    model = smp.Unet(
        encoder_name="efficientnet-b1",      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=1,        # model output channels (number of classes in your dataset)
        activation=None,
    )
    model.to(CFG.device)
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path, map_location='cpu'))
    model.eval()
    return model

In [6]:
def mask2rle(mask):
    mask = cv2.resize(mask, (768, 768), interpolation=cv2.INTER_NEAREST)
    '''
    mask: numpy array, 1 - mask, 0 - background
    Returns run length as string formatted
    '''
    pixels = mask.T.flatten() 
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] = runs[1::2] - runs[:-1:2]
    rle = ' '.join(str(x) for x in runs)
    return rle

def masks2rles(msks, ids):
    pred_strings = []
    for idx in range(msks.shape[0]):
        msk = cv2.resize(msks[idx], 
                         dsize=(768, 768), 
                         interpolation=cv2.INTER_NEAREST) # back to original shape
        rle = mask2rle(msk)
        pred_strings.append(rle)
    return pred_strings

def rle2mask(rle):
    shape = (768,768)
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    if(type(rle) == float): return img.reshape(rle)
    if(type(rle) == str): masks = [rle]
    for mask in masks:
        s = mask.split()
        
        for i in range(len(s)//2):
            start = int(s[2*i]) - 1
            length = int(s[2*i+1])
            img[start:start+length] = 1
    return img.reshape(shape).T

In [7]:
@torch.no_grad()
def infer(model, test_loader, num_log=1, thr=CFG.thr):
    pred_strings = []; pred_ids = []; pred_classes = [];
    for idx, (img, ids) in enumerate(tqdm(test_loader, total=len(test_loader), desc='Infer ')):
        img = img.to(CFG.device, dtype=torch.float) # .squeeze(0)
        size = img.size()
        msk = torch.zeros((size[0], size[2], size[3]), device=CFG.device, dtype=torch.float32) # (32, 3, 224, 224)
        logit = model(img)
        pred = logit.squeeze()
        pred = nn.Sigmoid()(pred)
        msk += pred
        msk = (msk > CFG.thr).to(torch.uint8).cpu().detach().numpy()
        rles = masks2rles(msk, idx)
        pred_strings.extend(rles)
        pred_ids.extend(ids)
    return pred_strings, pred_ids

In [8]:
test_images = glob('/kaggle/input/airbus-ship-detection/test_v2/*.jpg')
len(test_images)

15606

In [9]:
model_path = "/kaggle/input/ship-segmentation/pytorch/default/1/best_epoch-00.bin"

model = load_model(model_path)

config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/31.5M [00:00<?, ?B/s]

In [10]:
test_ds = ShipDataset(test_images, transforms=data_transforms['valid'])
test_dl  = DataLoader(test_ds, batch_size=CFG.valid_bs, num_workers=4, shuffle=False, pin_memory=False)

In [11]:
pred_strings, pred_ids = infer(model, test_dl)

Infer : 100%|██████████| 244/244 [02:00<00:00,  2.03it/s]


In [12]:
df = pd.read_csv('/kaggle/input/airbus-ship-detection/sample_submission_v2.csv')
print(df.shape)
df.head()

(15606, 2)


Unnamed: 0,ImageId,EncodedPixels
0,00002bd58.jpg,1 2
1,00015efb6.jpg,1 2
2,00023d5fc.jpg,1 2
3,000367c13.jpg,1 2
4,0008ca6e9.jpg,1 2


In [13]:
sub = pd.DataFrame({'ImageId': pred_ids, 'EncodedPixels': pred_strings})
sub.head()

Unnamed: 0,ImageId,EncodedPixels
0,/kaggle/input/airbus-ship-detection/test_v2/42...,
1,/kaggle/input/airbus-ship-detection/test_v2/b0...,193219 17 193987 17 194755 17 195519 24 196287...
2,/kaggle/input/airbus-ship-detection/test_v2/9f...,
3,/kaggle/input/airbus-ship-detection/test_v2/91...,
4,/kaggle/input/airbus-ship-detection/test_v2/6f...,


In [14]:
sub['ImageId'] = sub['ImageId'].apply(lambda x: x.split('/')[-1])
sub.to_csv('submission.csv', index=False)

In [15]:
df = pd.read_csv('submission.csv')
df.head()

Unnamed: 0,ImageId,EncodedPixels
0,4291f3a66.jpg,
1,b0808caaf.jpg,193219 17 193987 17 194755 17 195519 24 196287...
2,9f582d5ce.jpg,
3,916ae8dd3.jpg,
4,6fa533973.jpg,
