MNIST codeをworm2vec用に変更

In [1]:
# License: BSD
# Author: Ghassen Hamrouni

from __future__ import print_function

import os
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt
from skimage.io import imread
from PIL import Image

import numpy as np

plt.ion()   # interactive mode

## Define Dataset

In [2]:
class WormDataset(torch.utils.data.Dataset):

    training_dir = '201302081337/main'
    test_dir = '201302081353/main'

    def __init__(self, root, train=True, transform=None):

        self.root = root    # root_dir \Tanimoto_eLife_Fig3B or \unpublished control
        self.train = train  # training set or test set
        self.transform = transform

        if not self._check_exists():
            raise RuntimeError('Dataset not found.')

        if self.train:
            data_dir = self.training_dir
        else:
            data_dir = self.test_dir

        self.data = glob.glob(self.root + data_dir + "/*")
        self.targets = self.data.copy()

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where image == target.
        """
        img, target = self.data[index], self.targets[index]

        img = Image.open(img)
        target = Image.open(target)

        if self.transform is not None:
            img = self.transform(img)
            target = self.transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


    def _check_exists(self):
        print(self.root + self.training_dir)
        return (os.path.exists(self.root + self.training_dir) and
                os.path.exists(self.root + self.test_dir))


## transform

In [3]:
worm_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()])

## Dataset

In [4]:
train_set = WormDataset(root="F:\Tanimoto_eLife_Fig3B\\", train=True,
    transform=worm_transforms)

test_set = WormDataset(root="F:\Tanimoto_eLife_Fig3B\\", train=False,
    transform=worm_transforms)


F:\Tanimoto_eLife_Fig3B\201302081337/main
F:\Tanimoto_eLife_Fig3B\201302081337/main


## Dataloader

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training dataset
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=256, shuffle=True)
# Test dataset
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=256, shuffle=True)

## Define model

In [6]:
"""
vae model from https://github.com/podgorskiy/VAE
"""
class VAE(nn.Module):
    def __init__(self, zsize, layer_count=3, channels=3):
        super(VAE, self).__init__()

        d = 128
        self.d = d
        self.zsize = zsize

        self.layer_count = layer_count

        mul = 1
        inputs = channels
        for i in range(self.layer_count):
            setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, d * mul, 3, 2, 1))
            setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(d * mul))
            inputs = d * mul
            mul *= 2

        self.d_max = inputs

        self.fc1 = nn.Linear(inputs * 4 * 4, zsize)
        self.fc2 = nn.Linear(inputs * 4 * 4, zsize)

        self.d1 = nn.Linear(zsize, inputs * 4 * 4)

        mul = inputs // d // 2

        for i in range(1, self.layer_count):
            setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, d * mul, 3, 2, 1, 1))
            setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(d * mul))
            inputs = d * mul
            mul //= 2

        setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, channels, 3, 2, 1, 1))

    def encode(self, x):

        for i in range(self.layer_count):
            x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))

        x = x.view(x.shape[0], self.d_max * 4 * 4)
        h1 = self.fc1(x)
        h2 = self.fc2(x)
        return h1, h2

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, x):
        x = x.view(x.shape[0], self.zsize)
        x = self.d1(x)
        x = x.view(x.shape[0], self.d_max, 4, 4)
        #x = self.deconv1_bn(x)
        x = F.leaky_relu(x, 0.2)

        for i in range(1, self.layer_count):
            x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)

        x = torch.tanh(getattr(self, "deconv%d" % (self.layer_count + 1))(x))
#        x = F.tanh(getattr(self, "deconv%d" % (self.layer_count + 1))(x))
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)
        return self.decode(z.view(-1, self.zsize, 1, 1)), mu, logvar

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

z_size = 64
vae = VAE(zsize=z_size, layer_count=4, channels=1)
if device == "cuda":
    vae.cuda()
#vae.weight_init(mean=0, std=0.02)
vae.to(device)


VAE(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv1_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv2_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv3_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv4_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=16384, out_features=64, bias=True)
  (fc2): Linear(in_features=16384, out_features=64, bias=True)
  (d1): Linear(in_features=64, out_features=16384, bias=True)
  (deconv2): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1

## training the model

In [11]:
optimizer = optim.SGD(vae.parameters(), lr=0.01)


def train(epoch):

    def loss_function(recon_x, x, mu, logvar):
        BCE = torch.mean((recon_x - x)**2)
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.mean(torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), 1))
        return BCE, KLD * 0.1

    vae.train()

    for batch_idx, (data, target) in enumerate(train_loader):

        data, _ = data.to(device), target.to(device)
        optimizer.zero_grad()
        rec, mu, logvar = vae(data)
        loss_re, loss_kl = loss_function(rec, data, mu, logvar)
        (loss_re + loss_kl).backward()
        optimizer.step()

        if batch_idx % 1000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss_re: {:.6f} \tLoss_kl: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss_re.item(), loss_kl.item()))

def test():
    with torch.no_grad():
        vae.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)



In [12]:
for epoch in range(1, 15 + 1):
    train(epoch)
    #test()

torch.save(vae.state_dict(), "models/VAEmodel.pkl")




## Evaluation

In [17]:
from torchvision.utils import save_image

vae = VAE(zsize=z_size, layer_count=4, channels=1)
vae.load_state_dict(torch.load("models/VAEmodel.pkl"))

def evaluation(vae, eval_id):
    if not os.path.exists('results/' + eval_id):
        os.mkdir('results/' + eval_id)
    vae.eval()

    z_size = 64
    im_size = 64
    sample_v = torch.randn(128, z_size).view(-1, z_size, 1, 1)

    for batch_idx, (data, target) in enumerate(test_loader):
        x = data
        x_rec, _, _ = vae.forward(x)
        resultsample = torch.cat([x, x_rec]) * 0.5 + 0.5
        resultsample = resultsample.cpu()
        save_image(resultsample.view(-1, 1, im_size, im_size),
                    'results/'+ eval_id +'/sample_encode.png')

        x_rec = vae.decode(sample_v)
        resultsample = x_rec * 0.5 + 0.5
        resultsample = resultsample.cpu()
        save_image(resultsample.view(-1, 1, im_size, im_size),
                    'results/'+ eval_id +'/sample_decode.png')
        break

eval_id = "002"
evaluation(vae, eval_id)
