In [1]:
# パッケージのimport
import random
import math
import time
import glob
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import transforms

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
class gen_lyr:
    def input_lyr(z_dim, image_size):
        lyr = nn.Sequential(
            nn.ConvTranspose2d(z_dim, image_size ** 2 // 8, 
                              kernel_size=4, stride=1),
            nn.BatchNorm2d(image_size ** 2 // 8),
            nn.ReLU(inplace=True))
        return lyr

    def hdn_lyr(image_size, i):
        lyr = nn.Sequential(
            nn.ConvTranspose2d(image_size * (2 ** (i+1)),
                               image_size * (2 ** i), 
                               kernel_size=4, 
                               stride=2, padding=1),
            nn.BatchNorm2d(image_size * (2 ** i)),
            nn.ReLU(inplace=True))
        return lyr

    def output_lyr(image_size):
        lyr = nn.Sequential(
            nn.ConvTranspose2d(image_size, 1, kernel_size=4,
                               stride=2, padding=1),
                nn.Tanh())
        return lyr

class Generator(nn.Module):
    def __init__(self, z_dim=20, image_size=64):
        super().__init__()
        
        self.input_lyr = gen_lyr.input_lyr(z_dim, image_size)
        self.hdn_lyr2 = gen_lyr.hdn_lyr(image_size, 2)
        self.hdn_lyr1 = gen_lyr.hdn_lyr(image_size, 1)
        self.hdn_lyr0 = gen_lyr.hdn_lyr(image_size, 0)
        self.output_lyr = gen_lyr.output_lyr(image_size)
        
    def forward(self, z):
        out = self.input_lyr(z)
        out = self.hdn_lyr2(out)
        out = self.hdn_lyr1(out)
        out = self.hdn_lyr0(out)
        out = self.output_lyr(out)

        return out
class dis_lyr:
    def input_lyr(z_dim, image_size):
        lyr = nn.Sequential(
            nn.Conv2d(1, image_size, kernel_size=4,
                     stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        return lyr
    
    def hdn_lyr(image_size, i):
        lyr = nn.Sequential(
            nn.Conv2d(image_size * (2 ** i),
                      image_size * (2 ** (i+1)),
                      kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        return lyr
    
    def output_lyr(image_size):
        lyr = nn.Sequential(
            nn.Conv2d(image_size ** 2 // 8, 1,
                      kernel_size=4, stride=1))
        return lyr
    
class Discriminator(nn.Module):
    def __init__(self, z_dim=20, image_size=64):
        super().__init__()
        self.input_lyr = dis_lyr.input_lyr(z_dim, image_size)
        self.hdn_lyr0 = dis_lyr.hdn_lyr(image_size, 0)
        self.hdn_lyr1 = dis_lyr.hdn_lyr(image_size, 1)
        self.hdn_lyr2 = dis_lyr.hdn_lyr(image_size, 2)
        self.output_lyr = dis_lyr.output_lyr(image_size)
        
    def forward(self, x):
        out = self.input_lyr(x)
        out = self.hdn_lyr0(out)
        out = self.hdn_lyr1(out)
        out = self.hdn_lyr2(out)

        feature = out  # 最後にチャネルを1つに集約する手前の情報
        feature = feature.view(feature.size()[0], -1)  # 2次元に変換

        out = self.output_lyr(out)

        return out, feature

In [5]:
G = Generator(z_dim=10, image_size=64)
D = Discriminator(z_dim=10, image_size=64)
G.to(device)
D.to(device)

Discriminator(
  (input_lyr): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1, inplace=True)
  )
  (hdn_lyr0): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1, inplace=True)
  )
  (hdn_lyr1): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1, inplace=True)
  )
  (hdn_lyr2): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.1, inplace=True)
  )
  (output_lyr): Sequential(
    (0): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)

In [6]:
# 使用するモデル
G_load_path = "/mnt/home/atsuk/Desktop/pytorch_advanced/6_gan_anomaly_detection/G_ano_dc_64_10_30ep.pth"
D_load_path = "/mnt/home/atsuk/Desktop/pytorch_advanced/6_gan_anomaly_detection/D_ano_dc_64_10_30ep.pth"

In [7]:
# GPU上で保存された重みをCPU上でロードする場合
G_load_weights = torch.load(G_load_path, map_location={'cuda:0': 'cpu'})
G.load_state_dict(G_load_weights)
D_load_weights = torch.load(D_load_path, map_location={'cuda:0': 'cpu'})
D.load_state_dict(D_load_weights)

<All keys matched successfully>

