## Convolutional Auto Encoder

In [57]:
# packageのimport
from typing import Any, Union, Callable, Type, TypeVar
from tqdm.std import trange,tqdm
import numpy as np 
import numpy.typing as npt
import pandas as pd 
import matplotlib.pyplot as plt 
import plotly.express as px
import seaborn as sns
from PIL import Image
import cv2
import requests

# pytorch関連のimport
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from src.utils import set_seed
from src.layers import Reshape

In [7]:
SEED = 2023_6_27
set_seed(SEED)

### AutoEncoderに畳み込みレイヤを利用しよう

MNISTを学習する例ではデータ行列をflattenすることで通常のMLP構造のAEで画像データへ対応しました．AEで次元削減する対象は画像データだけではないですが，画像に対して利用する場合はやはり，畳み込みレイヤを利用するのが一般的です．

今回はデータセットとしてCIFAR-10を利用します．

In [25]:
train_data = torchvision.datasets.CIFAR10(
    './data',
    train=True,
    download=True,
    transform = T.Compose([T.ToTensor()]),
    )
test_data = torchvision.datasets.CIFAR10(
    './data',
    train=False,
    download=True,
    transform = T.Compose([T.ToTensor()]),
    )

train_loader = DataLoader(train_data, 64, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [56]:
encoder = nn.Sequential(
    nn.Conv2d(3, 16, 3, 1, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.MaxPool2d(2,2),
    nn.Conv2d(16, 8, 3, 1, padding=1),
    nn.BatchNorm2d(8),
    nn.ReLU(),
    nn.MaxPool2d(2,2),
    nn.Flatten(),
    nn.Linear(8*8*8, 50),
    nn.ReLU(),
)
print("-----input-----")
batch = next(iter(train_loader))
x,y = batch
print(x.shape)
print(x.shape)

print("-----output-----")
y = encoder(x)
print(y.shape)

-----input-----
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 32, 32])
-----output-----
torch.Size([64, 50])


In [63]:
decoder = nn.Sequential(
    nn.Linear(50, 8*8*8),
    nn.ReLU(),
    Reshape((8,8,8)),
    nn.ConvTranspose2d(8, 16, 3, 1, padding=1),
    nn.ReLU(),
    nn.ConvTranspose2d(16, 3, 3, 1, padding=1),
    nn.Sigmoid(),
)
decoder(y).shape

torch.Size([64, 3, 8, 8])

In [64]:
convae = nn.Sequential(encoder, decoder)

In [65]:
import skorch

ModuleNotFoundError: No module named 'skorch'