<a href="https://colab.research.google.com/github/satake12345/pytorch_auto_encoder/blob/main/kobe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [16]:
# cd "/content/drive/MyDrive/Colab Notebooks/imgs"
# cd "/content/drive/MyDrive/Colab Notebooks/imgs/create_img"
# rm -r *.jpg
# rm -r *.zip

In [None]:
!bash

In [None]:
!pwd

In [None]:
import os
import glob
#学習用正常データの読み出し
good_list = glob.glob(os.path.join( "/content/drive/MyDrive/Colab Notebooks/imgs/bottle/train/good/" , '*'))

#評価用正常データの読み出し
good_test_list = glob.glob(os.path.join("/content/drive/MyDrive/Colab Notebooks/imgs/bottle/test/good/" , '*'))

#評価用異常データの読み出し
bad_test_list = glob.glob(os.path.join("/content/drive/MyDrive/Colab Notebooks/imgs/bottle/test/broken_large" , '*')) + glob.glob(os.path.join("/content/drive/MyDrive/Colab Notebooks/imgs/bottle/test/broken_small" , '*')) + glob.glob(os.path.join("/content/drive/MyDrive/Colab Notebooks/imgs/bottle/test/contamination" , '*'))
#正常・異常データの数を確認
print(f"good {len(good_list)} good_test {len(good_test_list)} bad {len(bad_test_list)}")


In [None]:
good_test_list += good_list[:43]
good_list = good_list[43:]
print(f"good {len(good_list)} good_test {len(good_test_list)} bad {len(bad_test_list)}")


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import math

def view_img(img_list,title = "",size = (15,15)):
    fig = plt.figure(figsize=size)
    fig.suptitle(title)
    for i,d in enumerate(img_list):
        ax = fig.add_subplot(math.ceil(len(img_list)/2),2,i+1)
        ax.imshow(Image.open(d))
        ax.set_title(d)
        ax.axis("off")
    
#正常画像
view_img(good_list[:4],title = "good")
#異常画像
view_img(bad_test_list[:4],title = "bad")

In [7]:
import torchvision.transforms as T
from torch.utils.data import DataLoader,Dataset
from PIL import Image

# データセット関数の定義
class Custom_Dataset(Dataset):
  def __init__(self,img_list):
    self.img_list = img_list
    self.prepocess = T.Compose([T.Resize((128,128)),
                                T.ToTensor(),
                                ])
  def __getitem__(self,idx):
    img = Image.open(self.img_list[idx])
    img = self.prepocess(img)
    return img
  def __len__(self):
    return len(self.img_list)

#データを学習用・評価用に8:2へ分割
train_list = good_list[:int(len(good_list)*0.8)]
val_list = good_list[int(len(good_list)*0.8):]

train_dataset = Custom_Dataset(train_list)
val_dataset = Custom_Dataset(val_list)
train_loader = DataLoader(train_dataset,batch_size = 32)
val_loader = DataLoader(val_dataset,batch_size = 32)

In [8]:
import torch.nn as nn
import torchvision

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel,self).__init__() 
        # Encoderの構築。
        # nn.Sequential内にはEncoder内で行う一連の処理を記載する。
        # create_convblockは複数回行う畳み込み処理をまとめた関数。
        # 畳み込み→畳み込み→プーリング→畳み込み・・・・のような動作
        self.Encoder = nn.Sequential(self.create_convblock(3,16),     #256
                                     nn.MaxPool2d((2,2)),
                                     self.create_convblock(16,32),    #128
                                     nn.MaxPool2d((2,2)),
                                     self.create_convblock(32,64),    #64
                                     nn.MaxPool2d((2,2)),
                                     self.create_convblock(64,128),   #32
                                     nn.MaxPool2d((2,2)),
                                     self.create_convblock(128,256),  #16
                                     nn.MaxPool2d((2,2)),
                                     self.create_convblock(256,512),  #8
                                    )
        # Decoderの構築。
        # nn.Sequential内にはDecoder内で行う一連の処理を記載する。
        # create_convblockは複数回行う畳み込み処理をまとめた関数。
        # deconvblockは逆畳み込みの一連の処理をまとめた関数
        # 逆畳み込み→畳み込み→畳み込み→逆畳み込み→畳み込み・・・・のような動作
        self.Decoder = nn.Sequential(self.create_deconvblock(512,256), #16
                                     self.create_convblock(256,256),
                                     self.create_deconvblock(256,128), #32
                                     self.create_convblock(128,128),
                                     self.create_deconvblock(128,64),  #64
                                     self.create_convblock(64,64),
                                     self.create_deconvblock(64,32),   #128
                                     self.create_convblock(32,32),
                                     self.create_deconvblock(32,16),   #256
                                     self.create_convblock(16,16),
                                    )
        # 出力前の調整用
        self.last_layer = nn.Conv2d(16,3,1,1)
                                        
    # 畳み込みブロックの中身                            
    def create_convblock(self,i_fn,o_fn):
        conv_block = nn.Sequential(nn.Conv2d(i_fn,o_fn,3,1,1),
                                   nn.BatchNorm2d(o_fn),
                                   nn.ReLU(),
                                   nn.Conv2d(o_fn,o_fn,3,1,1),
                                   nn.BatchNorm2d(o_fn),
                                   nn.ReLU()
                                  )
        return conv_block
    # 逆畳み込みブロックの中身
    def create_deconvblock(self,i_fn , o_fn):
        deconv_block = nn.Sequential(nn.ConvTranspose2d(i_fn, o_fn, kernel_size=2, stride=2),
                                      nn.BatchNorm2d(o_fn),
                                      nn.ReLU(),
                                     )
        return deconv_block

    # データの流れを定義     
    def forward(self,x):
        x = self.Encoder(x)
        x = self.Decoder(x)
        x = self.last_layer(x)           
        return x

