In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm, tnrange, tqdm_notebook
from notify_run import Notify

In [None]:
notify = Notify()
notify.register()

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
# from src.model import VAE
from src.model import loss_function
from src import visualization as v

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data', x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

## Testing weaker models 

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, image_channels, image_size, h_dim1, h_dim2, zdim, conv_channels=[16, 16]):
        super(VAE, self).__init__()

        self.image_channels = image_channels
        self.image_size = image_size
        self.h_dim1 = h_dim1
        self.h_dim2 = h_dim2
        self.zdim = zdim
        self.conv_channels = conv_channels
        
        # Encoder
        self.conv1 = nn.Conv2d(image_channels, conv_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(conv_channels[0], conv_channels[1], kernel_size=3, stride=1, padding=1, bias=False)
        self.pool1 = nn.MaxPool2d(2)
        
        # Latent vectors
        self.fc1 = nn.Linear(image_size//2 * image_size//2 * conv_channels[1], h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
#         self.fc1 = nn.Linear(image_size//2 * image_size//2 * conv_channels[1], h_dim2)
        self.fc31 = nn.Linear(h_dim2, zdim)
        self.fc32 = nn.Linear(h_dim2, zdim)
        
        # Decoder
        self.fc3 = nn.Linear(zdim, h_dim2)
#         self.fc5 = nn.Linear(h_dim2, image_size//2 * image_size//2 * conv_channels[1])
        self.fc4 = nn.Linear(h_dim2, h_dim1)
        self.fc5 = nn.Linear(h_dim1, image_size//2 * image_size//2 * conv_channels[1])
        
        self.conv3 = nn.ConvTranspose2d(conv_channels[1], conv_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.conv4 = nn.ConvTranspose2d(conv_channels[0], image_channels, kernel_size=3, stride=1, padding=1, bias=False)

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
    
    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x = F.dropout(self.pool1(x))
        x = x.view(-1, self.image_size//2 * self.image_size//2 * self.conv_channels[-1])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc31(x), self.fc32(x)
    
    def decode(self, z):
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc4(z))
        z = F.relu(self.fc5(z))
        z = z.view(-1, self.conv_channels[-1], self.image_size//2, self.image_size//2)
        z = F.interpolate(z, scale_factor=2)       
        z = F.relu(self.conv3(z))
        z = torch.sigmoid(self.conv4(z))
        return z
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.sampling(mu, log_var)
        return self.decode(z), mu, log_var

In [None]:
for beta in [1, 5, 50, 100]:
    output_name = 'beta_{}_2fc_vae_{}_epoch_{}_zdim_{}.{}'
    losses = {'kl':[], 'rl':[]}
    
    zdim = 64
    model = VAE(image_channels=c.image_channels,
                image_size=c.image_size, 
                h_dim1=1024,
                h_dim2=128,
                zdim=zdim).to(c.device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    notify.send("Starting training for zdim {}".format(zdim))
    loss_params = {'input_size': c.image_size,
                   'zdim': zdim,
                   'beta': beta
                  }
    
    t1 = tnrange(50)
    for epoch in t1:
        model.train()
        train_loss, kl, rl = 0, 0, 0
        t2 = tqdm_notebook(dataloaders['train'])
        for batch_idx, (data, _) in enumerate(t2):
            data = data.to(c.device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss, r, k = loss_function(recon_batch, data, mu, logvar, **loss_params)
            loss.backward()
            train_loss += loss.item()
            kl += k.item()
            rl += r.item()
            optimizer.step()
            
            t2.set_postfix({"Reconstruction Loss":r.item(), "KL Divergence":k.item()})

        losses['kl'].append(kl)
        losses['rl'].append(rl)
        notify.send("z-dim = {}, Training Epoch {}, Training Loss: {:.4f}".format(zdim, 
                                                                                  epoch+1,
                                                                                  train_loss / len(dataloaders['train'].dataset)))

        t1.set_postfix({"KL Divergence":kl/len(dataloaders['train'].dataset), 
                       "Reconstruction Loss":rl/len(dataloaders['train'].dataset)})
        
        """
        Testing
        """
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for i, (data, _) in enumerate(dataloaders['val']):
                data = data.to(c.device)
                recon_batch, mu, logvar = model(data)
                loss, r, k = loss_function(recon_batch, data, mu, logvar, **loss_params)
                test_loss += loss.item()
                if i == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat([data[:n],
                                            recon_batch.view(c.batch_size, 
                                                             c.image_channels, 
                                                             c.image_size, 
                                                             c.image_size)[:n]])

                    save_image(comparison.cpu(),
                               os.path.join(data_home,
                                            'samples',
                                            output_name.format(beta, c.image_size, epoch+1, zdim, 'png')
                                           ), nrow=n)
                    
        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), 
                       os.path.join(data_home, 'weights',
                                    output_name.format(beta, c.image_size, epoch+1, zdim, 'torch')))
            notify.send("Saved weights at epoch {}".format(epoch+1))
                        
    fig = plt.figure()
    plt.plot(losses['kl'])
    plt.plot(losses['rl'])
    plt.savefig(os.path.join(data_home, 'figures', output_name.format(beta, c.image_size, '50', zdim, 'png')))

# Testing

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][0][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][10][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][0][0], image_datasets['train'][10][0], model=model)

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")

# plt.savefig(os.path.join(data_home,'figures','tool_different_anatomy_similar.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][10][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][830][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][10][0], image_datasets['train'][830][0], model=model)

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
# plt.savefig(os.path.join(data_home, 'figures', 'tool_different_anatomy_different.png'),bbox_inches='tight', dpi=400, pad_inches=0.0)