In [None]:


import os
import matlab.engine # the matlab engine for python
import cv2
import numpy as np
from pathlib import Path
from diffcurve.plot_utils import plot_images, remove_frame
import matplotlib.pyplot as plt
from diffcurve.fdct2d.curvelet_2d import get_curvelet_system
from diffcurve.utils import get_project_root
from diffcurve.fdct2d.torch_frontend import torch_fdct_2d, torch_ifdct_2d
import torch
from diffcurve.fdct2d.numpy_frontend import perform_ifft2



: 

## Sample image

In [None]:

project_root = get_project_root()

lena_file = Path.joinpath(project_root, "data/Lena.jpg")

lena_img = cv2.imread(str(lena_file), 0).astype(float) / 255

plt.figure(figsize = (3,3))
plt.imshow(lena_img, cmap = 'gray')

## Get the curvelet system

In [None]:

dct_kwargs = {
    'is_real': 0.0, # complex-valued curvelets
    'finest': 2.0, # use wavelets at the finest level
    'nbscales': 6.0,
    'nbangles_coarse': 16.0}


curvelet_system, curvelet_coeff_dim = get_curvelet_system(lena_img.shape[0],
                                                          lena_img.shape[1],
                                                          dct_kwargs)

curvelet_support_size =  np.prod(np.array(curvelet_coeff_dim), 1)

## Show a curvelet

In [None]:
curvelet_idx = 1
fig, axes = plt.subplots(1, 2)


axes[0].imshow( (curvelet_system[curvelet_idx]).real, cmap='gray')
axes[0].set_title('freq')

axes[1].imshow( perform_ifft2(curvelet_system[curvelet_idx]).real, cmap='gray')
axes[1].set_title('spatial')

## Run the curvelet transform and its inverse

In [None]:
torch_coeff = torch_fdct_2d(torch.from_numpy(lena_img),
                            torch.from_numpy(curvelet_system))

torch_decomp = torch_ifdct_2d(torch_coeff, torch.from_numpy(curvelet_system),
                              torch.from_numpy(curvelet_support_size ) )

coeff = np.array(torch_coeff.detach().cpu())
decomp = np.array(torch_decomp.detach().cpu())


## Show some curvelet coeffs

In [None]:
num_curvelets_to_show = 25


curvelets_coeff_to_show = [
    coeff[curvelet_idx].real.T for curvelet_idx in range(num_curvelets_to_show) ]


im, axes = plot_images(curvelets_coeff_to_show, nrows = 5, ncols = 5,
                       vrange = 'individual', cbar='none', cmap='gray',
                       fig_size = (6, 6))

[remove_frame(ax) for ax in axes.flatten()];

## Show some weighted curvelets

In [None]:


weighted_curvelets_to_show = [
    decomp[curvelet_idx].real.T for curvelet_idx in range(num_curvelets_to_show) ]


im, axes = plot_images(weighted_curvelets_to_show, nrows = 5, ncols = 5,
                       vrange = 'individual', cbar='none', cmap='gray',
                       fig_size = (6, 6))

[remove_frame(ax) for ax in axes.flatten()];

## Show the reconstructed image from the curvelet transform

In [None]:
fig, axes = plt.subplots(1, 3)

axes[0].imshow(lena_img , cmap='gray' )
axes[0].set_title('input')

axes[1].imshow( decomp.sum(0).real, cmap='gray' )
axes[1].set_title('recon')

mse = np.mean( (decomp.sum(0).real - lena_img) ** 2 )
axes[2].imshow( np.abs(decomp.sum(0).real - lena_img), cmap='gray' )
axes[2].set_title(f'|recon - input|')

print(f'MSE = { mse}')

[remove_frame(ax) for ax in axes.flatten()];