In [None]:
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
import torch

epoch_num = 1000

device = 'cuda'

best_loss = None
model = CustomModel().to(device)
limit_epoch = 100


optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
#criterion = nn.BCEWithLogitsLoss()
loss_list = {"train":[],"val":[]}

counter = 0
for e in range(epoch_num):
    total_loss = 0
    model.train()
    with tqdm(train_loader) as pbar:
        for itr , data in enumerate(pbar):
            optimizer.zero_grad()
            data = data.to(device)
            output = model(data)
            loss = criterion(output , data)
            total_loss += loss.detach().item()
            pbar.set_description(f"[train] Epoch {e+1:03}/{epoch_num:03} Itr {itr+1:02}/{len(pbar):02} Loss {total_loss/(itr+1):.3f}")
            loss.backward()
            optimizer.step()
    
    loss_list["train"].append(total_loss)
    total_loss = 0
    model.eval()
    with tqdm(val_loader) as pbar:
        for itr , data in enumerate(pbar):
            data = data.to(device)
            with torch.no_grad():
                output = model(data)
            loss = criterion(output , data)
            total_loss += loss.detach().item()
            pbar.set_description(f"[ val ] Epoch {e+1:03}/{epoch_num:03} Itr {itr+1:02}/{len(pbar):02} Loss {total_loss/(itr+1):.3f}")
    
    if best_loss is None or best_loss > total_loss/(itr+1):
        if best_loss is not None:
            print(f"update best_loss {best_loss:.6f} to {total_loss/(itr+1):.6f}")
        best_loss = total_loss/(itr+1)
        model_path = 'kobe_model.pth'
        torch.save(model.state_dict(), model_path)
        counter = 0
    else:
        counter += 1
        if limit_epoch <= counter:
            break
    loss_list["val"].append(total_loss)

In [None]:
import numpy as np
from PIL import Image


model = CustomModel().cuda()

model_path = 'kobe_model.pth'
model.load_state_dict(torch.load(model_path))

margin_w = 10
prepocess = T.Compose([T.Resize((128,128)),
                                T.ToTensor(),
                                ])
model.eval()
loss_list = []
labels = [0]*len(good_test_list) + [1]*len(bad_test_list)
for idx , path in enumerate(tqdm(good_test_list + bad_test_list)):

    img = Image.open(path)
    img = prepocess(img).unsqueeze(0).cuda()
    with torch.no_grad():
        output = model(img)[0]
    output = output.cpu().numpy().transpose(1,2,0)
    output = np.uint8(np.maximum(np.minimum(output*255 ,255),0))
    origin = np.uint8(img[0].cpu().numpy().transpose(1,2,0)*255)
    
    
    diff = np.uint8(np.abs(output.astype(np.float32) - origin.astype(np.float32)))
    loss_list.append(np.sum(diff))
    heatmap = cv2.applyColorMap(diff , cv2.COLORMAP_JET)
    margin = np.ones((diff.shape[0],margin_w,3))*255
    
    result = np.concatenate([origin[:,:,::-1],margin,output[:,:,::-1],margin,heatmap],axis = 1)
    label = 'good' if idx < len(good_test_list) else 'bad'
    cv2.imwrite(f"./create_img/{idx}_{label}.jpg",result)

In [None]:
print( "finish." )

In [None]:
# cd "/content/drive/MyDrive/Colab Notebooks/imgs"
# cd "/content/drive/MyDrive/Colab Notebooks/imgs/create_img"
# rm -r *.jpg
# rm -r *.zip

In [None]:
# !bash