In [None]:
import os
import h5py
import time

#from autograd import grad
#import autograd.numpy as np
import numpy as np
import matplotlib.pyplot as plt

from extra_data import open_run, RunDirectory, H5File
from extra_data.read_machinery import find_proposal

from skimage.transform import downscale_local_mean
from scipy.optimize import fmin_l_bfgs_b

In [None]:
propno = 2919
runno_dark = 59
runno_flat = 40

camno = 2
n_components = 20

cam_source = f"SPB_EHD_HPVX2_{camno}/CAM/CAMERA:daqOutput"

runno = 30

downsample_factors = (2, 4)

propdir = find_proposal(f"p{propno:06d}")
rundir = os.path.join(propdir, f"raw/r{runno:04d}")

In [None]:
def read_constants(fn, source):
    camera_name = source.partition('/')[0]
    with h5py.File(fn, 'r') as f:
        g = f[camera_name]
        dark = g['mean_dark'][:]
        flat = g['mean_flat'][:]
        components = g['components_matrix'][:]
    return dark, flat, components
    
def downscale_images(downscale_factors, *args):
    rescaled = []
    for image in args:
        sf = (1,) * (image.ndim - 2) + tuple(downscale_factors)
        rescaled.append(downscale_local_mean(image, sf))
    return tuple(rescaled)

def correct_dyn(w, image, flat0, components):
    flat_dyn = flat0 + np.sum(w[:, None, None] * components, 0)
    return image / flat_dyn

def discrete_gradient(w, image, flat0, components, components_mean, flat_mean):
    nx, ny = image.shape
    flat_dyn = flat0 + np.sum(w[:, None, None] * components, 0)
    factor = flat_mean + np.dot(w, components_mean)
    corr_img = image / flat_dyn * factor

    dx = np.zeros([nx, ny], float)
    dx[:-1, :] = np.diff(corr_img, n=1, axis=0)
    dy = np.zeros([nx, ny], float)
    dy[:, :-1] = np.diff(corr_img, n=1, axis=1)
    
    cost = np.sqrt(dx**2 + dy**2)  
    return np.sum(cost)

def prime_discrete_gradient(w, image, flat0, components, components_mean, flat_mean):
    nx, ny = image.shape
    ncomponents = len(components)
    flat_dyn = flat0 + np.sum(w[:, None, None] * components, 0)
    factor = flat_mean + np.dot(w, components_mean)
    
    corr_img = image / flat_dyn
    dimg_dw = corr_img[None, :, :] * (components_mean[:, None, None] - factor / flat_dyn[None, :, :] * components)
    corr_img *= factor
    
    dx = np.zeros([nx, ny], float)
    dx[:-1, :] = np.diff(corr_img, n=1, axis=0)
    dy = np.zeros([nx, ny], float)
    dy[:, :-1] = np.diff(corr_img, n=1, axis=1)
    
    
    dxdw = np.zeros([ncomponents, nx, ny], float)
    dxdw[:, :-1, :] = np.diff(dimg_dw, n=1, axis=1)
    #dxdw = np.concatenate((dxdw, np.zeros([ncomponents, 1, ny])), axis=1)
    
    dydw = np.zeros([ncomponents, nx, ny], float)
    dydw[:, :, :-1] = np.diff(dimg_dw, n=1, axis=2)
    #dydw = np.concatenate((dydw, np.zeros([ncomponents, nx, 1])), axis=2)
    
    cost = np.sqrt(dx**2 + dy**2)[None, :, :]
    g = np.divide(dx * dxdw + dy * dydw, cost, where=(cost != 0))
    
    return np.sum(g, axis=(-1, -2))

#prime_discrete_gradient2 = grad(discrete_gradient, (0))

def correct_static(image, dark, flat):
    intensity = np.mean(image) / np.mean(flat)
    return image - dark, intensity * flat - dark

def correct_static_orig(image, dark, flat):
    return image - dark, flat - dark

def plot_corrected(im1, im2):
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 18))
    i1 = ax1.matshow(im1)
    ax1.axis(False)
    fig.colorbar(i1, ax=ax1)
    
    i2 = ax2.matshow(im2)
    ax2.axis(False)
    fig.colorbar(i2, ax=ax2)
    
    i3 = ax3.matshow(im1/im1.mean() - im2/im2.mean())
    ax3.axis(False)
    fig.colorbar(i3, ax=ax3)


In [None]:
tm0 = time.monotonic()
run = RunDirectory(rundir)
images = run[cam_source, "data.image.pixels"].ndarray()
ntrain, npulse, nx, ny = images.shape
tm1 = time.monotonic()

print(f"Ntrain: {ntrain}, Npulse: {npulse}, Image size: {nx} x {ny}")
print(f"N image: {ntrain * npulse}")
print(f"Read time: {tm1 - tm0:.2f} s")

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}_orig.h5"
dark, flat, components = read_constants(fn, cam_source)

nx, ny = dark.shape
ncomponents = components.shape[0]

print(f"Image size: {nx} x {ny} px")
print(f"N components: {ncomponents}")

dark_ds, flat_ds, components_ds = downscale_images(downsample_factors, dark, flat, components)
components_mean = np.mean(components_ds, axis=(-1,-2))

