# Create Strips of Stop and Stare Data for Reconstruction
This notebook takes a list of existing datasets and performs registration and reconstruction of each stop-and-stare dataset in the stack

In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

# Load motiondeblur module and Dataset class
import libwallerlab.projects.motiondeblur as md
from libwallerlab.utilities.io import Dataset, isDataset
import libwallerlab.utilities.noise as noise

# Platform imports
import os, glob
import numpy as np

# Debugging imports
import llops as yp
import matplotlib.pyplot as plt

yp.config.setDefaultBackend('numpy')

## Load Data

In [None]:
# Define user for path setting
dataset_path = 'D:\Zack\Google Drive\einstein_data/'

# Define output directory
output_directory = 'D:\Zack\Google Drive\einstein_data\patches'

# Find files in this directory
folder_list = glob.glob(os.path.join(dataset_path, '*/'))
dataset_list = [folder for folder in folder_list if isDataset(folder)]

# Select only the stop and stare datasets
sns_dataset_list = [folder_name for folder_name in folder_list if 'stopandstare' in folder_name]
coded_dataset_list = [folder_name for folder_name in folder_list if 'coded' in folder_name]

# Select res target (for debugging)
sns_dataset_list = [s for s in sns_dataset_list if '173e' in s]
# coded_dataset_list = [s for s in coded_dataset_list if 'res' in s]

# Select dataset (for now)
dataset_index = 0

# Create dataset object for stop and stare
dataset = Dataset(sns_dataset_list[dataset_index])
dataset.metadata.type = 'motiondeblur'
dataset.channel_mask = [0]
md.preprocess(dataset)

# # Create dataset for coded illumination
# dataset_coded = Dataset(coded_dataset_list[dataset_index])
# dataset_coded.metadata.type = 'motiondeblur'
# dataset_coded.channel_mask = [0]
# md.preprocess(dataset_coded)

# Get linear segment count
linear_segment_count = len(dataset.position_segment_indicies)
# linear_segment_count_coded = len(dataset_coded.position_segment_indicies)
# assert linear_segment_count_coded == linear_segment_count



# First Step: Generate Raw Data and Corresponding Blurred Overlap

1. Take full unblurred measurement
2. Decimate
3. Convolve with decimated blur kernel (return valid kernel)
4. Crop decimated ground truth to 

Kernel offset to left should have true values on the right

In [None]:
# Clear all frames from memory
dataset._frame_list# = [None] * len(dataset._frame_list)

# Set frame mask
# dataset.frame_mask = [frame_index]

# Load frame
# frame = dataset.frame_list[0]

In [None]:
import scipy as sp
import numpy as np
import libwallerlab.utilities.noise as noise
downsample_factor = 8
compress_output = True
blur_kernel_fov_fraction = 0.2
frame_overlap_fraction = 0.25
blur_axis = 1
blur_direction = 'left'
debug=True

# Define noise models. Dict items are kwargs for libwallerlab.utilities.noise.add()
noise_models = {'gaussian': {'snr': 255}}

# Get frame shape
frame_shape = [sp.fftpack.next_fast_len(int(sz / downsample_factor)) for sz in dataset.frame_shape]

# Get measurement shap3
blur_vector_length = int((blur_kernel_fov_fraction * frame_shape[blur_axis]))

# Determine primary measurement shape
measurement_shape = yp.dcopy(frame_shape)
measurement_shape[blur_axis] = frame_shape[blur_axis] - blur_vector_length + 1
measurement_start = [0, 0]

# Determine shape of overlap
overlap_shape = yp.dcopy(frame_shape)
overlap_shape[blur_axis] = int(frame_shape[blur_axis] * frame_overlap_fraction)
overlap_start = [0, 0]
overlap_start[blur_axis] = frame_shape[blur_axis] - overlap_shape[blur_axis] 

# Get ground truth shape
ground_truth_shape = yp.dcopy(frame_shape)
ground_truth_shape[blur_axis] = measurement_shape[blur_axis] - overlap_shape[blur_axis]
ground_truth_start = [0, 0]

if debug:
    print('Frame shape is %s' % str(frame_shape))
    print('Blur vector length is %d' % blur_vector_length)
    print('Measurement shape is %s' % str(measurement_shape))
    print('Overlap shape is %s' % str(overlap_shape))
    print('Ground Truth shape is %s' % str(ground_truth_shape))

# Calculate size of first (left) frame
frame_1_shape = measurement_shape
frame_1_start = measurement_start

# Calculate size of first (left) frame
frame_2_shape = overlap_shape
frame_2_start = overlap_start

# Generate blurring function and decimate
blur_kernel = md.blurkernel.generate(frame_shape, blur_vector_length, axis=blur_axis, 
                                     position='center_' + blur_direction, method='random_phase')

# Also crop blur kernel for storage in output file
blur_kernel_crop = yp.crop_to_support(blur_kernel)

