編集部注：2023年5月29日最終更新．コードの一部がお手元の書籍と異なる可能性がございます．正誤・更新情報は弊社ウェブサイトの[本書詳細ページ](https://www.yodosha.co.jp/jikkenigaku/book/9784758122634/index.html)をご参照ください．

In [None]:
from google.colab import drive
drive.mount('/content/drive')
path = "/content/drive/MyDrive/sec6/"

In [None]:
from skimage.io import imread
from skimage.transform import resize
from skimage.color import rgb2gray
from sklearn.model_selection import train_test_split

from time import time
from copy import deepcopy
from tqdm import tqdm 
import glob

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
folder = ['benign', 'malignant', 'normal']
def load_data():
  images = []
  masks = []
  for i,folname in enumerate(tqdm(folder)):
      files1 = glob.glob(path + "Dataset_BUSI_with_GT/" + folname +"/*_mask.png");
      for j, file in enumerate(tqdm(files1)):
          filename =   file.rsplit("/",1)[1].rsplit("_",1)[0]
          image = resize(imread(path + "Dataset_BUSI_with_GT/" + folname +"/"+filename + ".png",as_gray=True), output_shape=(128, 128))
          mask = resize(imread(path + "Dataset_BUSI_with_GT/" + folname +"/"+filename + "_mask.png",as_gray=True), output_shape=(128, 128))
          images.append(image)
          masks.append(mask)
  return np.array(images), np.array(masks)
images, masks = load_data()
print(f"Images shape: {images.shape}",
      f"Masks shape: {masks.shape}\n")

In [None]:
def flip(images, labels, axis):
    aug_images = np.flip(images, axis) #画像の変換
    aug_labels = np.flip(labels, axis) #ラベルの変換
    return aug_images, aug_labels
def augment(images, labels):
    aug_y_images, aug_y_labels = flip(images, labels, axis=2) #縦
    images = np.concatenate([images, aug_y_images])
    labels = np.concatenate([labels, aug_y_labels])
    aug_x_images, aug_x_labels = flip(images, labels, axis=1) #横
    images = np.concatenate([images, aug_x_images])
    labels = np.concatenate([labels, aug_x_labels])
    return images, labels
images, masks = augment(images, masks)
print(f"Images shape: {images.shape}",f"Masks shape: {masks.shape}\n")

In [None]:
#Figure output
f, axis = plt.subplots(nrows=4, ncols=2, constrained_layout=True, figsize=(10, 10))
for i in range(4):
    axis[i, 0].imshow(images[i], cmap="gray")
    axis[i, 1].imshow(masks[i], cmap="gray")
plt.show()

In [None]:
images = np.expand_dims(images, axis=3)
masks = np.expand_dims(masks, axis=3)
x_train, x_val, y_train, y_val = train_test_split(images, masks, test_size= 0.1, shuffle=True, random_state=1111)
x_test, x_val, y_test, y_val = train_test_split(x_val, y_val, test_size=0.5, shuffle=True, random_state=11)

print(f"Train arrays shape: {x_train.shape}, {y_train.shape}")
print(f"Test arrays shape: {x_test.shape}, {y_test.shape}")
print(f"Validation arrays shape: {x_val.shape}, {y_val.shape}")

In [None]:
#GPUへ　Tentor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
images_format = torch.float32
masks_format = torch.float32

# Free memory
del images, masks

# Convert Numpy to tensors
train_inputs = torch.from_numpy(x_train).to(images_format).to(device)
train_outputs = torch.from_numpy(y_train).to(masks_format).to(device)
val_inputs = torch.from_numpy(x_val).to(images_format).to(device)
val_outputs = torch.from_numpy(y_val).to(masks_format).to(device)
test_inputs = torch.from_numpy(x_test).to(images_format).to(device)
test_outputs = torch.from_numpy(y_test).to(masks_format).to(device)
train_inputs = train_inputs.permute(0, 3, 1, 2)
val_inputs = val_inputs.permute(0, 3, 1, 2)
test_inputs = test_inputs.permute(0, 3, 1, 2)
train_outputs = train_outputs.permute(0, 3, 1, 2)
val_outputs = val_outputs.permute(0, 3, 1, 2)
test_outputs = test_outputs.permute(0, 3, 1, 2)

print(f"Train tensor shape: {train_inputs.shape}, {train_outputs.shape}")
print(f"Test tensor shape: {test_inputs.shape}, {test_outputs.shape}")
print(f"Validation tensor shape: {val_inputs.shape}, {val_outputs.shape}")

In [None]:
#model specificaition
class conv2(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(conv2, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x
class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()
        self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
    def forward(self, x1, x2):
        x2 = self.up_scale(x2)
        diffY = x1.size()[2] - x2.size()[2]
        diffX = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return x
class down_layer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down_layer, self).__init__()
        self.pool = nn.AvgPool2d(2, stride=2, padding=0)
        self.conv = conv2(in_ch, out_ch)
    def forward(self, x):
        x = self.conv(self.pool(x))
        return x
class up_layer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up_layer, self).__init__()
        self.up = up(in_ch, out_ch)
        self.conv = conv2(in_ch, out_ch)
    def forward(self, x1, x2):
        a = self.up(x1, x2)
        x = self.conv(a)
        return x
    
class unet(nn.Module):
    def __init__(self):
        super(unet, self).__init__()
        self.conv1 = conv2(1, 8)
        self.down1 = down_layer(8, 16)
        self.down2 = down_layer(16, 32)
        self.down3 = down_layer(32, 64)
        self.down4 = down_layer(64, 128)
        self.down5 = down_layer(128, 256)
        self.down6 = down_layer(256, 512)
        self.down7 = down_layer(512, 1024)
        self.up1 = up_layer(1024, 512)
        self.up2 = up_layer(512, 256)
        self.up3 = up_layer(256, 128)
        self.up4 = up_layer(128, 64)
        self.up5 = up_layer(64, 32)
        self.up6 = up_layer(32, 16)
        self.up7 = up_layer(16, 8)
        self.last_conv = nn.Conv2d(8, 1, 1)
        self.dilute = nn.Conv2d(1, 1, 1)
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x7 = self.down6(x6)
        x8 = self.down7(x7)
        x1_up = self.up1(x7, x8)
        x2_up = self.up2(x6, x1_up)
        x3_up = self.up3(x5, x2_up)
        x4_up = self.up4(x4, x3_up)
        x5_up = self.up5(x3, x4_up)
        x6_up = self.up6(x2, x5_up)
        x7_up = self.up7(x1, x6_up)
    
        output = self.last_conv(x7_up)
        output = self.dilute(output)
        output = torch.sigmoid(output)
        return output

In [None]:
#Loss function
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()
    def forward(self, inputs, targets, smooth=1):    
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice
unet = unet().to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)
criterion = DiceLoss()

