In [None]:
import os
import h5py
import time
import numpy as np

import matplotlib.pyplot as plt

from sklearn.decomposition import PCA

from extra_data import open_run, RunDirectory, H5File
from extra_data.read_machinery import find_proposal

from threadpoolctl import threadpool_limits

In [None]:
propno = 2919
runno_dark = 59
runno_flat = 40

camno = 2
n_components = 20

cam_source = f"SPB_EHD_HPVX2_{camno}/CAM/CAMERA:daqOutput"

propdir = find_proposal(f"p{propno:06d}")
rundir_dark = os.path.join(propdir, f"raw/r{runno_dark:04d}")
rundir_flat = os.path.join(propdir, f"raw/r{runno_flat:04d}")

print("Proposal directory:", propdir)
print("Dark run directory:", rundir_dark)
print("Flat run directory:", rundir_flat)
print("Camera source:", cam_source)

In [None]:
def process_dark(images):
    return np.mean(images, axis=tuple(range(images.ndim - 2)))

def process_flat(images, n_components=20):
    flat = np.mean(images, axis=tuple(range(images.ndim - 2)))
    nx, ny = flat.shape
    intensity_mean = np.mean(flat)
    intensity = np.mean(images, axis=(-2,-1)).ravel() / intensity_mean
    flat_centered = images.reshape(-1, nx, ny) - intensity[:, None, None] * flat[None, :, :]
    with threadpool_limits(limits=40, user_api='blas'):
        pca = PCA(n_components=n_components, svd_solver='randomized', whiten=True)
        pca.fit(flat_centered.reshape(-1, nx * ny))
        
    return flat, pca.components_.reshape(-1, nx, ny), pca.explained_variance_ratio_
    
def process_flat_orig(images, n_components=20):
    flat = np.mean(images, axis=tuple(range(images.ndim - 2)))
    nx, ny = flat.shape
    flat_centered = images.reshape(-1, nx, ny) - flat[None, :, :]
    with threadpool_limits(limits=40, user_api='blas'):
        pca = PCA(n_components=n_components, svd_solver='randomized', whiten=True)
        pca.fit(flat_centered.reshape(-1, nx * ny))
        
    return flat, pca.components_.reshape(-1, nx, ny), pca.explained_variance_ratio_
    
def plot_images(images, figsize=None):
    nimage = len(images)
    nrow = 4
    ncol = nimage // 4 + int(nimage % 4 != 0)
    
    if figsize is None:
        figsize = (16, 10)
    fig, axs = plt.subplots(ncol, nrow, figsize=figsize)

    for k in range(nimage):
        ax = axs[k // nrow, k % nrow]
        im = ax.matshow(images[k])
        ax.axis(False)
        fig.colorbar(im, ax=ax)

    return fig, axs

def write_constants(fn, source, dark, flat, components, explained_variance_ratio):
    camera_name = source.partition('/')[0]
    with h5py.File(fn, 'w') as f:
        g = f.create_group(camera_name)
        g['rank'] = len(components)
        g['shape'] = dark.shape
        g['mean_dark'] = dark
        g['mean_flat'] = flat
        g['components_matrix'] = components
        g['explained_variance_ratio'] = explained_variance_ratio

In [None]:
tm0 = time.monotonic()
run_dark = RunDirectory(rundir_dark)
images_dark = run_dark[cam_source, "data.image.pixels"].ndarray()
ntrain, npulse, nx, ny = images_dark.shape
tm_rd = time.monotonic() - tm0

tm0 = time.monotonic()
dark = process_dark(images_dark)
tm_cm = time.monotonic() - tm0

print(f"N image: {ntrain * npulse} (ntrain: {ntrain}, npulse: {npulse})")
print(f"Image size: {nx} x {ny} px")
print(f"Read time: {tm_rd:.2f} s, comp time: {tm_cm:.2f}")

plt.matshow(dark)
plt.show()

In [None]:
tm0 = time.monotonic()
run_flat = RunDirectory(rundir_flat)
images_flat = run_flat[cam_source, "data.image.pixels"].ndarray()
ntrain, npulse, nx, ny = images_flat.shape
tm_rd = time.monotonic() - tm0

tm0 = time.monotonic()
flat, components, explained_variance_ratio = process_flat(images_flat, n_components)
tm_cm = time.monotonic() - tm0

print(f"N image: {ntrain * npulse} (ntrain: {ntrain}, npulse: {npulse})")
print(f"Image size: {nx} x {ny} px")
print(f"Read time: {tm_rd:.2f} s, comp time: {tm_cm:.2f}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,5))
ax1.matshow(flat)
ax2.plot(explained_variance_ratio, 'o')
plt.show()

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}.h5"
write_constants(fn, cam_source, dark, flat, components, explained_variance_ratio)

In [None]:
plot_images(components[:20])
plt.show()

In [None]:
tm0 = time.monotonic()
flat, components, explained_variance_ratio = process_flat_orig(images_flat, n_components+1)
tm_cm = time.monotonic() - tm0

print(f"N image: {ntrain * npulse} (ntrain: {ntrain}, npulse: {npulse})")
print(f"Image size: {nx} x {ny} px")
print(f"Read time: {tm_rd:.2f} s, comp time: {tm_cm:.2f}")

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,5))
ax1.matshow(flat)
ax2.plot(explained_variance_ratio, 'o')
plt.show()

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}_orig.h5"
write_constants(fn, cam_source, dark, flat, components, explained_variance_ratio)

In [None]:
plot_images(components[:21])
plt.show()

In [None]:
plot_images(images_flat.reshape(-1, nx, ny)[5:128*20+5:128])
plt.show()