tm0 = time.monotonic()
fctr=10000000000.0

imcorr1 = np.zeros([128, nx, ny], float)

x0 = np.zeros(ncomponents)
i = 0

image_ds, =  downscale_images(downsample_factors, images[0, i])
image0, flat0 = correct_static_orig(image_ds, dark_ds, flat_ds)
flat_mean = np.mean(flat0)

#print(prime_discrete_gradient2(x0, image0, flat0, components_ds, components_mean, flat_mean))

prime_discrete_gradient(x0, image0, flat0, components_ds, components_mean, flat_mean)

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}_orig.h5"
dark, flat, components = read_constants(fn, cam_source)

nx, ny = dark.shape
ncomponents = components.shape[0]

print(f"Image size: {nx} x {ny} px")
print(f"N components: {ncomponents}")

dark_ds, flat_ds, components_ds = downscale_images(downsample_factors, dark, flat, components)
components_mean = np.mean(components_ds, axis=(-1,-2))

tm0 = time.monotonic()
fctr=10000000000.0

imcorr1 = np.zeros([128, nx, ny], float)

x0 = np.zeros(ncomponents)
for i in range(128):
    image_ds, =  downscale_images(downsample_factors, images[0, i])
    image0, flat0 = correct_static_orig(image_ds, dark_ds, flat_ds)
    flat_mean = np.mean(flat0)

    args = (image0, flat0, components_ds, components_mean, flat_mean)
    r = fmin_l_bfgs_b(discrete_gradient, x0, fprime=prime_discrete_gradient,
                      args=args, factr=fctr, iprint=0)
    
    image0, flat0 = correct_static_orig(images[0, i], dark, flat)
    imcorr1[i] = correct_dyn(r[0], image0, flat0, components)
    
    x0 = r[0]
    #print(r)
    
tm1 = time.monotonic()
print(f"Minimization time: {tm1 - tm0: .2f} s, per image: {(tm1 - tm0)/128: .2f} s")


In [None]:
i = 1
image0, flat0 = correct_static_orig(images[0, i], dark, flat)
plot_corrected(imcorr1[i], image0 / flat0)
plt.show()

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}.h5"
dark, flat, components = read_constants(fn, cam_source)

nx, ny = dark.shape
ncomponents = components.shape[0]

print(f"Image size: {nx} x {ny} px")
print(f"N components: {ncomponents}")

dark_ds, flat_ds, components_ds = downscale_images(downsample_factors, dark, flat, components)
components_mean = np.mean(components_ds, axis=(-1,-2))

tm0 = time.monotonic()
fctr=10000000000.0

imcorr2 = np.zeros([128, nx, ny], float)

x0 = np.zeros(ncomponents)
for i in range(128):
    image_ds, =  downscale_images(downsample_factors, images[0, i])
    image0, flat0 = correct_static(image_ds, dark_ds, flat_ds)
    flat_mean = np.mean(flat0)

    args = (image0, flat0, components_ds, components_mean, flat_mean)
    r = fmin_l_bfgs_b(discrete_gradient, x0, fprime=prime_discrete_gradient,
                      args=args, factr=fctr, iprint=0, pgtol=1e-15)
    
    image0, flat0 = correct_static(images[0, i], dark, flat)
    imcorr2[i] = correct_dyn(r[0], image0, flat0, components)
    x0 = r[0]
    #print(r)
    
tm1 = time.monotonic()
print(f"Minimization time: {tm1 - tm0: .2f} s, per image: {(tm1 - tm0)/128: .2f} s")


In [None]:
i = 1
image0, flat0 = correct_static(images[0, i], dark, flat)
plot_corrected(imcorr2[i], image0 / flat0)
plt.show()

In [None]:
i = 1
plot_corrected(imcorr1[i], imcorr2[i])
plt.show()

In [None]:
import sys
sys.path.append('../onlineVisual_OnNorm')
from dffc_functions_online import *

In [None]:
fn = f"pca_cam{camno}_d{runno_dark}_f{runno_flat}_r{n_components}_orig.h5"
dark, flat, components = read_constants(fn, cam_source)

nx, ny = dark.shape
ncomponents = components.shape[0]

print(f"Image size: {nx} x {ny} px")
print(f"N components: {ncomponents}")

pca_info = {
    'rank': ncomponents,
    'image_dimensions': (nx, ny),
    'mean_flat': flat,
    'mean_dark': dark,
    'components_matrix': components.reshape(ncomponents, nx * ny),
}

tm0 = time.monotonic()
imcorr3 = np.zeros([128, nx, ny], float)

for i in range(128):
    imcorr3[i] = dffc_correct_2d(images[0, i], pca_info, downsample_factors, fctr=10000000000.0)
    
tm1 = time.monotonic()
print(f"Minimization time: {tm1 - tm0: .2f} s, per image: {(tm1 - tm0)/128: .2f} s")
    

In [None]:
i = 1
plot_corrected(imcorr1[i], imcorr3[i])
plt.show()

In [None]:
i = 1
plot_corrected(imcorr2[i], imcorr3[i])
plt.show()