# Inferece for segmentation with Lightning⚡Flash

**This is just inference for this training notebook: https://www.kaggle.com/code/jirkaborovec/tract-segm-eda-flash-deeplab-albumentatio**

See also: [Easy Kaggle Offline Submission With Chaining Kernel Notebooks](https://towardsdatascience.com/easy-kaggle-offline-submission-with-chaining-kernels-30bba5ea5c4d)

## Install dependencies

In [1]:
!pip uninstall -y torchtext
# !pip install -q --upgrade torch torchvision
!pip install -q "lightning-flash[image]" "torchmetrics<0.8" --no-index --find-links ../input/demo-flash-semantic-segmentation/frozen_packages
!pip install -q -U timm segmentation-models-pytorch --no-index --find-links ../input/demo-flash-semantic-segmentation/frozen_packages
!pip install -q 'kaggle-imsegm' --no-index --find-links ../input/tract-segm-eda-3d-interactive-viewer/frozen_packages

! pip list | grep torch
! pip list | grep lightning
! nvidia-smi -L

Found existing installation: torchtext 0.10.1
Uninstalling torchtext-0.10.1:
  Successfully uninstalled torchtext-0.10.1
efficientnet-pytorch                  0.6.3
pytorch-ignite                        0.4.8
pytorch-lightning                     1.5.10
segmentation-models-pytorch           0.2.1
torch                                 1.9.1
torchaudio                            0.9.1
torchmetrics                          0.6.2
torchvision                           0.10.1
lightning-bolts                       0.5.0
lightning-flash                       0.7.3
pytorch-lightning                     1.5.10
GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-3a7cc727-0116-4b4f-056e-ff9639c20dc6)


In [2]:
import os, glob
import pandas as pd
import matplotlib.pyplot as plt

DATASET_FOLDER = "/kaggle/input/uw-madison-gi-tract-image-segmentation"
DATASET_IMAGES = "/kaggle/temp/dataset-flash/images"

df_train = pd.read_csv(os.path.join(DATASET_FOLDER, "train.csv"))
display(df_train.head())

LABELS = sorted(df_train["class"].unique())
print(LABELS)

Unnamed: 0,id,class,segmentation
0,case123_day20_slice_0001,large_bowel,
1,case123_day20_slice_0001,small_bowel,
2,case123_day20_slice_0001,stomach,
3,case123_day20_slice_0002,large_bowel,
4,case123_day20_slice_0002,small_bowel,


['large_bowel', 'small_bowel', 'stomach']


## Reuse augmentation and Trainer...

In [3]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

In [4]:
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Sequence, Tuple, Union
import albumentations as alb
from flash.core.data.io.input_transform import InputTransform
from flash.image.segmentation.input_transform import prepare_target, remove_extra_dimensions
from kaggle_imsegm.augment import FlashAlbumentationsAdapter

@dataclass
class SemanticSegmentationInputTransform(InputTransform):
    # https://albumentations.ai/docs/examples/pytorch_semantic_segmentation

    image_size: Tuple[int, int] = (128, 128)

    def train_per_sample_transform(self) -> Callable:
        return FlashAlbumentationsAdapter([
            alb.Resize(*self.image_size),
            alb.VerticalFlip(p=0.5),
            alb.HorizontalFlip(p=0.5),
            alb.RandomRotate90(p=0.5),
            alb.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.05, rotate_limit=15, p=0.5),
            alb.GaussNoise(var_limit=(0.00, 0.03), mean=0, per_channel=False, p=1.0),
            #alb.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
            #alb.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
            #alb.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
        ])

    def per_sample_transform(self) -> Callable:
        return FlashAlbumentationsAdapter([alb.Resize(*self.image_size)])

    def predict_input_per_sample_transform(self) -> Callable:
        return FlashAlbumentationsAdapter([alb.Resize(*self.image_size)])

    def target_per_batch_transform(self) -> Callable:
        return prepare_target

    def predict_per_batch_transform(self) -> Callable:
        return remove_extra_dimensions

    def serve_per_batch_transform(self) -> Callable:
        return remove_extra_dimensions

In [5]:
trainer = flash.Trainer(gpus=torch.cuda.device_count())

