## U-net

In [6]:
!pip install segmentation_models_pytorch
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.14.3 torchmetrics-1.7.1


In [None]:
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchmetrics.classification import BinaryJaccardIndex

In [None]:
# ハイパーパラメータ設定
batch_size = 8
num_epochs = 10
patience = 5
initial_lr = 1e-1

# モデル（U-Net）の初期化
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
    activation="sigmoid"
)

In [20]:
# 画像の前処理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# マスクの前処理
mask_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.PILToTensor()
])

# Oxford Petsデータセットを使用
full_dataset = OxfordIIITPet(
    root="./data",
    split="trainval",    # train+val用（テストには別途 test を使用）
    target_types="segmentation",
    download=True,
    transform=transform,
    target_transform=mask_transform
)

In [21]:
# train, validation, testへ分割
train_size = int(0.6 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 最適化関数・損失関数・評価指標
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=initial_lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.1)
iou_metric = BinaryJaccardIndex()

In [23]:
# EarlyStoppingクラス定義
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None or val_score > self.best_score:
            self.best_score = val_score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# 学習ループ
es = EarlyStopping(patience=patience)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    model.train()
    train_loss = 0.0
    train_iou = 0.0
    for images, masks in train_loader:
        masks = masks.float() / 255.0
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_iou += iou_metric(outputs, masks.int()).item()

    model.eval()
    val_loss = 0.0
    val_iou = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            masks = masks.float() / 255.0
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            val_iou += iou_metric(outputs, masks.int()).item()

    avg_train_loss = train_loss / len(train_loader)
    avg_train_iou = train_iou / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    avg_val_iou = val_iou / len(val_loader)

    print(f"Train Loss: {avg_train_loss:.4f}, Train IoU: {avg_train_iou:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}, Val IoU: {avg_val_iou:.4f}")

    scheduler.step(avg_val_iou)
    es(avg_val_iou)

    if es.early_stop:
        print("Early stopping triggered")
        break


Epoch 1/10


KeyboardInterrupt: 

In [None]:
# 評価
model.eval()
test_iou = 0.0
with torch.no_grad():
    for images, masks in test_loader:
        masks = masks.float() / 255.0
        outputs = model(images)
        test_iou += iou_metric(outputs, masks.int()).item()

avg_test_iou = test_iou / len(test_loader)
print(f"Test IoU: {avg_test_iou:.2f}")