# Loop over measurements and generate datapoints
for frame_index in yp.display.progressBar(range(len(dataset.frame_mask)), name='Frames Saved'):
    # Clear all frames from memory
    dataset._frame_list = [None] * len(dataset._frame_list)
    
    # Set frame mask
    dataset.frame_mask = [frame_index]
    
    # Load frame
    frame = dataset.frame_list[0]
    
    # Decimite frame
    frame_decimated = yp.filter.decimate(frame, downsample_factor)

    # Convolve with blurring function ('valid' keyword crops kernel to )
    frame_convolved = yp.convolve(frame_decimated, blur_kernel, mode='same', padded=False)

    # Crop first frame's roi
    frame_1 = yp.crop(frame_convolved, frame_1_shape, frame_1_start)
    
    # Crop to second frame's ROI
    frame_2 = yp.crop(frame_convolved, frame_2_shape, frame_2_start)
    
    # Add noise to measurements
    for noise_mode in noise_models:
        frame_1 = noise.add(frame_1, noise_mode, **noise_models[noise_mode])
        frame_2 = noise.add(frame_2, noise_mode, **noise_models[noise_mode])
        
    # Generate ground truth image with correct support
    ground_truth = yp.crop(frame_decimated, ground_truth_shape, ground_truth_start)
    
    # Generate output filename
    _dir = os.path.join(output_directory, dataset.metadata.file_header) + '_%d' % frame_index
    
    # Define data structure
    data = {'ground_truth': {'array': ground_truth, 'start': (0,0)},
            'measurements': [{'array': frame_1, 'start': frame_1_start},
                             {'array': frame_2, 'start': frame_2_start}],
            'metadata': {'blur_direction': blur_direction, 
                         'blur_axis': blur_axis,
                         'blur_kernel_fov_fraction': blur_kernel_fov_fraction,
                         'frame_overlap_fraction': frame_overlap_fraction,
                         'measurement_shape': measurement_shape,
                         'ground_truth_shape': ground_truth_shape},
            'blur_vector': {'array': blur_kernel_crop, 'start': yp.boundingBox(blur_kernel)[0]}}
    
    # Save data
    if compress_output:
        np.savez_compressed(_dir, data)
    else:
        np.savez(_dir, data)



## Load and Display a Data Point

In [None]:
# Set data point index here
frame_index = 6

# Find frames
files = list(glob.glob(output_directory + '*.npz'))
files.sort()

# Load data point (second line deals with weird structuring of .npz files)
_data = dict(np.load(files[frame_index]))
data = {key:_data[key].item() for key in _data}['arr_0']

plt.figure()
plt.subplot(141)
plt.imshow(yp.real(data['measurements'][0]['array']))
plt.title('frame 1')
plt.subplot(142)
plt.imshow(yp.real(data['measurements'][1]['array']))
plt.title('frame 2')
plt.subplot(143)
plt.imshow(yp.real(data['ground_truth']['array']))
plt.title('ground truth')
plt.subplot(144)
plt.plot(yp.real(np.squeeze(data['blur_vector']['array'])))
plt.title('Blur Sequence')


# Deconvolve a Data Point Using L2 Deconvolution

In [None]:
# Set data point index here
frame_index = 4

# Find frames
files = list(glob.glob(output_directory + '*.npz'))
files.sort()

# Load data point (second line deals with weird structuring of .npz files)
_data = dict(np.load(files[frame_index]))
data = {key:_data[key].item() for key in _data}['arr_0']

blur_vector = data['blur_vector']['array']
measurement_shape = data['metadata']['measurement_shape']
ground_truth_shape = data['metadata']['ground_truth_shape']
blur_axis = data['metadata']['blur_axis']
blur_direction = data['metadata']['blur_direction']

frame_1 = data['measurements'][0]['array']
frame_2 = data['measurements'][1]['array']
ground_truth = data['ground_truth']['array']

# Get ROIs
frame_1_roi = yp.Roi(start=data['measurements'][0]['start'], shape=yp.shape(frame_1))
frame_2_roi = yp.Roi(start=data['measurements'][1]['start'], shape=yp.shape(frame_2))

# Average measurements
coverage_weights = yp.zeros(measurement_shape)
coverage_weights[frame_1_roi.slice] += 1.0
coverage_weights[frame_2_roi.slice] += 1.0
measurements = (yp.pad(frame_1, measurement_shape, frame_1_roi.start) + yp.pad(frame_2, measurement_shape, frame_2_roi.start)) / coverage_weights

# Create blur kernel with the correct position in the frame
# blur_kernel_crop = yp.roll(md.blurkernel.fromVector(blur_vector, measurement_shape, axis=blur_axis, position='center_' + blur_direction), (-1, 2))
import math
blur_kernel_crop = yp.roll(yp.pad(blur_vector, measurement_shape, center=True), (0, -math.ceil(yp.size(blur_vector) / 2) + 3))
# Generate operators
import ndoperators as ops
CR = ops.Crop(measurement_shape, ground_truth_shape, crop_start=(0,0))
C = ops.Convolution(blur_kernel_crop, dtype='complex32', mode='circular')
y = measurements

# Define deconvolution method
method = 'direct'

if method == 'gd':
    objective = ops.solvers.objectivefunctions.L2(C, y)
    gd = ops.solvers.GradientDescent(objective)
    x_opt = CR * gd.solve(iteration_count=1000, step_size=1e-3)
    
elif method == 'direct':
    C.inverse_regularizer = 1e-1
    x_opt = CR * C.inv * y

plt.figure()
plt.subplot(141)
plt.imshow(yp.abs(frame_1))
plt.title('Raw')
plt.subplot(142)
plt.imshow(yp.abs(x_opt))
plt.title('Recovered')
plt.subplot(143)
plt.imshow(yp.abs(ground_truth))
plt.title('True')
plt.subplot(144)
plt.imshow(yp.abs(ground_truth - x_opt))
plt.title('Error')