In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from skimage import io
from skimage.filters import gaussian

import time
import numpy as np
import scipy as sp 
import scipy.ndimage as ndimage
from scipy.ndimage import uniform_filter
import hyperspy.api as hs
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from IPython.display import display, clear_output

In [3]:
class CVAE3D(nn.Module):
    """Convolutional variational autoencoder."""

    def __init__(self, latent_dim, size):
        super(CVAE3D, self).__init__()
        self.latent_dim = latent_dim
        self.size = size
        reduced_size = size  # Size doesn't change because stride = 1
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(128, 128, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * reduced_size * reduced_size * reduced_size, latent_dim * 2)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, reduced_size * reduced_size * reduced_size * 32),
            nn.ReLU(),
            nn.Unflatten(1, (32, reduced_size, reduced_size, reduced_size)),
            nn.ConvTranspose3d(32, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 1, kernel_size=3, stride=1, padding=1)
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = torch.chunk(h, 2, dim=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = torch.sigmoid(logits)
            return probs
        return logits

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z)
            
#! kernel of 5 and stride of 1, another layer, confirm there is no compression in the energy direction. 

############################################################################################################
##################################### END OF MODEL DEFINITION ##############################################
############################################################################################################

def load_dm4_data(filepath):
    s = hs.load(filepath)
    data = s.data  # The 3D data array
    return data

