This notebook tests the performance of a convolutional CVAE. We aim to predict the diffraction pattern when given the structure factors file.

In [1]:
import os
import sys
import torch
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np

In [2]:
sys.path.append(os.path.abspath(os.path.join('..')))
from src.models.components.cvae import Model

data_loc = "../data/VAE000/cleaned_data/"

In [3]:
model = Model()
model.load_state_dict(torch.load("models/cvae.pt"))
model.eval()

FileNotFoundError: [Errno 2] No such file or directory: 'models/cvae.pt'

In [7]:
sample_size = 10
CISD_codes = os.listdir(data_loc)

patterns = np.concatenate([np.clip(np.fromfile(
    os.path.join(data_loc, code, code+"_+0+0+0.bin"),
    dtype=np.float64), 0.0, 1.0).reshape((128, 128)) for code in np.random.choice(CISD_codes, sample_size)], axis=0)
patterns_tensor = torch.from_numpy(patterns).to(torch.float32).clone().detach().view(sample_size, 1, 128, 128)
structure_factors = np.concatenate([np.loadtxt(
    os.path.join(data_loc, code, f"{code}_structure_factors.txt"),
    dtype=np.float64) for code in np.random.choice(CISD_codes, sample_size)], axis=0)
structure_factors_tensor = torch.from_numpy(structure_factors).float().clone().detach().view(sample_size, -1)
predictions_tensor = model.decode(torch., structure_factors_tensor)
predictions = predictions_tensor.view(sample_size * 128, 128).detach().numpy()

# Fix images contrasts (due to ZNCC not caring about this)
patterns = patterns - np.min(patterns)
patterns = patterns * (1 / np.max(patterns))
predictions = predictions - np.min(predictions)
predictions = predictions * (1 / np.max(predictions))
difference = np.abs(predictions - patterns)
difference = np.clip(5 * difference, 0.0, 1.0)

fig = plt.figure(figsize=(25, 4))
gs = gridspec.GridSpec(3, 1, height_ratios=[1, 1, 1],
         wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 

for i, (ax, img) in enumerate(zip([plt.subplot(gs[0, 0]), plt.subplot(gs[1, 0]), plt.subplot(gs[2, 0])],
                   [patterns.T, predictions.T, difference.T])):
    ax.imshow(img, cmap=("CMRmap" if i==2 else "gray"))
    ax.set_axis_off()

plt.show()


RuntimeError: Tensors must have same number of dimensions: got 2 and 1