In [19]:
!pip uninstall opencv-python-headless==4.5.5.62 -y
!pip install opencv-python-headless==4.5.2.52
!pip install segmentation-models-pytorch
!pip install -U albumentations
!pip install opencv-python
!pip install opencv-contrib-python
!pip install cv2-tools
!pip install pytorch-lightning

Found existing installation: opencv-python-headless 4.5.2.52
Uninstalling opencv-python-headless-4.5.2.52:
  Successfully uninstalled opencv-python-headless-4.5.2.52
Collecting opencv-python-headless==4.5.2.52
  Using cached opencv_python_headless-4.5.2.52-cp37-cp37m-manylinux2014_x86_64.whl (38.2 MB)
Installing collected packages: opencv-python-headless
Successfully installed opencv-python-headless-4.5.2.52


Collecting pytorch-lightning
  Downloading pytorch_lightning-1.5.9-py3-none-any.whl (527 kB)
[K     |████████████████████████████████| 527 kB 7.5 MB/s 
[?25hCollecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 52.6 MB/s 
Collecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting setuptools==59.5.0
  Downloading setuptools-59.5.0-py3-none-any.whl (952 kB)
[K     |████████████████████████████████| 952 kB 55.9 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.1 MB/s 
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.7.0-py3-none-any.whl (396 kB)
[K     |████████████████████████████████| 396 kB 35.4 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.1.0-py3-none-any.whl (133 kB)

In [20]:
from pathlib import Path

import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
# import torchvision.transforms as T
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.seed import seed_everything
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold

MASK_FG = 1
MASK_BG = 2
MASK_IGNORE = 3

pl.__version__, smp.__version__#, python.__version__

('1.5.9', '0.2.1')

In [21]:
config = {
    "model": {
        "arch": "Unet",
        "encoder_name": "inceptionv4", #resnet50
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 1,
    },
    "train": {
        "epoch": 10,
    },
    "optim": {
        "weight_decay": 0.001,
        "lr_max": 2e-3,
    }
}

In [22]:
!head -n 10 ../input/the-oxfordiiit-pet-dataset/annotations/annotations/list.txt

head: cannot open '../input/the-oxfordiiit-pet-dataset/annotations/annotations/list.txt' for reading: No such file or directory


In [None]:
df = pd.read_csv(
    "../input/the-oxfordiiit-pet-dataset/annotations/annotations/list.txt",
    delimiter=" ",
    skiprows=6,
    header=None,
    names=["stem", "class_id", "species", "breed"]
)
df["class_name"] = df.stem.map(lambda x: x.split("_")[0])
df["image"] = df.stem.map(lambda x: f"../input/the-oxfordiiit-pet-dataset/images/images/{x}.jpg")
df["trimap"] = df.stem.map(lambda x: f"../input/the-oxfordiiit-pet-dataset/annotations/annotations/trimaps/{x}.png")

df

In [None]:
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
sns.histplot(df.class_id)
plt.subplot(1,3,2)
sns.histplot(df.species, discrete=True)
plt.subplot(1,3,3)
sns.histplot(df.breed)
plt.show()

In [None]:
print(df.iloc[0])
img = Image.open(df.image[0])
annot = Image.open(df.trimap[0])

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(img)
plt.xticks([])
plt.yticks([])

plt.subplot(1,2,2)
plt.imshow(annot)
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
# sanity check
if False:
    for img, annot in tqdm(zip(df.image, df.trimap), total=len(df)):
        Image.open(img).verify()
        Image.open(annot).verify()

In [None]:
class IIITDataset(Dataset):
    def __init__(self, df, tfm=None):
        self.df = df
        self.tfm = tfm
    def __len__(self):
        return len(self.df)
    def __getitem__(self, i):
        img = Image.open(self.df.image.iloc[i]).convert('RGB')
        mask = Image.open(self.df.trimap.iloc[i])
        img = np.asarray(img)
        mask = np.asarray(mask)
        if self.tfm:
            augmented = self.tfm(image=img, mask=mask)
            img, mask = augmented["image"], augmented["mask"]
        return img, mask

In [None]:
train_tfm = A.Compose([
    A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5),
    A.RandomScale(),
    A.Rotate(border_mode=cv2.BORDER_CONSTANT, mask_value=MASK_BG),
    A.RandomBrightnessContrast(p=0.2),
    A.SmallestMaxSize(224), A.RandomCrop(224, 224),
    A.Normalize(),
    ToTensorV2(),
])
val_tfm = A.Compose([
    A.SmallestMaxSize(224), A.CenterCrop(224, 224),
    A.Normalize(),
    ToTensorV2(),
])

imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std = (0.229, 0.224, 0.225)

def imagenet_denorm(x):
    """x: array-like with shape (..., H, W, C)"""
    return x * imagenet_std + imagenet_mean

skf = StratifiedKFold(5)
train_idx, val_idx = next(iter(skf.split(df, df.class_id)))
train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]

train_ds = IIITDataset(train_df, tfm=train_tfm)
val_ds = IIITDataset(val_df, tfm=val_tfm)

In [None]:
for _ in range(3):
    img, mask = train_ds[0]
    plt.subplot(1,2,1)
    plt.imshow(imagenet_denorm(img.numpy().transpose(1,2,0)))
    plt.xticks([]); plt.yticks([])
    plt.subplot(1,2,2)
    plt.imshow(mask == MASK_FG)
    plt.xticks([]); plt.yticks([])
    plt.show()

In [None]:
seed_everything(42)

In [None]:
class Task(LightningModule):
    def __init__(self, cfg, train_df, val_df):
        super().__init__()
        self.cfg = cfg
        self.train_df = train_df
        self.val_df = val_df
    def setup(self, stage=None):
        global train_tfm
        global val_tfm
        self.model = smp.create_model(**self.cfg["model"])
        self.train_ds = IIITDataset(self.train_df, tfm=train_tfm)
        self.val_ds = IIITDataset(self.val_df, tfm=val_tfm)
        self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=16,
            shuffle=True,
            pin_memory=True,
        )
    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=16,
            pin_memory=True,
        )
    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_idx):
        x, t = batch
        fg = (t.detach() == MASK_FG).float()
        valid_mask = (t.detach() != MASK_IGNORE)
        y = self.model(x).squeeze()
        loss = torch.masked_select(self.loss_fn(y, fg), valid_mask).mean()
        return loss
    def validation_step(self, batch, batch_idx):
        x, t = batch
        fg = (t.detach() == MASK_FG).float()
        valid_mask = (t.detach() != MASK_IGNORE)
        y = self.model(x).squeeze()
        loss = torch.masked_select(self.loss_fn(y, fg), valid_mask).mean()
        return {"loss": loss}
    def training_epoch_end(self, outputs):
        self.log("loss", np.mean([x["loss"].detach().cpu().numpy() for x in outputs]))
    def validation_epoch_end(self, outputs):
        self.log("val_loss", np.mean([x["loss"].detach().cpu().numpy() for x in outputs]))
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), weight_decay=self.cfg["optim"]["weight_decay"])
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.OneCycleLR(
                    optimizer, self.cfg["optim"]["lr_max"],
                    epochs=self.cfg["train"]["epoch"],
                    steps_per_epoch=len(self.train_dataloader())
                ),
                "interval": "step",
            },
        }