def preprocess_3d_images(image, size, sigma, energy_range, xy_window):
    # Apply Gaussian blur directly to the single image
    blurred_image = gaussian(image, sigma=sigma, mode='reflect', preserve_range=True)
    
    # Calculate the pixel indices corresponding to the energy range
    start_pixel = int((energy_range[0] - 0))
    end_pixel = int((energy_range[1] - 0))
    
    # Slice the data array to keep only the desired energy range in the third dimension
    blurred_image = blurred_image[:, :, start_pixel:end_pixel]
    
    # Normalize the image
    normalized_image = blurred_image / np.max(blurred_image)
    
    # Apply spatial-spectral smoothing
    def smooth_spatial_spectral(arr, window):
        # Use uniform_filter to compute the sum of spectra in the neighborhood
        neighborhood_sum = uniform_filter(arr, size=(window, window, 1), mode='reflect')
        # Compute the number of pixels in the neighborhood
        neighborhood_count = uniform_filter(np.ones_like(arr), size=(window, window, 1), mode='reflect')
        # Compute the average
        return neighborhood_sum / neighborhood_count
    
    smoothed_img = smooth_spatial_spectral(normalized_image, xy_window)
    
    # Calculate the padding for each dimension
    padding = [(max(0, size - dim_size) // 2, max(0, size - dim_size) - max(0, size - dim_size) // 2) 
               for dim_size in smoothed_img.shape]
    
    # Apply padding
    padded_img = np.pad(smoothed_img, padding, mode='constant')
    
    # Calculate the crop for each dimension
    crop = [(max(0, padded_img.shape[i] - size) // 2, 
             max(0, padded_img.shape[i] - size) - max(0, padded_img.shape[i] - size) // 2) 
            for i in range(len(padded_img.shape))]
    
    # Apply cropping
    cropped_img = padded_img[crop[0][0]:padded_img.shape[0]-crop[0][1],
                             crop[1][0]:padded_img.shape[1]-crop[1][1],
                             crop[2][0]:padded_img.shape[2]-crop[2][1]]
    
    # Ensure the final shape matches the target size
    assert cropped_img.shape == (size, size, size), f"Shape mismatch: {cropped_img.shape} != {(size, size, size)}"
    
    # Reshape to (1, size, size, size) for PyTorch
    reshaped_image = cropped_img.reshape((1, 1, size, size, size))
    
    return cropped_img, reshaped_image.astype('float32')

############################################################################################################
##################################### END OF DATA PREPROCESSING ############################################
############################################################################################################
    
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    cross_ent = F.binary_cross_entropy_with_logits(x_logit, x, reduction='sum')
    logpx_z = -cross_ent
    logpz = log_normal_pdf(z, torch.zeros_like(z), torch.zeros_like(z))
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -torch.mean(logpx_z + logpz - logqz_x)

def log_normal_pdf(sample, mean, logvar):
    log2pi = torch.log(torch.tensor(2. * np.pi))
    return torch.sum(-0.5 * ((sample - mean) ** 2 * torch.exp(-logvar) + logvar + log2pi), dim=1)

def gaussian_blur(img, sigma):
    return np.array(gaussian(img, (sigma, sigma)))

def gaussian_blur_arr(images, sigma):
    return np.array([gaussian_blur(img, sigma) for img in images])

def norm_max_pixel(images):
    return np.array([img / np.max(img) for img in images])

def visualize_inference(model, input_image, data, energy_range, x = 0, y = 0):
    model.eval()
    with torch.no_grad():
        input_tensor = torch.tensor(input_image).unsqueeze(0).to(device)
        mean, logvar = model.encode(input_tensor)
        z = model.reparameterize(mean, logvar)
        prediction = model.decode(z, apply_sigmoid=True)

    # Set the pixel manually
    selected_pixel = (y, x) #! SET THE PIXEL MANUALLY 

    # Create two separate figures
    fig_images = make_subplots(rows=1, cols=2, subplot_titles=('Input Image', 'Prediction'))
    fig_spectra = make_subplots(rows=1, cols=2, subplot_titles=('Input Spectral Graph', 'Predicted Spectral Graph'))

    # Input Image
    middle_slice_input = input_image[0, :, :, input_image.shape[2] // 2].cpu().numpy()
    fig_images.add_trace(go.Heatmap(z=middle_slice_input, colorscale='Viridis', showscale=False), row=1, col=1)
    fig_images.add_trace(go.Scatter(x=[selected_pixel[1]], y=[selected_pixel[0]], mode='markers', 
                                    marker=dict(color='red', size=10), showlegend=False), row=1, col=1)

    # Prediction
    middle_slice_prediction = prediction[0, 0, :, :, prediction.shape[2] // 2].cpu().numpy()
    fig_images.add_trace(go.Heatmap(z=middle_slice_prediction, colorscale='Viridis', showscale=False), row=1, col=2)
    fig_images.add_trace(go.Scatter(x=[selected_pixel[1]], y=[selected_pixel[0]], mode='markers', 
                                    marker=dict(color='red', size=10), showlegend=False), row=1, col=2)

    # Update layout to ensure images are not distorted
    fig_images.update_layout(
        height=600,
        width=1200,
        margin=dict(l=20, r=20, t=40, b=20),
        yaxis=dict(scaleanchor="x", scaleratio=1),
        yaxis2=dict(scaleanchor="x2", scaleratio=1)
    )

    # Input Spectral Graph
    input_spectrum = data[selected_pixel[0], selected_pixel[1], :]
    x_energy = np.linspace(energy_range[0], energy_range[1], input_spectrum.shape[0])
    fig_spectra.add_trace(go.Scatter(x=x_energy, y=input_spectrum), row=1, col=1)

    # Predicted Spectral Graph
    predicted_spectrum = prediction[0, 0, selected_pixel[0], selected_pixel[1], :].cpu().numpy()
    fig_spectra.add_trace(go.Scatter(x=x_energy, y=predicted_spectrum), row=1, col=2)

    fig_spectra.update_layout(
        height=500,
        width=1200,
        margin=dict(l=20, r=20, t=40, b=20)
    )
    
    return fig_images, fig_spectra

def create_loss_plot():
    """Create an empty Plotly figure for the loss plot."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=[], y=[], mode='lines', name='Training Loss'))
    fig.update_layout(title='Training Loss',
                      xaxis_title='Epoch',
                      yaxis_title='Loss',
                      height=400,
                      width=800)
    return fig


In [4]:
SIZE = 48
SIGMA = 2
epochs = 8000
latent_dim = 20
energy_range = (1630, 1678)  # eV
ev_per_pixel = 0.05
dm4_file = 'data/images_3D/EELS HL SI.dm4'

# Load the data with the specified energy range and resolution
data = load_dm4_data(dm4_file)

#! APPLY 2D Smoothing before feeding to the model 
#! generate the image where every pixel ( spectrum ) is the average of 3 by 3 pixels, also try training only on the "white part"
#! try to overtrain by making it bigger if nothing else works 

In [5]:
train_image_viz, train_image = preprocess_3d_images(data, size=SIZE, sigma=SIGMA, energy_range=energy_range, xy_window=3)

# Create data loaders
train_dataset = TensorDataset(torch.tensor(train_image))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Instantiate model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CVAE3D(latent_dim, SIZE).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Pick a sample of the test set for generating output images
num_examples_to_generate = 1
test_sample = next(iter(train_loader))[0][:num_examples_to_generate].to(device)

In [10]:
# Initialize lists to store epoch and loss data
epochs_list = []
losses = []

# Create the initial loss plot
fig_loss = create_loss_plot()

# Training loop
try:
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        model.train()
        train_loss = 0
        for batch in train_loader:
            train_x = batch[0].to(device)
            loss = compute_loss(model, train_x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader.dataset)
        end_time = time.time()

        # Append data to lists
        epochs_list.append(epoch)
        losses.append(train_loss)

        # Update the loss plot
        fig_loss.data[0].x = epochs_list
        fig_loss.data[0].y = losses

        # # Create new inference visualizations
        # fig_images, fig_spectra = visualize_inference(model, test_sample[0], train_image_viz, energy_range)

        # Display updated plots
        clear_output(wait=True)
        fig_loss.show()
        # fig_images.show()
        # fig_spectra.show()

        print(f'Epoch: {epoch}, Test set ELBO: {train_loss:.4f}, time elapsed for current epoch: {end_time - start_time:.2f}s')
        
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")

Epoch: 832, Test set ELBO: 66252.8594, time elapsed for current epoch: 0.06s
Training interrupted. Saving the model...


In [11]:
# Create new inference visualizations
fig_images, fig_spectra = visualize_inference(model, test_sample[0], train_image_viz, energy_range, 24, 20)

# Display updated plots
fig_images.show()
fig_spectra.show()


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [12]:
torch.save(model.state_dict(), 'cvae3d_modified_stride.pth')