In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision
import numpy as np
import h5py
from matplotlib import pyplot as plt
from Utils import save_large_dataset, load_large_dataset
plt.ion()   # interactive mode

In [2]:
CUDA = True

In [4]:
X = load_large_dataset('images')
Y = load_large_dataset('labels')

In [5]:
X = X.squeeze() #remove unnecessary dimension
Y = Y.squeeze()

X = X[:,:,:,58:61] #take only 3 slices and treat them as channels

print (X.shape)

(1792, 121, 145, 3)


In [6]:
X = np.rollaxis(X, 3, 1) #move channel dimension to be the first one
print (X.shape)

(1792, 3, 121, 145)


In [7]:
np.random.seed(9999) #seed fixed for reproducibility
mask = np.random.rand(len(X)) < 0.9  #array of boolean variables

training_set = X[mask]
training_labels = Y[mask]

validation_set = X[~mask]
validation_labels = Y[~mask]

In [8]:
BATCH_SIZE = 64

In [9]:
training_set = torch.from_numpy(training_set) #convert to torch tensor
training_labels = torch.from_numpy(training_labels) #convert to torch tensor

In [10]:
training_labels = training_labels.long()

In [11]:
dataset = torch.utils.data.TensorDataset(training_set, training_labels)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Autoencoder definition

In [12]:
LATENT_DIM = 10 #size of the latent space in the variational autoencoder

In [13]:
class VAE(nn.Module):
    
    def __init__(self):
        super(VAE, self).__init__()
        
        # layers for encoder
        self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)  
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(32*7*9, LATENT_DIM)
        self.fc2 = nn.Linear(32*7*9, LATENT_DIM)
        
        
        # layers for decoder
        self.fc_decoder = nn.Linear(LATENT_DIM, 32*7*9)
        
        self.conv1_decoder = nn.Conv2d(32, 32, kernel_size=3, padding=1) 
        self.conv2_decoder = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3_decoder = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv4_decoder = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv5_decoder = nn.Conv2d(32, 3, kernel_size=3, padding=1)
        
        
    def encode(self, x):
        x = F.relu(self.conv1(x)) #shape after conv: (8, 121, 145)
        x = F.max_pool2d(x, kernel_size=2) #shape after pooling: (8, 60, 72)
        
        x = F.relu(self.conv2(x)) #shape after conv: (16, 60, 72)
        x = F.max_pool2d(x, kernel_size=2) #shape after pooling: (16, 30, 36)
        
        x = F.relu(self.conv3(x)) #shape after conv: (32, 30, 36)
        x = F.max_pool2d(x, kernel_size=2) #shape after pooling: (32, 15, 18)
        
        x = F.relu(self.conv4(x)) #shape after conv: (32, 15, 18)
        x = F.max_pool2d(x, kernel_size=2) #shape after pooling: (32, 7, 9)
        
        x = x.view(-1, 32*7*9)
        return self.fc1(x), self.fc2(x)
    
    
    def reparameterize(self, mu, logvar):

        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        return eps.mul(std).add_(mu)

    
    def decode(self, z):
        z = F.relu(self.fc_decoder(z))
        z = z.view(-1, 32,7,9) #reshape to (32, 7, 9)
        
        z = F.relu(self.conv1_decoder(z)) #shape after conv (32, 7, 9)
        z = F.upsample(z, size=(14,18), mode='nearest') #shape after upsampling (32, 14, 18)
        
        z = F.relu(self.conv2_decoder(z)) #shape after conv (32, 14, 18)
        z = F.upsample(z, size=(28,36), mode='nearest') #shape after upsampling (32, 28, 36)
        
        z = F.relu(self.conv3_decoder(z)) #shape after conv (32, 28, 36)
        z = F.upsample(z, size=(56,72), mode='nearest') #shape after conv (32, 56, 72)
        
        z = self.conv4_decoder(z) #shape after conv (32, 56, 72)
        z = F.upsample(z, size=(112,144), mode='nearest') #shape after conv (32, 112, 144)
        
        z = self.conv5_decoder(z) #shape after conv (3, 112, 144)
        z = F.pad(z, (0,1,4,5), "constant", -10) #after padding (3, 121, 145) (to match the input size)
        
        return F.sigmoid(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar 

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(reconstruced_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(reconstruced_x.view(-1, 121*145*3), x.view(-1, 121*145*3), size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    loss = BCE + KLD
    return loss, BCE, KLD

In [14]:
net = VAE()
if (CUDA):
    net.cuda()
optimizer = optim.Adam(net.parameters())

In [None]:
reconstructed_images = torch.zeros(22,3,121,145)

In [None]:
for epoch in range(150):  # loop over the dataset multiple times
    
    running_loss = 0.0 #total loss
    running_BCE = 0.0 #reconstruction loss
    running_KLD = 0.0 #divergence loss
    
    for i, data in enumerate(train_loader):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        if (CUDA):
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
        else:
            inputs, labels = Variable(inputs), Variable(labels)
                   
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        reconstructed_batch, mu, logvar = net(inputs)
        
        outputs = net(inputs)
        loss, BCE, KLD = loss_function(reconstructed_batch, inputs, mu, logvar)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.data[0]*inputs.size(0)
        running_BCE += BCE.data[0]*inputs.size(0)
        running_KLD += KLD.data[0]*inputs.size(0)

    print('Epoch %d, Total loss: %.3f' % (epoch + 1, running_loss / len(training_set)))
    print('Epoch %d, BCE loss: %.3f' % (epoch + 1, running_BCE / len(training_set)))
    print('Epoch %d, KLD loss: %.3f' % (epoch + 1, running_KLD / len(training_set)))
    print('------------')
    
    
    if (epoch%10==0): #saving examples of reconstructed images every 10 epochs
        output = net(Variable(train_loader.dataset[0][0].unsqueeze(0).cuda()))
        output = output[0].data.view(3,121,145).cpu()
        reconstructed_images[int(epoch/10)] = output
    
print('Finished Training')

In [None]:
#saving weights
torch.save(net.state_dict(), "VAE_150_epochs.pt") #weights after 150 epochs

# Data augmentation

In [15]:
#loading weights
net.load_state_dict(torch.load("VAE_150_epochs.pt"))

In [23]:
augmented_images = np.zeros((len(training_set)*5, 3, 121, 145)) #5 new images for each original image
augmented_labels = np.zeros(len(training_set)*5)

In [24]:
for i in range(len(training_set)):
    for j in range(5):
        mu, logvar = net.encode(Variable(train_loader.dataset[i][0].unsqueeze(0).cuda())) #Encoder
        
        std = logvar.mul(0.5).exp_() 
        eps = Variable(std.data.new(std.size()).normal_()) 
        
        output = net.decode(eps.mul(std).add_(mu)) #Decoder
        
        augmented_images[(5*i)+j] = output[0].data.cpu()
        augmented_labels[(5*i)+j] = train_loader.dataset[i][1]

In [25]:
save_large_dataset("augmented_train_set_1", augmented_images)
save_large_dataset("augmented_train_set_labels_1", augmented_labels)