In [None]:
class PredictImageCallback(callbacks.Callback):
    def __init__(self, model: LightningModule):
        super().__init__()
        self.model = model
#     def on_train_start(self, trainer, pl_module):
#         # visualize one batch from val #
#         xs, ts = next(iter(self.model.val_dataloader()))
        
#         arr = xs.numpy()
#         arr = (imagenet_denorm(arr.transpose(0,2,3,1)) * 255).astype(np.uint8)
#         self.model.experiment.add_image("val-image", arr, dataformats='NHWC')
#         self.model.experiment.add_image("val-mask", np.where(ts[:,:,:,None] == MASK_FG, arr, 0), dataformats='NHWC')
    def on_train_epoch_end(self, trainer, pl_module):
        # visualize one batch from predicted val #
        xs, ts = next(iter(self.model.val_dataloader()))
        with torch.no_grad():
            ys = self.model(xs.to(self.model.device)).cpu().numpy().squeeze()
        arr = xs.numpy()
        arr = (imagenet_denorm(arr.transpose(0,2,3,1)) * 255).astype(np.uint8)
        self.model.logger.experiment.add_image("val-image", arr, dataformats='NHWC')
        self.model.logger.experiment.add_image("val-mask", np.where(ts[:,:,:,None] == MASK_FG, arr, 0), dataformats='NHWC')
        self.model.logger.experiment.add_image("val-pred", np.where(ys[:,:,:,None] > 0, arr, 0), dataformats='NHWC')

