<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/semantic_segmentation/unet_deeplabv3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget -q http://imgcom.jsrt.or.jp/imgcom/wp-content/uploads/2018/11/Segmentation01.zip

In [None]:
!unzip -q /content/Segmentation01.zip

In [None]:
!python

In [None]:
!pip install -q pytorch_lightning

In [None]:
!pip install -q torchmetrics==0.7.0

In [None]:
import matplotlib.pyplot as plt
import cv2
import glob
import random
org_image_root = "./Segmentation01/train/org/"
label_image_root = "./Segmentation01/train/label/"


org_images = sorted(glob.glob(org_image_root + "*.png"))

choice_org_images = random.sample(org_images, 3)
choice_label_images = [label_image_root + name.split("/")[-1] for name in choice_org_images]

# グラフ領域の作成
fig = plt.figure()

# 座標軸の作成

ax1 = fig.add_subplot(3, 2, 1)
ax2 = fig.add_subplot(3, 2, 2)
ax3 = fig.add_subplot(3, 2, 3)
ax4 = fig.add_subplot(3, 2, 4)
ax5 = fig.add_subplot(3, 2, 5)
ax6 = fig.add_subplot(3, 2, 6)

# データのプロット
ax1.imshow(cv2.cvtColor(cv2.imread(choice_org_images[0]), cv2.COLOR_BGR2RGB))
ax2.imshow(cv2.cvtColor(cv2.imread(choice_label_images[0]), cv2.COLOR_BGR2RGB))
ax3.imshow(cv2.cvtColor(cv2.imread(choice_org_images[1]), cv2.COLOR_BGR2RGB))
ax4.imshow(cv2.cvtColor(cv2.imread(choice_label_images[1]), cv2.COLOR_BGR2RGB))
ax5.imshow(cv2.cvtColor(cv2.imread(choice_org_images[2]), cv2.COLOR_BGR2RGB))
ax6.imshow(cv2.cvtColor(cv2.imread(choice_label_images[2]), cv2.COLOR_BGR2RGB))

ax1.axis("off")
ax2.axis("off")
ax3.axis("off")
ax4.axis("off")
ax5.axis("off")
ax6.axis("off")

plt.tight_layout()

# グラフの表示
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob
import PIL
from PIL import Image
import argparse


import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

import torchmetrics
from torchmetrics.functional import accuracy, iou

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from torchvision import transforms
from torchvision.models.segmentation import deeplabv3

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--patience", type=int, default=10)
args = parser.parse_args(args=[])

In [None]:
# Dataset and DataLoader
## RGB image
train_img_list = sorted(glob.glob("/content/Segmentation01/train/org/*.png"))
test_img_list = sorted(glob.glob("/content/Segmentation01/test/org/*.png"))

## label mask image
train_label_list = sorted(glob.glob("/content/Segmentation01/train/label/*.png"))
test_label_list = sorted(glob.glob("/content/Segmentation01/test/label/*.png"))


In [None]:
class CustomDataset(data.Dataset):
  def __init__(self, img_path_list, label_path_list, args):
    self.image_path_list = img_path_list
    self.label_path_list = label_path_list
    self.transform = transforms.Compose([transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor()])

  def __len__(self):
    return len(self.image_path_list)

  def __getitem__(self, index):
    img = Image.open(self.image_path_list[index]).convert("RGB")
    img = self.transform(img)

    label = Image.open(self.label_path_list[index])
    label = self.transform(label)

    return img, label




In [None]:
# create dataset and dataloader
train_dataset = CustomDataset(train_img_list, train_label_list, args)
test_dataset = CustomDataset(test_img_list, test_label_list, args)
dataloader = {
    "train": data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True),
    "val": data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
}

In [None]:
# Define Network
class Net(pl.LightningModule):
  def __init__(self, lr:float):
    super().__init__()
    self.lr = lr
    self.model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    self.model.classifier = deeplabv3.DeepLabHead(2048, 1)


  def forward(self, x):
    h = self.model(x)
    return h

  def training_step(self, batch, batch_idx):
    x, t = batch
    out = self(x)
    y = torch.sigmoid(out['out'])
    loss = F.binary_cross_entropy_with_logits(out["out"], t)

    self.log('train_loss', loss, on_step=True, on_epoch=True)
    self.log('train_acc', accuracy(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
    self.log('train_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)

    return loss

  def validation_step(self, batch, batch_idx):
    x, t = batch
    out = self(x)
    y = torch.sigmoid(out['out'])
    loss = F.binary_cross_entropy_with_logits(out["out"], t)

    self.log('val_loss', loss, on_step=True, on_epoch=True)
    self.log('val_acc', accuracy(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)
    self.log('val_iou', iou(y, t.int()), on_step=True, on_epoch=True, prog_bar=True)

    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters())
    return optimizer



In [None]:
SAVE_MODEL_PATH = "/content/model/"
model_checkpoint = ModelCheckpoint(
    SAVE_MODEL_PATH,
    filename="UNet-" + "{epoch:02d}-{val_loss:.2f}",
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=False,
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=args.patience,
)

In [None]:
pl.seed_everything(0)
net = Net(lr=args.lr)
trainer = pl.Trainer(max_epochs=args.epochs, callbacks=[model_checkpoint, early_stopping], gpus=1)
trainer.fit(net, dataloader["train"], dataloader["val"])

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir lightning_logs/

In [None]:
trainer.callback_metrics

In [None]:
# deeplabv3などのpretrainedしたモデルをbackboneに用いると50枚程度の多様性を持ったデータセットならマスクイメージを算出可能？
# mask imageの領域が大きさとの関連は不明