In [1]:
import numpy as np
import pywt
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.linear_model import LinearRegression

In [2]:
def load_image(image_path):
    """Load an image and convert it to grayscale."""
    image = Image.open(image_path).convert('L')
    return np.array(image)

def compute_multilevel_wavelet_transform(image, wavelet='haar', level=2):
    """Compute multi-level wavelet decomposition of an image."""
    coeffs = pywt.wavedec2(image, wavelet, level=level)
    return coeffs


In [3]:
def compute_difference(coeffs1, coeffs2):
    """Compute the difference in wavelet coefficients between two images."""
    diff_coeffs = []
    for level1, level2 in zip(coeffs1[1:], coeffs2[1:]):  # Skip level 0 (approximation)
        # Each level is a tuple (cA, (cH, cV, cD))
        cH1, cV1, cD1 = level1[1]
        cH2, cV2, cD2 = level2[1]
        # Compute differences in detail coefficients
        diff_coeffs.append((cH1, cV1, cD1, cH2, cV2, cD2))
    return diff_coeffs


In [4]:
def fit_wavelet_models(coeffs_diff):
    """Fit linear regression models to minimize error between wavelet coefficients of two images."""
    model = LinearRegression()
    fitted_models = []
    for cH1, cV1, cD1, cH2, cV2, cD2 in coeffs_diff:
        # Flatten arrays for regression
        fitted_models.append(
            (model.fit(cH1.reshape(-1, 1), cH2.reshape(-1, 1)).predict(cH1.reshape(-1, 1)),
             model.fit(cV1.reshape(-1, 1), cV2.reshape(-1, 1)).predict(cV1.reshape(-1, 1)),
             model.fit(cD1.reshape(-1, 1), cD2.reshape(-1, 1)).predict(cD1.reshape(-1, 1)))
        )
    return fitted_models


In [5]:
def plot_wavelet_difference(image1_path, image2_path, wavelet='haar', level=2):
    """Plot the differences in multi-level wavelet transforms of two images using matplotlib."""
    image1 = load_image(image1_path)
    image2 = load_image(image2_path)
    coeffs1 = compute_multilevel_wavelet_transform(image1, wavelet, level)
    coeffs2 = compute_multilevel_wavelet_transform(image2, wavelet, level)
    coeffs_diff = compute_difference(coeffs1, coeffs2)
    fitted_models = fit_wavelet_models(coeffs_diff)

    nrows, ncols = level, 3  # Define grid for subplots
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*4, nrows*4))
    plt.subplots_adjust(hspace=0.5, wspace=0.3)

    for i, (dh, dv, dd) in enumerate(fitted_models):
        for j, data in enumerate([dh, dv, dd]):
            ax = axes[i, j] if nrows > 1 else axes[j]
            ax.imshow(data.reshape(coeffs_diff[i][j*2].shape), cmap='gray', aspect='auto')
            ax.set_title(f'Level {i+1} - {"Horizontal" if j==0 else "Vertical" if j==1 else "Diagonal"} Fit')
            ax.axis('off')

    plt.show()


In [6]:

# Example usage
image1_path = '/home/admyyh/python_workspace/BasicSR/experiments/train_HMA_SRx2_from_DF2K_250k_smaller/visualization/baby/baby_18500.png'
image2_path = '/home/admyyh/python_workspace/BasicSR/datasets/Set5/GTmod12/baby.png'
plot_wavelet_difference(image1_path, image2_path, wavelet='haar', level=4)


ValueError: too many values to unpack (expected 3)