In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from skimage import io
from skimage.filters import gaussian

import time
import numpy as np
import scipy as sp 
import scipy.ndimage as ndimage
from scipy.ndimage import uniform_filter, zoom
import hyperspy.api as hs
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.subplots import make_subplots
from IPython.display import display, clear_output

In [2]:
class CVAE3D(nn.Module):
    def __init__(self, latent_dim):
        super(CVAE3D, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        
        # Latent space
        self.fc_mu = nn.Conv3d(128, latent_dim, kernel_size=1)
        self.fc_logvar = nn.Conv3d(128, latent_dim, kernel_size=1)
        
        # Decoder
        self.decoder_input = nn.Conv3d(latent_dim, 128, kernel_size=1)
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 1, kernel_size=3, stride=1, padding=1)
        )
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, apply_sigmoid=False):
        z = self.decoder_input(z)
        logits = self.decoder(z)
        if apply_sigmoid:
            return torch.sigmoid(logits)
        return logits

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar       

#! kernel of 5 and stride of 1, another layer, confirm there is no compression in the energy direction. 

############################################################################################################
##################################### END OF MODEL DEFINITION ##############################################
############################################################################################################

def load_dm4_data(filepath):
    s = hs.load(filepath)
    data = s.data  # The 3D data array
    return data

import numpy as np
from skimage.filters import gaussian
from scipy.ndimage import uniform_filter

def preprocess_3d_images(image, target_size, sigma, energy_range, xy_window):

    """
    Preprocess 3D images with arbitrary input and output dimensions.
    
    :param image: Input 3D image (height, width, depth)
    :param target_size: Tuple of (height, width, depth) for the target size
    :param sigma: Sigma for Gaussian blur
    :param energy_range: Tuple of (min_energy, max_energy) in eV
    :param xy_window: Window size for spatial-spectral smoothing
    :return: Tuple of (preprocessed_image, reshaped_image)
    """
    # Apply Gaussian blur
    blurred_image = gaussian(image, sigma=sigma, mode='reflect', preserve_range=True)
    
    # Calculate the pixel indices corresponding to the energy range
    start_pixel = int(energy_range[0])
    end_pixel = int(energy_range[1])
    
    # Slice the data array to keep only the desired energy range in the depth dimension
    blurred_image = blurred_image[:, :, start_pixel:end_pixel]
    
    # Normalize the image
    normalized_image = blurred_image / np.max(blurred_image)
    
    # Apply spatial-spectral smoothing
    def smooth_spatial_spectral(arr, window):
        neighborhood_sum = uniform_filter(arr, size=(window, window, 1), mode='reflect')
        neighborhood_count = uniform_filter(np.ones_like(arr), size=(window, window, 1), mode='reflect')
        return neighborhood_sum / neighborhood_count
    
    smoothed_img = smooth_spatial_spectral(normalized_image, xy_window)
    
    # Resize the image to match the target size
    current_size = smoothed_img.shape
    scale_factors = [t / c for t, c in zip(target_size, current_size)]
    
    resized_img = zoom(smoothed_img, scale_factors, order=1)  # order=1 for linear interpolation
    
    # Ensure the final shape matches the target size
    assert resized_img.shape == target_size, f"Shape mismatch: {resized_img.shape} != {target_size}"
    
    # Reshape to (1, 1, height, width, depth) for PyTorch
    reshaped_image = resized_img.reshape((1, 1, *target_size))
    
    return resized_img, reshaped_image.astype('float32')

############################################################################################################
##################################### END OF DATA PREPROCESSING ############################################
############################################################################################################
    
def compute_loss(model, x):
    recon_x, mu, logvar = model(x)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

def log_normal_pdf(sample, mean, logvar):
    log2pi = torch.log(torch.tensor(2. * np.pi))
    return torch.sum(-0.5 * ((sample - mean) ** 2 * torch.exp(-logvar) + logvar + log2pi), dim=1)

def gaussian_blur(img, sigma):
    return np.array(gaussian(img, (sigma, sigma)))

def gaussian_blur_arr(images, sigma):
    return np.array([gaussian_blur(img, sigma) for img in images])

def norm_max_pixel(images):
    return np.array([img / np.max(img) for img in images])