In [None]:
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df.class_id)):
    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    task = Task(config, train_df, val_df)
    earystopping = EarlyStopping(monitor="val_loss")
    lr_monitor = callbacks.LearningRateMonitor(logging_interval='step')
    loss_checkpoint = callbacks.ModelCheckpoint(
        filename="best_loss",
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        save_last=False,
        save_weights_only=True
    )
    logger = TensorBoardLogger(f"fold-{fold}")
    
    trainer = pl.Trainer(
        gpus=1,
        logger=logger,
        max_epochs=config["train"]["epoch"],
        callbacks=[lr_monitor, loss_checkpoint, earystopping, PredictImageCallback(task)],
    )
    trainer.fit(task)
#     break

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
from glob import glob
from io import BytesIO
from matplotlib import gridspec

def show_log(fold):
    path = glob(f'./fold-{fold}/default/version_0/events*')[0]
    event_acc = EventAccumulator(path, size_guidance={'scalars': 0})
    event_acc.Reload()

    scalars = {}
    for tag in event_acc.Tags()['scalars']:
        events = event_acc.Scalars(tag)
        scalars[tag] = [event.value for event in events]

    images = {}
    for tag in event_acc.Tags()['images']:
        events = event_acc.Images(tag)
        images[tag] = [Image.open(BytesIO(event.encoded_image_string)) for event in events]
    
    plt.figure(figsize=(16, 6))
    plt.subplot(1, 2, 1)
    plt.plot(range(len(scalars['lr-AdamW'])), scalars['lr-AdamW'])
    plt.xlabel('steps')
    plt.ylabel('lr')
    plt.title('adamw lr')

    plt.subplot(1, 2, 2)
    plt.plot(range(len(scalars['loss'])), scalars['loss'], label='train_loss')
    plt.plot(range(len(scalars['val_loss'])), scalars['val_loss'], label='val_loss')
    plt.legend()
    plt.ylabel('bce')
    plt.xlabel('epoch')
    plt.title('train/val loss')
    plt.show()
    
    plt.figure(figsize=(16, 10))
    gs = gridspec.GridSpec(5,2)
    ax = plt.subplot(gs[0,:])
    ax.imshow(images["val-image"][0])
    ax.set_xticks([]); ax.set_yticks([])
    for i in range(4):
        ax0 = plt.subplot(gs[i+1,0])
        ax1 = plt.subplot(gs[i+1,1])
        ax0.set_title(f"epoch {i+1}:pred")
        ax0.imshow(images["val-pred"][i])
        ax0.set_xticks([]); ax0.set_yticks([])
        ax1.set_title(f"epoch {i+1}:gt")
        ax1.imshow(images["val-mask"][i])
        ax1.set_xticks([]); ax1.set_yticks([])
            
    plt.show()

for i in range(5):
    show_log(i)