In [None]:
#Model Training
best_loss = np.inf
epochs = 200
patience = 5
batch_size = 20

train_losses = []
val_losses = []
n_epochs = []

total_train = train_inputs.size()[0]
total_val = val_inputs.size()[0]

t0 = time()
for epoch in range(epochs):
    unet.train()
    running_loss = 0
    running_valloss = 0
    n_epochs.append(epoch)
    train_perm = torch.randperm(total_train)
    val_perm = torch.randperm(total_val)

    for i in range(0, total_train, batch_size):
        optimizer.zero_grad()
        indices = train_perm[i:i+batch_size]
        batch_x, batch_y = train_inputs[indices], train_outputs[indices]
        outputs = unet(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        running_loss += loss

    running_loss = running_loss.cpu().detach() / total_train
    train_losses.append(running_loss)

    for j in range(0, total_val, batch_size):
        unet.eval()
        indices = val_perm[j:j+batch_size]
        batch_x, batch_y = val_inputs[indices], val_outputs[indices]
        outputs = unet(batch_x)
        loss = criterion(outputs, batch_y)
        running_valloss += loss

    running_valloss = running_valloss.cpu().detach() / total_val
    val_losses.append(running_valloss)


    if running_valloss < best_loss:
        best_loss = running_valloss
        cost_patience = patience
        state_dict = deepcopy(unet.state_dict())
        print(f"\tEpoch: {epoch+1}/{epochs}, ", f"Train Loss: {running_loss:.3g}, ", f"Val Loss: {running_valloss:.3g}")

    else:
        cost_patience -= 1
        if cost_patience < 0:
            print(f"\nEarly stopping after {patience} epochs of no improvements")
            break

        else:
            print(f"\tEpoch: {epoch+1}/{epochs}, ",f"Train Loss: {running_loss:.3g}, ", f"Val Loss: {running_valloss:.3g} - No improvement", f"-> Remaining patience: {cost_patience}")



In [None]:
plt.plot(n_epochs, train_losses, label='train_loss')
plt.plot(n_epochs, val_losses, label='val_loss')
plt.legend(loc='upper center')
plt.ylabel('loss')
plt.xlabel('epoch')
no_improvement_line = n_epochs[-1] - patience
plt.axvline(x=no_improvement_line, color='r')
plt.show()

In [None]:
torch.save(state_dict, "model_unet.pth")

In [None]:
#Prediction

predictions = unet(test_inputs).cpu().detach().numpy()
images = test_inputs.cpu().detach().numpy()
masks = test_outputs.cpu().detach().numpy()

f, axis = plt.subplots(nrows=4, ncols=3, constrained_layout=True, figsize=(10, 10))
for i in range(4):
    axis[i, 0].imshow(images[i][0], cmap="gray")
    axis[i, 0].set_title("image")
    axis[i, 1].imshow(masks[i][0], cmap="gray")
    axis[i, 1].set_title("target")
    axis[i, 2].imshow(predictions[i][0], cmap="gray")
    axis[i, 2].set_title("predict")

plt.show()