In [31]:
import torch
import torch.nn as nn
from analyzer.data import Dataloader, folder2Vol
from vae.model import vae_3d
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch.optim as optim

In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Device:  cpu


In [35]:
dl = Dataloader("../datasets/r_em/", "../datasets/r_gt/", chunk_size=(5,4096,4096))
em, gt = dl.load_chunk(vol='both')

dl.volume = folder2Vol("../datasets/r_em/")


em data loaded:  (5, 4096, 4096)
gt data loaded:  (5, 4096, 4096)


In [36]:
chunk_size = 32

test = dl.volume

test = np.reshape(test, (-1,chunk_size,4096,4096))[:16]
test = np.expand_dims(test, axis=1)
test = test[:,:,:,:chunk_size,:chunk_size]
test.shape

(4, 1, 32, 32, 32)

In [37]:
model = vae_3d.Conv3dVAE_simple().to(device)
criterion = nn.BCELoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=0.5)

In [38]:
def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [39]:
data = torch.tensor(test.astype(np.float32))

In [40]:
model.train()
running_loss = 0.0
data = data.to(device)
optimizer.zero_grad()
reconstruction, mu, logvar = model(data)
bce_loss = criterion(reconstruction, data)
loss = final_loss(bce_loss, mu, logvar)
running_loss += loss.item()
loss.backward()
optimizer.step()
train_loss = running_loss/16

In [42]:
train_loss

71232.2265625