<a href="https://colab.research.google.com/github/prinaldi3/Denoising/blob/main/Load_and_Plot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is simply designed to load a trained model (from checkpoints) and then plot reconstructions

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
!python -m pip install pip==21.0.1

In [None]:
!rm -rf xca
!git clone https://github.com/maffettone/xca

In [None]:
%%bash
cd xca
python -m pip install .

In [None]:
#Importing from XCA package

import xca
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from xca.ml.tf_models import build_CNN_encoder_model, build_CNN_decoder_model, VAE_denoising_training, VAE

In [None]:
dataset_paths = ['specify_path_here']

In [None]:
#Define and load model -- currently configured to Guassian dataset
encoder, last_conv_shape = build_CNN_encoder_model(data_shape=(500,1), latent_dim=100, dense_dims=[], filters=[32, 32, 16], kernel_sizes=[16, 16, 16],  strides=[1, 1, 1], pool_sizes=[1, 1, 1], paddings=["same"]*3, verbose=True)

decoder = build_CNN_decoder_model(data_shape=(500,1), latent_dim=100, last_conv_layer_shape=last_conv_shape, filters = [16, 32, 32], kernel_sizes=[16, 16, 16], strides=[1, 1, 1], paddings=["same"]*4, verbose=True)
vae = VAE(encoder, decoder, 0.)

cnn_checkpoint_path = '/content/gdrive/MyDrive/guassian_denoise_95/training_checkpoints' 
checkpoint = tf.train.Checkpoint(model=vae)
checkpoint.restore(tf.train.latest_checkpoint(cnn_checkpoint_path))

In [None]:
from xca.ml.tf_data_proc import build_dataset

#Build dataset (create noisy data) -- configured for Guassian dataset
batch_size=128
def preprocess(data, label):
        X = tf.cast(data, tf.float32)
        X = (X - tf.math.reduce_min(X, axis=0, keepdims=True)) / (
            tf.math.reduce_max(X, axis=0, keepdims=True) - tf.math.reduce_min(X, axis=0, keepdims=True)
        )
        noisy = tf.cast(data, tf.float32) + tf.random.normal(
            data.shape,
            stddev=10 ** np.random.uniform(-1, -.5),
            dtype=tf.float32,
        )
        noisy = (noisy - tf.math.reduce_min(noisy, axis=0, keepdims=True)) / (
            tf.math.reduce_max(noisy, axis=0, keepdims=True) - tf.math.reduce_min(noisy, axis=0, keepdims=True)
        )
        return {"X": X, "X_noisy": noisy, "label": label}

data,_ = build_dataset(
        dataset_paths=dataset_paths,
        batch_size=batch_size,
        multiprocessing=1,
        categorical=True,
        val_split=0.0,
        data_shape=(500,1),
        preprocess=preprocess,
    )


In [None]:
#Plot reconstructions

for batch in data:
  output = vae(batch["X_noisy"], training=False) #apply trained denoiser to the noisy data
  X_denoise = output["reconstruction"][0] #alter this index to see different samples (but don't exceed batch_size-1 !)
  X = batch["X"][0]
  X_noisy = batch["X_noisy"][0]
  label = batch["label"][0]

  fig, ax = plt.subplots(figsize=(10, 10))
  ax.plot(X_noisy+1.2, label="Raw", color='gray')
  ax.plot(X + .2, label="True", color='k')
  ax.plot(X_denoise+.2, 'r--', label="Denoise")
  ax.plot(abs(X-X_denoise), color='b', label="Residual")
  ax.set_title("Comparison of True and Denoised Patterns")
  ax.set_xlabel("Pixel #")
  ax.set_ylabel("Intensity [Arb.]")
  ax.get_yaxis().set_ticks([])

  plt.rc('font', size=20)          # controls default text sizes
  plt.rc('axes', titlesize=15)     # fontsize of the axes title
  plt.rc('axes', labelsize=15)    # fontsize of the x and y labels
  plt.rc('figure', titlesize=20)  # fontsize of the figure title
  #fig.savefig('/content/gdrive/MyDrive/poster_figs/guassian_reconstruction', facecolor='white')
  plt.legend()
  break