In [66]:
import os
import datetime

import torch 
import torch.nn as nn

from torch.utils.data import Dataset
from torch.autograd import Variable

from torchvision import datasets, transforms



from natsort import natsorted
from PIL import Image
from skimage import io, transform
from tqdm import tqdm

from torchinfo import summary

In [50]:
!pip3 install scikit-image

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting scikit-image
  Downloading scikit_image-0.17.2-cp36-cp36m-manylinux1_x86_64.whl (12.4 MB)
     |████████████████████████████████| 12.4 MB 4.5 MB/s            
Collecting tifffile>=2019.7.26
  Downloading tifffile-2020.9.3-py3-none-any.whl (148 kB)
     |████████████████████████████████| 148 kB 93.4 MB/s            
Collecting PyWavelets>=1.1.1
  Downloading PyWavelets-1.1.1-cp36-cp36m-manylinux1_x86_64.whl (4.4 MB)
     |████████████████████████████████| 4.4 MB 39.9 MB/s            
Collecting decorator<5,>=4.3
  Downloading decorator-4.4.2-py2.py3-none-any.whl (9.2 kB)
Installing collected packages: decorator, tifffile, PyWavelets, scikit-image
  Attempting uninstall: decorator
    Found existing installation: decorator 5.1.1
    Uninstalling decorator-5.1.1:
      Successfully uninstalled decorator-5.1.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that a

In [67]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [68]:
resnet = nn.Sequential(*list(model.children())[:-2])


In [69]:
INPUT_SHAPE = 256

In [70]:
summary(resnet, (1, 3, INPUT_SHAPE, INPUT_SHAPE), depth=6)

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               --                        --
├─Conv2d: 1-1                            [1, 64, 128, 128]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 128, 128]         128
├─ReLU: 1-3                              [1, 64, 128, 128]         --
├─MaxPool2d: 1-4                         [1, 64, 64, 64]           --
├─Sequential: 1-5                        [1, 64, 64, 64]           --
│    └─BasicBlock: 2-1                   [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-1                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-3                    [1, 64, 64, 64]           --
│    │    └─Conv2d: 3-4                  [1, 64, 64, 64]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 64, 64]           128
│    │    └─ReLU: 3-6                    [1, 64, 64, 64]           --
│

In [71]:

class BasicBlockDec(nn.Module):

    def __init__(self, shape):
        super().__init__()
        if shape == 512:
            shape2 = 512
        else:
            shape2 = int(shape * 2)
        
        self.convtrans1 = nn.ConvTranspose2d(shape2, shape, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(shape)
        self.convtrans2 = nn.ConvTranspose2d(shape, shape, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(shape)
        
    def forward(self, x):
        out1 = self.convtrans1(x)
        out2 = torch.relu(self.bn1(out1))
        out2 = self.convtrans2(out2)
        out2 = torch.relu(self.bn2(out2))
        final = torch.add(out1, out2)
        
        return final


class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], nc=3):
        super().__init__()
        self.layer1 = BasicBlockDec(512)
        self.layer2 = BasicBlockDec(256)
        self.layer3 = BasicBlockDec(128)
        self.layer4 = BasicBlockDec(64)
        self.conv1 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

#         self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, shape):
        return 
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = torch.sigmoid(self.conv1(x))
        return x

class AutoEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = resnet
        self.decoder = ResNet18Dec()
        self.conv1 = nn.Conv2d(512, 512, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(512, 512, kernel_size=1, stride=1)

    def forward(self, x):
#         mean, logvar = self.encoder(x)
        x = self.encoder(x)
        x = torch.relu(self.conv1(x))
#         x = torch.relu(self.conv2(x))
#         z = self.reparameterize(mean, logvar)
        x = self.decoder(x)
        return x
    


In [72]:
ae = AutoEncoder().cuda()
summary(ae, (1, 3, INPUT_SHAPE, INPUT_SHAPE), depth=6)

Layer (type:depth-idx)                        Output Shape              Param #
AutoEncoder                                   --                        --
├─Sequential: 1-1                             [1, 512, 8, 8]            --
│    └─Conv2d: 2-1                            [1, 64, 128, 128]         9,408
│    └─BatchNorm2d: 2-2                       [1, 64, 128, 128]         128
│    └─ReLU: 2-3                              [1, 64, 128, 128]         --
│    └─MaxPool2d: 2-4                         [1, 64, 64, 64]           --
│    └─Sequential: 2-5                        [1, 64, 64, 64]           --
│    │    └─BasicBlock: 3-1                   [1, 64, 64, 64]           --
│    │    │    └─Conv2d: 4-1                  [1, 64, 64, 64]           36,864
│    │    │    └─BatchNorm2d: 4-2             [1, 64, 64, 64]           128
│    │    │    └─ReLU: 4-3                    [1, 64, 64, 64]           --
│    │    │    └─Conv2d: 4-4                  [1, 64, 64, 64]           36,864
│    │ 

In [73]:
IMAGE_SIZE = 256
BATCH_SIZE = 64
EPOCHS = 75
LR = 0.0004

In [74]:
class LoadFromFolder(Dataset):
    def __init__(self, main_dir, transform):
        self.main_dir = main_dir
        self.transform = transform
        all_imgs = natsorted(os.listdir(main_dir))
        self.all_imgs_name = natsorted(all_imgs)
        self.imgs_loc = [os.path.join(self.main_dir, i) for i in self.all_imgs_name]

    def __len__(self):
        return len(self.all_imgs_name)
    
    def load_image(self, path):
        image = Image.open(path).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image
    
    def __getitem__(self, idx):
        
        # 後ほどsliceで画像を複数枚取得したいのでsliceでも取れるようにする
        if type(idx) == slice:
            paths = self.imgs_loc[idx]
            tensor_image = [self.load_image(path) for path in paths]
            tensor_image = torch.cat(tensor_image).reshape(len(tensor_image), *tensor_image[0].shape)
        elif type(idx) == int:
            path = self.imgs_loc[idx]
            tensor_image = self.load_image(path)
        return tensor_image

In [75]:
transform_dict = {
    "train": transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # IMAGE_SIZEにreshape
            transforms.ToTensor(),
        ]
    ),
    "test": transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # IMAGE_SIZEにreshape
            transforms.ToTensor(),
        ]
    ),
}

In [80]:
# train_root = './cap_dataset/white_omote_crop/train/good/'  # train dataの保存してあるディレクトリ
train_root = './cap_dataset/white_omote_crop_shadow/train/good/'  # train dataの保存してあるディレクトリ

train_dataset = LoadFromFolder(train_root, transform=transform_dict["train"])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [81]:
model = AutoEncoder().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [82]:
def train(model,dataloader,otpimizer,criterion,num_epochs):

    for epoch in range(1,num_epochs+1):
                
        total_loss = 0
        corrects = 0
        total = 0

        with tqdm(total=len(dataloader),unit="batch") as pbar:
            pbar.set_description(f"Epoch[{epoch}/{num_epochs}]")
            for imgs in dataloader: 
                imgs = Variable(imgs).cuda()
                output = model(imgs)
                loss = criterion(output, imgs)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total += imgs.size(0)


                total_loss += loss.data / total
                pbar.set_postfix({"loss":total_loss.item()})
                pbar.update(1)

In [83]:
EPOCHS = 100
train(model, train_loader, optimizer, criterion, EPOCHS)

Epoch[1/100]: 100%|██████████| 6/6 [00:11<00:00,  1.87s/batch, loss=0.00387]
Epoch[2/100]: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, loss=0.000859]
Epoch[3/100]: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, loss=0.000355]
Epoch[4/100]: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, loss=0.000207]
Epoch[5/100]: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, loss=0.000156]
Epoch[6/100]: 100%|██████████| 6/6 [00:09<00:00,  1.57s/batch, loss=0.000212]
Epoch[7/100]: 100%|██████████| 6/6 [00:09<00:00,  1.59s/batch, loss=0.000171]
Epoch[8/100]: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, loss=0.000136]
Epoch[9/100]: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, loss=0.000122]
Epoch[10/100]: 100%|██████████| 6/6 [00:09<00:00,  1.59s/batch, loss=0.000111]
Epoch[11/100]: 100%|██████████| 6/6 [00:09<00:00,  1.58s/batch, loss=0.000108]
Epoch[12/100]: 100%|██████████| 6/6 [00:09<00:00,  1.59s/batch, loss=8.94e-5]
Epoch[13/100]: 100%|██████████| 6/6 [00:09<00:00,  1.61s/batch,

In [84]:
today = datetime.datetime.today()
today = today.strftime('%m%d%H%M')
pkl_path = "resnet34AE_{}_{}epoch.pkl".format(today, EPOCHS)

In [85]:
torch.save(model.state_dict(), pkl_path)