<a href="https://colab.research.google.com/github/tocom242242/notebooks/blob/master/pytorch/seg/simple_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import torch

x = torch.randn((1,3,256,256))
x.shape

torch.Size([1, 3, 256, 256])

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DownConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.max_pool= nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool(x)
        return x

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up_conv1 = nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1)
        self.conv1 = nn.Conv2d(out_channels*2, out_channels, 3, stride=1, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x1, x2):
        x = self.up_conv1(x1)
        x = self.relu(x)
        x = torch.cat([x, x2], dim=1)
        x = self.conv1(x)
        x = self.relu(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.down_conv1 = DownConv(in_channels, features[0])
        self.down_conv2 = DownConv(features[0], features[1])
        self.down_conv3 = DownConv(features[1], features[2])
        self.down_conv4 = DownConv(features[2], features[3])
        self.up_conv1 = UpConv(features[3], features[2])
        self.up_conv2 = UpConv(features[2], features[1])
        self.up_conv3 = UpConv(features[1], features[0])
        self.out_upconv = nn.ConvTranspose2d(features[0], out_channels, 3, stride=2, padding=1, output_padding=1)


    def forward(self, x):
        x1 = self.down_conv1(x)
        x2 = self.down_conv2(x1)
        x3 = self.down_conv3(x2)
        x4 = self.down_conv4(x3)
        x = self.up_conv1(x4, x3)
        x = self.up_conv2(x, x2)
        x = self.up_conv3(x, x1)
        x = self.out_upconv(x)
        return x

print("input x.shape:",x.shape)
model = UNet(in_channels=3, out_channels=1, features=[64, 128, 256, 512])
model(x).shape

input x.shape: torch.Size([1, 3, 256, 256])


torch.Size([1, 1, 256, 256])

In [19]:
import torch
import numpy as np
import torchvision.transforms as T
from torchvision.transforms import functional as F
from torchvision.datasets import VOCSegmentation
from torch.utils.data import DataLoader
from PIL import Image

# 同時に画像とマスクに変換を適用するためのヘルパークラス
class ComposeJoint(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ResizeJoint(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.resize(image, self.size, Image.BILINEAR)
        target = F.resize(target, self.size, Image.NEAREST)
        return image, target

# トランスフォームの定義
transform = ComposeJoint([
    ResizeJoint((256, 256)),
    # lambda image, target: (T.ToTensor()(image), torch.as_tensor(np.array(target), dtype=torch.int64))
    lambda image, target: (T.ToTensor()(image), T.ToTensor()(target))
])

# VOCデータセットの読み込み
dataset = VOCSegmentation(root='./data', year='2012', image_set='train', download=True, transforms=transform)

# DataLoaderの作成
data_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)



Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data


In [20]:
# データの確認
# for images, labels in data_loader:
#     print(images.shape, labels.shape)

In [21]:
# prompt: modelをgpuに載せて

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [22]:
a, b = next(iter(data_loader))
torch.unique(b)

tensor([0.0000, 0.0235, 0.0275, 0.0314, 0.0627, 0.0745, 1.0000])

In [28]:
# prompt: 上で読み込んだデータセットを先ほど定義したUNetを学習するコードを書いて

# 損失関数の定義
criterion = nn.CrossEntropyLoss()

# 最適化手法の定義
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# モデルの学習
for epoch in range(50):
    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)
        # モデルの出力を計算
        outputs = model(images)
        outputs = outputs.squeeze(1)
        labels = torch.squeeze(labels, 1)

        # 損失を計算
        loss = criterion(outputs,labels.float())

        # 勾配の計算
        loss.backward()

        # パラメータの更新
        optimizer.step()

    # エポックごとに損失を表示
    print('Epoch {}: Loss: {}'.format(epoch, loss.item()))


Epoch 0: Loss: 74.68472290039062
Epoch 1: Loss: 85.310302734375
Epoch 2: Loss: 113.46865844726562
Epoch 3: Loss: 66.2655029296875
Epoch 4: Loss: 46.95069885253906
Epoch 5: Loss: 116.72239685058594
Epoch 6: Loss: 138.3501434326172
Epoch 7: Loss: 91.05636596679688
Epoch 8: Loss: 62.99249267578125
Epoch 9: Loss: 50.462982177734375
Epoch 10: Loss: 103.19679260253906
Epoch 11: Loss: 56.5571174621582
Epoch 12: Loss: 100.02284240722656
Epoch 13: Loss: 61.33900451660156
Epoch 14: Loss: 40.6544189453125
Epoch 15: Loss: 61.72590255737305
Epoch 16: Loss: 136.92759704589844
Epoch 17: Loss: 74.5814208984375
Epoch 18: Loss: 89.07098388671875
Epoch 19: Loss: 63.640045166015625
Epoch 20: Loss: 107.43726348876953
Epoch 21: Loss: 62.74976348876953
Epoch 22: Loss: 61.31978225708008
Epoch 23: Loss: 84.89266967773438
Epoch 24: Loss: 74.0421371459961
Epoch 25: Loss: 67.83377075195312
Epoch 26: Loss: 86.34810638427734
Epoch 27: Loss: 154.90670776367188
Epoch 28: Loss: 77.73731231689453
Epoch 29: Loss: 64.250

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[0], features[1], 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[1], features[2], 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[2], features[3], 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features[3], features[2], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[2], features[1], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[1], features[0], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[0], out_channels, 3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = UNet()
model(x).shape

In [None]:
# prompt: create or load dataset for above unet

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_paths = os.listdir(os.path.join(root_dir, 'images'))
        self.mask_paths = os.listdir(os.path.join(root_dir, 'masks'))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.root_dir, 'images', self.image_paths[idx])
        mask_path = os.path.join(self.root_dir, 'masks', self.mask_paths[idx])
        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        return image, mask

train_dataset = CustomDataset('path/to/train_dataset')
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


In [None]:
# prompt: シンプルなU-Netを作成して

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[0], features[1], 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[1], features[2], 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(features[2], features[3], 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features[3], features[2], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[2], features[1], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[1], features[0], 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(features[0], out_channels, 3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = UNet()
model(x).shape