# Simple Example

Before running the algorithm, please install the package first. For installation detail please check `README.rst` in the root directory.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from phre import ImageData, measure_data
from phre.optimizer import find_initial_image, RelaxSplit, ADMM

## Load data

In [None]:
image = np.load('../data/mnist_digit.npy')
data = ImageData(image)

In [None]:
data.show_image()

## Hyper Paramters

In [None]:
obj_type = 'hard'
nu = 1.0
rho = 1.0
obs_std = 0.0

sim_type = 'no_noise' if obs_std == 0.0 else 'with_noise'

figure_name = '_'.join([obj_type, sim_type]) + '.pdf'

## Create observations

In [None]:
obs_mat, obs = measure_data(data,
                            num_obs=data.image_size*4,
                            obs_std=obs_std,
                            normalize_obs_mat=True)

## Initialize variables

In [None]:
init_x = find_initial_image(obs_mat, obs)
init_w = obs_mat.dot(init_x)

In [None]:
plt.imshow(init_x.reshape(data.image_shape))

## Create solver and optimize the problem

In [None]:
solver_rs = RelaxSplit(obs_mat, obs, obj_type=obj_type, nu=nu)

In [None]:
result_rs = solver_rs.phase_retrieval(init_x=init_x,
                                      init_w=init_w,
                                      max_iter=500,
                                      verbose=True)

In [None]:
solver_admm = ADMM(obs_mat, obs, obj_type=obj_type, rho=rho)

In [None]:
result_admm = solver_admm.phase_retrieval(init_x=init_x,
                                          init_w=init_w,
                                          max_iter=500,
                                          verbose=True)

## Final result

In [None]:
plt.imshow(result_rs.reshape(data.image_shape))

In [None]:
plt.imshow(result_admm.reshape(data.image_shape))

## Convergence History

In [None]:
plt.semilogy(solver_rs.obj_his, '.-', label='RS')
if sim_type == 'with_noise':
    plt.semilogy(solver_admm.obj_his[:len(solver_rs.obj_his) + 50], '.-', label='ADMM')
else:
    plt.semilogy(solver_admm.obj_his, '.-', label='ADMM')
plt.legend()
plt.xlabel('number of iterations')
plt.ylabel('objective value')
plt.savefig(figure_name, bbox_inches='tight')