def visualize_inference(model, input_image, data, energy_range, x=0, y=0):
    model.eval()
    with torch.no_grad():
        input_tensor = torch.tensor(input_image).unsqueeze(0).to(next(model.parameters()).device)
        mean, logvar = model.encode(input_tensor)
        z = model.reparameterize(mean, logvar)
        prediction = model.decode(z, apply_sigmoid=True)

    # Set the pixel manually
    selected_pixel = (y, x)

    # Create two separate figures
    fig_images = make_subplots(rows=1, cols=2, subplot_titles=('Input Image', 'Prediction'))
    fig_spectra = make_subplots(rows=1, cols=2, subplot_titles=('Input Spectral Graph', 'Predicted Spectral Graph'))

    # Input Image
    middle_slice_input = input_image[0, :, :, input_image.shape[3] // 2].cpu().numpy()
    fig_images.add_trace(go.Heatmap(z=middle_slice_input, colorscale='Viridis', showscale=False), row=1, col=1)
    fig_images.add_trace(go.Scatter(x=[selected_pixel[1]], y=[selected_pixel[0]], mode='markers', 
                                    marker=dict(color='red', size=10), showlegend=False), row=1, col=1)

    # Prediction
    middle_slice_prediction = prediction[0, 0, :, :, prediction.shape[4] // 2].cpu().numpy()
    fig_images.add_trace(go.Heatmap(z=middle_slice_prediction, colorscale='Viridis', showscale=False), row=1, col=2)
    fig_images.add_trace(go.Scatter(x=[selected_pixel[1]], y=[selected_pixel[0]], mode='markers', 
                                    marker=dict(color='red', size=10), showlegend=False), row=1, col=2)

    # Update layout to ensure images are not distorted
    fig_images.update_layout(
        height=600,
        width=1200,
        margin=dict(l=20, r=20, t=40, b=20),
        yaxis=dict(scaleanchor="x", scaleratio=1),
        yaxis2=dict(scaleanchor="x2", scaleratio=1)
    )

    # Input Spectral Graph
    input_spectrum = data[selected_pixel[0], selected_pixel[1], :]
    x_energy = np.linspace(energy_range[0], energy_range[1], input_spectrum.shape[0])
    fig_spectra.add_trace(go.Scatter(x=x_energy, y=input_spectrum), row=1, col=1)

    # Predicted Spectral Graph
    predicted_spectrum = prediction[0, 0, selected_pixel[0], selected_pixel[1], :].cpu().numpy()
    fig_spectra.add_trace(go.Scatter(x=x_energy, y=predicted_spectrum), row=1, col=2)

    fig_spectra.update_layout(
        height=500,
        width=1200,
        margin=dict(l=20, r=20, t=40, b=20)
    )
    
    return fig_images, fig_spectra

def create_loss_plot():
    """Create an empty Plotly figure for the loss plot."""
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=[], y=[], mode='lines', name='Training Loss'))
    fig.update_layout(title='Training Loss',
                      xaxis_title='Epoch',
                      yaxis_title='Loss',
                      height=400,
                      width=800)
    return fig


In [3]:
epochs = 800
sigma = 2 
latent_dim = 20
energy_range = (600, 740)  # eV
target_size = (50, 50, energy_range[1]-energy_range[0])  # (x, y, z)
ev_per_pixel = 0.05
dm4_file = 'data/images_3D/EELS HL SI.dm4'

# Load the data with the specified energy range and resolution
data = load_dm4_data(dm4_file)

#! APPLY 2D Smoothing before feeding to the model 
#! generate the image where every pixel ( spectrum ) is the average of 3 by 3 pixels, also try training only on the "white part"
#! try to overtrain by making it bigger if nothing else works 

In [4]:
train_image_viz, train_image = preprocess_3d_images(data, target_size, sigma, energy_range, 3) 

# Create data loaders
train_dataset = TensorDataset(torch.tensor(train_image))
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Instantiate model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CVAE3D(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Pick a sample of the test set for generating output images
num_examples_to_generate = 1
test_sample = next(iter(train_loader))[0][:num_examples_to_generate].to(device)

In [5]:
# Initialize lists to store epoch and loss data
epochs_list = []
losses = []

# Create the initial loss plot
fig_loss = create_loss_plot()

# Training loop
try:
    for epoch in range(1, epochs + 1):
        start_time = time.time()
        model.train()
        train_loss = 0
        for batch in train_loader:
            train_x = batch[0].to(device)
            loss = compute_loss(model, train_x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_loader.dataset)
        end_time = time.time()

        # Append data to lists
        epochs_list.append(epoch)
        losses.append(train_loss)

        # Update the loss plot
        fig_loss.data[0].x = epochs_list
        fig_loss.data[0].y = losses

        # Create new inference visualizations
        fig_images, fig_spectra = visualize_inference(model, test_sample[0], train_image_viz, energy_range)

        # Display updated plots
        clear_output(wait=True)
        fig_loss.show()
        fig_images.show()
        fig_spectra.show()

        print(f'Epoch: {epoch}, Test set ELBO: {train_loss:.4f}, time elapsed for current epoch: {end_time - start_time:.2f}s')
        
except KeyboardInterrupt:
    print("Training interrupted. Saving the model...")

Training interrupted. Saving the model...


In [None]:
# Create new inference visualizations
fig_images, fig_spectra = visualize_inference(model, test_sample[0], train_image_viz, energy_range, 24, 36)

# Display updated plots
fig_images.show()
fig_spectra.show()


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [None]:
torch.save(model.state_dict(), 'cvae3d_flex.pth')