In [8]:
def Anomaly_score(x, fake_img, D, Lambda=0.1):

    # テスト画像xと生成画像fake_imgのピクセルレベルの差の絶対値を求めて、ミニバッチごとに和を求める
    residual_loss = torch.abs(x-fake_img)
    residual_loss = residual_loss.view(residual_loss.size()[0], -1)
    residual_loss = torch.sum(residual_loss, dim=1)

    # テスト画像xと生成画像fake_imgを識別器Dに入力し、特徴量を取り出す
    _, x_feature = D(x)
    _, G_feature = D(fake_img)

    # テスト画像xと生成画像fake_imgの特徴量の差の絶対値を求めて、ミニバッチごとに和を求める
    discrimination_loss = torch.abs(x_feature-G_feature)
    discrimination_loss = discrimination_loss.view(
        discrimination_loss.size()[0], -1)
    discrimination_loss = torch.sum(discrimination_loss, dim=1)

    # ミニバッチごとに2種類の損失を足し算する
    loss_each = (1-Lambda)*residual_loss + Lambda*discrimination_loss

    # ミニバッチ全部の損失を求める
    total_loss = torch.sum(loss_each)

    return total_loss, loss_each, residual_loss

In [9]:
def make_datapath_list(path, i):
    """学習、検証の画像データとアノテーションデータへのファイルパスリストを作成する。 """
    til = glob.glob(path)
    random.shuffle(til)
    if i == "a":
        train_img_list = til[:]
    else:
        train_img_list = til[:i]
    return train_img_list

In [10]:
class ImageTransform():
    """画像の前処理クラス"""

    def __init__(self, resize):
        self.data_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize(resize),  # リサイズ
            transforms.ToTensor(),
            # transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

In [11]:
class GAN_Img_Dataset(data.Dataset):
    """画像のDatasetクラス。PyTorchのDatasetクラスを継承"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''画像の枚数を返す'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''前処理をした画像のTensor形式のデータを取得'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  # [高さ][幅]白黒

        # 画像の前処理
        img_transformed = self.transform(img)

        return img_transformed

In [12]:
# DataLoaderの作成と動作確認

# ファイルリストを作成
num_of_data = "a"
# path = '/media/tamahassam/Extreme SSD/normal_ct_visual/*.tif'
path = '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/*.png'
test_img_list=make_datapath_list(path, num_of_data)
random.shuffle(test_img_list)

# 異常のDataLoaderの作成

# Datasetを作成
# mean = (0.5,)
# std = (0.5,)
resize = 64
test_dataset = GAN_Img_Dataset(
    file_list= test_img_list, transform=ImageTransform(resize))

# DataLoaderを作成
batch_size = 1

test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False)


In [13]:
test_img_list[:5]

['/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_59.png',
 '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_9118.png',
 '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_8374.png',
 '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_15403.png',
 '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_14795.png']

In [14]:
import csv
with open('/mnt/home/atsuk/Desktop/abnormal_detection_data/abnormal_path_30_a.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(test_img_list)

In [15]:
from csv import reader

with open('/mnt/home/atsuk/Desktop/abnormal_detection_data/abnormal_path_30_a.csv', 'r') as csv_file:
    csv_reader = reader(csv_file)
    # Passing the cav_reader object to list() to get a list of lists
    list_of_rows = list(csv_reader)
    print(list_of_rows[0][:5])
    print(len(list_of_rows[0]))
    # print(type(csv_reader))

['/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_59.png', '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_9118.png', '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_8374.png', '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_15403.png', '/mnt/home/atsuk/Desktop/abnormal_detection_data/covid_ct_good3_final/covid_14795.png']
11008


In [None]:
losses = []

for i in tqdm(range(len(test_img_list))):
# for i in tqdm(range(2500)):
    # 異常検知したい画像
    # テストデータの確認
    batch_iterator = iter(test_dataloader)  # イテレータに変換
    imges = next(batch_iterator)  

    x = imges[[0]]
    x = x.to(device)

    # 異常検知したい画像を生成するための、初期乱数
    z = torch.randn(1, 10).to(device)
    z = z.view(z.size(0), z.size(1), 1, 1)

    # 変数zを微分を求めることが可能なように、requires_gradをTrueに設定
    z.requires_grad = True

    # 変数zを更新できるように、zの最適化関数を求める
    z_optimizer = torch.optim.Adam([z], lr=1e-3)

    # zを求める
    for epoch in range(5000+1):
        fake_img = G(z)
        loss, _, _ = Anomaly_score(x, fake_img, D, Lambda=0.1)

        z_optimizer.zero_grad()
        loss.backward()
        z_optimizer.step()

        # if epoch % 1000 == 0:
        #     print('epoch {} || loss_total:{:.0f} '.format(epoch, loss.item()))
        if epoch == 5000:
            print(f'loss: {round(loss.item())}')
            losses.append(round(loss.item()))

  0%|          | 1/11008 [03:24<624:31:11, 204.26s/it]

loss: 690


In [None]:
import csv
with open('/mnt/home/atsuk/Desktop/abnormal_detection_data/30_abnormal_losses_a.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(losses)