In [6]:
!ls -l ../input/tract-segm-eda-flash-deeplab-albumentatio/*.pt

model = SemanticSegmentation.load_from_checkpoint(
    "../input/tract-segm-eda-flash-deeplab-albumentatio/semantic_segmentation_model.pt"
)

-rw-r--r-- 1 nobody nogroup 469840399 Apr 30 03:22 ../input/tract-segm-eda-flash-deeplab-albumentatio/semantic_segmentation_model.pt


## Parse sample submissison

In [7]:
df_pred = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
sfolder = "test"
display(df_pred.head())

if df_pred.empty:
    sfolder = "train"
    df_pred = df_train[df_train["id"].str.startswith("case123_day")]

os.makedirs(os.path.join(DATASET_IMAGES, sfolder), exist_ok=True)

Unnamed: 0,id,class,predicted


In [8]:
from pprint import pprint
from kaggle_imsegm.data import extract_tract_details

pprint(extract_tract_details(df_pred['id'].iloc[0], DATASET_FOLDER, folder=sfolder))

df_pred[['Case','Day','Slice', 'image', 'image_path', 'height', 'width']] = df_pred['id'].apply(
    lambda x: pd.Series(extract_tract_details(x, DATASET_FOLDER, folder=sfolder))
)
df_pred["Case_Day"] = [f"case{r['Case']}_day{r['Day']}" for _, r in df_pred.iterrows()]
display(df_pred.head())

{'Case': 123,
 'Day': 20,
 'Slice': '0001',
 'height': 266,
 'image': 'slice_0001_266_266_1.50_1.50.png',
 'image_path': 'train/case123/case123_day20/scans/slice_0001_266_266_1.50_1.50.png',
 'width': 266}


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self[k1] = value[k2]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  if __name__ == '__main__':


Unnamed: 0,id,class,segmentation,Case,Day,Slice,image,image_path,height,width,Case_Day
0,case123_day20_slice_0001,large_bowel,,123,20,1,slice_0001_266_266_1.50_1.50.png,train/case123/case123_day20/scans/slice_0001_2...,266,266,case123_day20
1,case123_day20_slice_0001,small_bowel,,123,20,1,slice_0001_266_266_1.50_1.50.png,train/case123/case123_day20/scans/slice_0001_2...,266,266,case123_day20
2,case123_day20_slice_0001,stomach,,123,20,1,slice_0001_266_266_1.50_1.50.png,train/case123/case123_day20/scans/slice_0001_2...,266,266,case123_day20
3,case123_day20_slice_0002,large_bowel,,123,20,2,slice_0002_266_266_1.50_1.50.png,train/case123/case123_day20/scans/slice_0002_2...,266,266,case123_day20
4,case123_day20_slice_0002,small_bowel,,123,20,2,slice_0002_266_266_1.50_1.50.png,train/case123/case123_day20/scans/slice_0002_2...,266,266,case123_day20


## Predictions for test scans

In [9]:
from joblib import Parallel, delayed
from kaggle_imsegm.data import preprocess_tract_scan

_args = dict(
    dir_data=os.path.join(DATASET_FOLDER, sfolder),
    dir_imgs=DATASET_IMAGES,
    dir_segm=None,
    labels=LABELS,
    sfolder=sfolder,
)
test_scans = Parallel(n_jobs=6)(
    delayed(preprocess_tract_scan)(dfg, **_args)
    for _, dfg in df_pred.groupby("Case_Day")
)

In [10]:
import numpy as np
from itertools import chain
from kaggle_imsegm.mask import rle_encode

preds = []
for test_imgs in test_scans:
    dm = SemanticSegmentationData.from_files(
        predict_files=test_imgs,
        predict_transform=SemanticSegmentationInputTransform,
        transform_kwargs=dict(image_size=(256, 256)),
        num_classes=len(LABELS) + 1,
        batch_size=10,
        num_workers=3,
    )
    pred = trainer.predict(model, datamodule=dm, output="labels")
    pred = list(chain(*pred))
    for img, seg in zip(test_imgs, pred):
        rle = rle_encode(np.array(seg)) if np.sum(seg) > 1 else {}
        name, _ = os.path.splitext(os.path.basename(img))
        id_ = "_".join(name.split("_")[:4])
        preds += [{"id": id_, "class": lb, "predicted": rle.get(i + 1, "")} for i, lb in enumerate(LABELS)]

df_pred = pd.DataFrame(preds)
display(df_pred[df_pred["predicted"] != ""].head())

  cpuset_checked))


Predicting: 0it [00:00, ?it/s]

  "See the documentation of nn.Upsample for details.".format(mode)


Predicting: 0it [00:00, ?it/s]

Predicting: 0it [00:00, ?it/s]

Unnamed: 0,id,class,predicted
149,case123_day0_slice_0050,stomach,30760 2 31025 3 31290 5 31555 6 31821 6 32086 ...
152,case123_day0_slice_0051,stomach,27834 1 28092 10 28358 11 28623 12 28888 14 29...
155,case123_day0_slice_0052,stomach,26238 1 26500 6 26763 10 27028 11 27293 12 275...
158,case123_day0_slice_0053,stomach,25440 2 25704 5 25968 7 26232 10 26497 11 2676...
161,case123_day0_slice_0054,stomach,22780 2 23046 2 23311 3 23576 5 23841 6 24106 ...


## Finalize submissions

In [11]:
df_ssub = pd.read_csv(os.path.join(DATASET_FOLDER, "sample_submission.csv"))
del df_ssub['predicted']
df_pred = df_ssub.merge(df_pred, on=['id','class'])

df_pred[['id', 'class', 'predicted']].to_csv("submission.csv", index=False)

!head submission.csv

id,class,predicted
