# Exploratory Data Analysis
A first exploration of the datasets created compressing pristine images from various SOTA papers.

## Libraries

In [None]:
import numpy as np
import glob
import os
import sys
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import pandas as pd
import sys
sys.path.append('../')

## Let's find all the images inside the various folders

In [None]:
if os.path.exists('/nas/public/exchange/JPEG-AI/data/TEST/data_info.csv'):
    # Load the CSV directly
    all_images = pd.read_csv('/nas/public/exchange/JPEG-AI/data/TEST/data_info.csv')
else:
    # Look for all the images inside the directory (avoid binary files)
    all_images = pd.DataFrame([path for path in glob.glob(os.path.join('/nas/public/exchange/JPEG-AI/data/TEST/**/*.*'), recursive=True) if 'ipynb' not in path], columns=['path'])
    # add other useful info
    all_images['dataset'] = all_images['path'].apply(lambda x: x.split('/')[7])
    all_images['compressed'] = all_images['path'].apply(lambda x: True if 'compressed' in x else False)
    all_images['target_bpp'] = all_images['path'].apply(lambda x: x.split('target_bpp_')[1].split('/')[0] if 'target_bpp' in x else None)
    all_images['target_bpp'] = all_images['target_bpp'].apply(lambda x: float(x)/100 if x is not None else None)
    all_images['filename'] = all_images['path'].apply(lambda x: os.path.basename(x))
    content_dict = {'imagenet': 'various', 'celeba': 'faces', 'ffhq': 'faces', 'coco': 'various', 'raise': 'various', 'laion': 'various'}
    all_images['content'] = all_images['dataset'].apply(lambda x: content_dict[x] if 'lsun' not in x else None)
    # Fix LSUN
    for i, r in all_images.loc[all_images['dataset']=='lsun'].iterrows():
        if r['compressed']:
            all_images.loc[i, 'content'] = r['path'].split('/')[10]
        else:
            all_images.loc[i, 'content'] = r['path'].split('/')[9]
    # Let's add info on the single image sizes
    all_images['size'] = all_images['path'].apply(lambda x: Image.open(x).size)
    # Let's save the data into a csv
    all_images.to_csv('/nas/public/exchange/JPEG-AI/data/TEST/data_info.csv', index=False)
all_images

## Let's plot some info about the dataset

In [None]:
# Number of samples divided by dataset
all_images.groupby('dataset').count()['path'].plot.bar(figsize=(12, 9)), plt.title('Number of samples per dataset'), plt.show()

# Number of compressed samples
all_images.groupby('compressed').count()['content'].plot.bar(figsize=(12, 9)), plt.title('Number of compressed samples'), plt.show()

# Number of samples divided by semantic category
all_images.groupby('content').count()['path'].plot.bar(figsize=(12, 9)), plt.title('Number of samples per category'), plt.show()

# Number of samples divided by target BPP
all_images.loc[all_images['compressed']].groupby('target_bpp').count()['path'].plot.bar(figsize=(12, 9)), plt.title('Number of samples per BPP'), plt.show()

# Number of samples divided by size
all_images.groupby('size').count()['dataset'].plot.bar(figsize=(12, 9)), plt.title('Number of samples per size'), plt.show()

# Number of samples divided by size and dataset
all_images.groupby(['size', 'dataset']).count()['path'].unstack().plot.bar(figsize=(12, 9)), plt.title('Number of samples per size and dataset'), plt.show()

## Great! Now let's look at some quality metrics
# THIS TAKES TOO LONG!
I computed it offline and placed it inside a different folder.

In [None]:
# sys.path.append('/nas/home/ecannas/third_party_code/jpeg-ai-reference-software')
# from PIL import Image
# import torch
# from src.codec.metrics.metrics import DataClass, MetricsProcessor, MetricsFabric
# from tqdm.notebook import tqdm

In [None]:
# # Prepare the metrics processor
# metrics = MetricsProcessor()
# metrics.internal_bits = 10
# metrics.jvet_psnr = False
# metrics.metrics = MetricsFabric.metrics_list
# metrics.metrics_output = [metric for metric in MetricsFabric.metrics_list]
# metrics.color_conv = '709'
# metrics.max_samples_for_eval_on_gpu = -1
# gpu = 3
# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

In [None]:
# dataset_dict = []
# for dataset in all_images['dataset'].unique():
#     print(f'Doing dataset {dataset}...')
#     # select the images from the dataset
#     images_df = all_images.loc[all_images['dataset']==dataset]
#     orig_df = images_df.loc[~images_df['compressed']].iloc[:5]
#     compr_df = images_df.loc[images_df['compressed']]
#     
#     # Create the metrics Dataframe
#     images_dict = []
#     
#     # Cycle over the different pristine samples
#     for i, r in tqdm(orig_df.iterrows()):
#         # Load the original sample
#         filename, dataset, content = r['filename'], r['dataset'], r['content']
#         orig_sample, _ = DataClass().load_image(r['path'], color_conv='709', device='cuda')
#         # Find the corresponding compressed samples
#         comp_samples = compr_df.loc[(compr_df['filename']==filename.replace('jpg', 'png')) \
#                                     & (compr_df['dataset']==dataset) \
#                                     & (compr_df['content']==content)]
#         # Cycle over the various BPPs values
#         bpps_dict = []
#         for ii, rr in comp_samples.iterrows():
#             # Load the compressed samples
#             comp_sample, _ = DataClass().load_image(rr['path'], color_conv='709', device='cuda')
#             # Compute the metrics
#             metrics_vals = metrics.process_images(orig_sample, comp_sample)
#             # Save them
#             bpps_dict.append(pd.DataFrame.from_dict({rr['target_bpp']: {metric: metrics_vals[idx] for idx, metric in enumerate(metrics.metrics_output)}},
#                                                     orient='index'))
#         
#         # Append the image path to the Dataframe and save it
#         images_dict.append(pd.concat({r['path']: pd.concat(bpps_dict)}, names=['path', 'bpp']))
#     
#     # Append the images dictionary
#     dataset_dict.append(pd.concat({dataset: pd.concat(images_dict)}, names=['dataset', 'path', 'bpp']))
# 
#             
#         

In [None]:
# metrics_info = pd.concat(dataset_dict)
# metrics_info
metrics_info = pd.read_csv('/nas/public/exchange/JPEG-AI/code/quality_report/quality_report.csv')
print(metrics_info)

## Let's plot some graphics

In [None]:
# Use path as the index
metrics_info.set_index('path', inplace=True)
# Add the BPPs to the index
metrics_info.set_index('bpp', append=True, inplace=True)
# Add the dataset to the index
metrics_info.set_index('dataset', append=True, inplace=True)
metrics_info

In [None]:
labels_dict = {'msssim_torch': 'Multi-scale SSIM (MS-SSIM)',
  'msssim_iqa': 'Multi-scale SSIM IQA',
  'psnr': 'Peak Signal-to-Noise Ratio (PSNR)',
  'vif': 'Visual Information Fidelity (VIF)',
  'fsim': 'Feature Similarity (FSIM)',
  'nlpd': 'Normalized Laplacian Pyramid (NLPD)',
  'iw-ssim': 'IW-SSIM',
  'vmaf': ' Video Multimethod Assessment Fusion (VMAF)',
  'psnr_hvs': 'Peak Signal-to-Noise Ratio (PSNR) in the HSV space'}

ylabel_dict = {'msssim_torch': 'MS-SSIM [0-1]',
  'msssim_iqa': 'MS-SSIM [0-1]',
  'psnr': 'PSNR [dB]',
  'vif': 'VIF',
  'fsim': 'FSIM',
  'nlpd': 'NLPD [0-1]',
  'iw-ssim': 'IW-SSIM [0-1]',
  'vmaf': ' VMAF [0-1]',
  'psnr_hvs': 'PSNR-HSV [dB]'}

# Let's plot the single metrics
for metric in metrics_info.columns:
    fig, axs = plt.subplots(1, 1, figsize=(12, 3))
    for dataset in metrics_info.index.get_level_values(2).unique():
        dataset_info = metrics_info.swaplevel(0, 2).loc[dataset]
        avg_metrics = dataset_info.groupby('bpp').mean()[metric]
        avg_metrics.sort_index(inplace=True)
        axs.plot(avg_metrics, label=dataset)
    plt.legend()
    plt.grid()
    plt.title(f'{labels_dict[metric]} results at various BPPs')
    plt.ylabel(ylabel_dict[metric])
    plt.xticks(ticks=np.linspace(avg_metrics.index[0], avg_metrics.index[-1], len(avg_metrics.index)),
               labels=avg_metrics.index.tolist())
    plt.xlabel('Bit-Per-Pixels')
    plt.show()

Let's have another plot where we put the absolute scale for the reference metric 

In [None]:
labels_dict = {'msssim_torch': 'Multi-scale SSIM (MS-SSIM)',
  'msssim_iqa': 'Multi-scale SSIM IQA',
  'psnr': 'Peak Signal-to-Noise Ratio (PSNR)',
  'vif': 'Visual Information Fidelity (VIF)',
  'fsim': 'Feature Similarity (FSIM)',
  'nlpd': 'Normalized Laplacian Pyramid (NLPD)',
  'iw-ssim': 'IW-SSIM',
  'vmaf': ' Video Multimethod Assessment Fusion (VMAF)',
  'psnr_hvs': 'Peak Signal-to-Noise Ratio (PSNR) in the HSV space'}

ylabel_dict = {'msssim_torch': 'MS-SSIM [0-1]',
  'msssim_iqa': 'MS-SSIM [0-1]',
  'psnr': 'PSNR [dB]',
  'vif': 'VIF',
  'fsim': 'FSIM',
  'nlpd': 'NLPD [0-1]',
  'iw-ssim': 'IW-SSIM [0-1]',
  'vmaf': ' VMAF [0-1]',
  'psnr_hvs': 'PSNR-HSV [dB]'}

scale_dict = {'msssim_torch': [0, 1],
  'msssim_iqa': [0, 1],
  'psnr': [0, 51],
  'vif': [0, 56],
  'fsim': [0, 56],
  'nlpd': [0, 1],
  'iw-ssim': [0, 1],
  'vmaf': [0, 1],
  'psnr_hvs': [0, 1]}

# Let's plot the single metrics
for metric in metrics_info.columns:
    fig, axs = plt.subplots(1, 1, figsize=(12, 3))
    for dataset in metrics_info.index.get_level_values(2).unique():
        dataset_info = metrics_info.swaplevel(0, 2).loc[dataset]
        avg_metrics = dataset_info.groupby('bpp').mean()[metric]
        avg_metrics.sort_index(inplace=True)
        axs.plot(avg_metrics, label=dataset)
    plt.legend()
    plt.grid()
    plt.title(f'{labels_dict[metric]} results at various BPPs')
    plt.ylabel(ylabel_dict[metric])
    plt.ylim(scale_dict[metric])
    plt.xticks(ticks=np.linspace(avg_metrics.index[0], avg_metrics.index[-1], len(avg_metrics.index)),
               labels=avg_metrics.index.tolist())
    plt.xlabel('Bit-Per-Pixels')
    plt.show()

## How bad do the samples look for the different datasets?

In [None]:
fsize=20
for dataset in all_images['dataset'].unique():
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(1, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    comp_samples = comp_df.loc[(comp_df['filename']==sample['filename'].item().replace('jpg', 'png')) & (comp_df['content']==sample['content'].item())]
    # Plot everything
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, (i, r) in enumerate(comp_samples.iterrows()):
        row_idx = (idx+1)//3
        col_idx = idx if i <3 else idx-3+1
        axs[row_idx][col_idx].imshow(Image.open(r['path']).convert('RGB'))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{r["target_bpp"]} BPP', fontsize=fsize-5)
    axs[0][0].imshow(Image.open(sample['path'].item()).convert('RGB')) 
    axs[0][0].axis('off')
    axs[0][0].set_title('Uncompressed sample', fontsize=fsize-5)
    fig.suptitle(f'{dataset} sample', fontsize=fsize)
    plt.show()
    
    

### Let's look at the center crop in 224x224

In [None]:
fsize=20
for dataset in all_images['dataset'].unique():
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(1, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    comp_samples = comp_df.loc[(comp_df['filename']==sample['filename'].item().replace('jpg', 'png')) & (comp_df['content']==sample['content'].item())]
    # Get the crop coordinates for a 224x224 patch
    height, width = int(sample['size'].item().split('(')[1].split(',')[0]), int(sample['size'].item().split(')')[0].split(', ')[1]) 
    left = (width - 224)/2
    top = (height - 224)/2
    right = (width + 224)/2
    bottom = (height + 224)/2
    # Plot everything
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, (i, r) in enumerate(comp_samples.iterrows()):
        row_idx = (idx+1)//3
        col_idx = idx if i <3 else idx-3+1
        axs[row_idx][col_idx].imshow(Image.open(r['path']).convert('RGB').crop((left, top, right, bottom)))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{r["target_bpp"]} BPP', fontsize=fsize-5)
    axs[0][0].imshow(Image.open(sample['path'].item()).convert('RGB').crop((left, top, right, bottom))) 
    axs[0][0].axis('off')
    axs[0][0].set_title('Uncompressed sample', fontsize=fsize-5)
    fig.suptitle(f'{dataset} sample', fontsize=fsize)
    plt.show()
    
    

### Are there some more visible artifacts in some datasets than others?

In [None]:
import cv2

def compute_spectrum_module(img: Image) -> np.ndarray:
    # Convert the image to grayscale
    img = np.array(img.convert('L'))
    
    # Create the median blurred image
    blur_img = cv2.medianBlur(img, 5)
    
    # High pass filtering
    hp_img = cv2.subtract(img, blur_img)
    
    # Compute the FFT
    fft_img = np.fft.fftshift(np.fft.fft2(hp_img))
    
    # Return the module
    return np.abs(fft_img)

In [None]:
fsize=20
for dataset in all_images['dataset'].unique():
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(1, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    comp_samples = comp_df.loc[(comp_df['filename']==sample['filename'].item().replace('jpg', 'png')) & (comp_df['content']==sample['content'].item())]
    # Get the crop coordinates for a 224x224 patch
    height, width = int(sample['size'].item().split('(')[1].split(',')[0]), int(sample['size'].item().split(')')[0].split(', ')[1]) 
    left = (width - 224)/2
    top = (height - 224)/2
    right = (width + 224)/2
    bottom = (height + 224)/2
    # Plot everything
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, (i, r) in enumerate(comp_samples.iterrows()):
        row_idx = (idx+1)//3
        col_idx = idx if i <3 else idx-3+1
        spec = compute_spectrum_module(Image.open(r['path']).convert('RGB').crop((left, top, right, bottom)))
        qmin, qmax = np.quantile(spec.flatten(), [0.5, 0.999])
        #axs[row_idx][col_idx].imshow(spec, vmin=qmin, vmax=qmax)
        axs[row_idx][col_idx].imshow(np.log(1+spec))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{r["target_bpp"]} BPP spectrum', fontsize=fsize-5)
    spec = compute_spectrum_module(Image.open(r['path']).convert('RGB').crop((left, top, right, bottom)))
    #qmin, qmax = np.quantile(spec.flatten(), [0.5, 0.999])
    #axs[0][0].imshow(spec, vmin=qmin, vmax=qmax) 
    axs[0][0].imshow(np.log(1+spec))
    axs[0][0].axis('off')
    axs[0][0].set_title('Uncompressed sample spectrum', fontsize=fsize-5)
    fig.suptitle(f'{dataset} sample', fontsize=fsize)
    plt.show()
    
    

## Does not seem to be a clear pattern here...
Let's try to compute the average spectrum from multiple images

In [None]:
import cv2
from typing import Tuple

# --- THIS IS THE OHJA IMPLEMENTATION ACCORDING TO THEIR PAPER, DOES NOT MAKE SENSE TO ME
# def compute_avg_spectrum_module(df: pd.DataFrame) -> np.ndarray:
    
#     avg_spec = []
#     for i, r in df.iterrows():
        
#         # Convert the image to grayscale
#         img = Image.open(r['path']).convert('L')
        
#         # Crop the image
#         height, width = img.size
#         left = (width - 224)/2
#         top = (height - 224)/2
#         right = (width + 224)/2
#         bottom = (height + 224)/2
#         img = np.array(img.crop((left, top, right, bottom)))

#         # Create the median blurred image
#         blur_img = cv2.medianBlur(img, 5)

#         # High pass filtering
#         hp_img = cv2.subtract(img, blur_img)
        
#         # Remove DC component
#         #hp_img = cv2.subtract(hp_img, hp_img.mean())

#         # Compute the FFT
#         avg_spec.append(hp_img[np.newaxis, :, :])
    
#     # Compute the average
#     avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
#     # Compute the spectrum
#     avg_spec = np.fft.fftshift(np.fft.fft2(avg_spec))
    
#     # Return the module
#     return np.abs(avg_spec)

def compute_avg_spectrum_module(df: pd.DataFrame) -> np.ndarray:
    
    avg_spec = []
    for i, r in df.iterrows():
        
        # Convert the image to grayscale
        img = Image.open(r['path']).convert('L')
        
        # Crop the image
        height, width = img.size
        left = (width - 224)/2
        top = (height - 224)/2
        right = (width + 224)/2
        bottom = (height + 224)/2
        img = np.array(img.crop((left, top, right, bottom)))

        # Create the median blurred image
        blur_img = cv2.medianBlur(img, 5)

        # High pass filtering
        hp_img = cv2.subtract(img, blur_img)
        
        # Remove DC component
        #hp_img = cv2.subtract(hp_img, hp_img.mean())

        # Compute the FFT
        spec = np.abs(np.fft.fftshift(np.fft.fft2(hp_img)))
        avg_spec.append(spec[np.newaxis, :, :])
    
    # Compute the average
    avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
    # Return the module
    return avg_spec

In [None]:
from tqdm.notebook import tqdm

fsize=20
avg_specs = dict()
for dataset in tqdm(all_images['dataset'].unique()):
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(50, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    # Merge the dataframes on 'filename' and 'content'
    sample['filename'] = sample['filename'].apply(lambda x: x.replace('jpg', 'png'))
    comp_samples = pd.merge(comp_df, sample, left_on=['filename', 'content'], right_on=['filename', 'content'],
                        suffixes=('', '_uncompressed'))
    # Compute them
    avg_specs[dataset] = dict()
    # Compressed samples
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        spec = compute_avg_spectrum_module(comp_samples.loc[comp_samples['target_bpp']==t_bpp])
        avg_specs[dataset][t_bpp] = spec
    # Uncompressed samples
    spec = compute_avg_spectrum_module(sample)
    avg_specs[dataset]['unc'] = spec
    
    

In [None]:
fsize = 20
for dataset, specs in avg_specs.items():
    # Prepare the figure
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        qmin, qmax = np.quantile(specs[t_bpp].flatten(), [0.01, 0.999])
        axs[row_idx][col_idx].imshow(specs[t_bpp], vmax=qmax)
        #axs[row_idx][col_idx].imshow(np.log(1+specs[t_bpp]))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{t_bpp} BPP avg spectrum', fontsize=fsize-5)
    qmin, qmax = np.quantile(specs['unc'].flatten(), [0.01, 0.999])
    axs[0][0].imshow(specs['unc'], vmax=qmax)
    #axs[0][0].imshow(np.log(1+specs['unc']))
    axs[0][0].axis('off')
    axs[0][0].set_title(f'Uncompressed avg spectrum', fontsize=fsize-5)
    fig.suptitle(f'{dataset} average spectrums', fontsize=fsize)
    plt.show()

### Let's try again but with Laplace filtering instead

In [None]:
import cv2
from typing import Tuple

# --- THIS IS THE OHJA IMPLEMENTATION ACCORDING TO THEIR PAPER, DOES NOT MAKE SENSE TO ME
# def compute_avg_spectrum_module(df: pd.DataFrame) -> np.ndarray:
    
#     avg_spec = []
#     for i, r in df.iterrows():
        
#         # Convert the image to grayscale
#         img = Image.open(r['path']).convert('L')
        
#         # Crop the image
#         height, width = img.size
#         left = (width - 224)/2
#         top = (height - 224)/2
#         right = (width + 224)/2
#         bottom = (height + 224)/2
#         img = np.array(img.crop((left, top, right, bottom)))

#         # Create the median blurred image
#         blur_img = cv2.medianBlur(img, 5)

#         # High pass filtering
#         hp_img = cv2.subtract(img, blur_img)
        
#         # Remove DC component
#         #hp_img = cv2.subtract(hp_img, hp_img.mean())

#         # Compute the FFT
#         avg_spec.append(hp_img[np.newaxis, :, :])
    
#     # Compute the average
#     avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
#     # Compute the spectrum
#     avg_spec = np.fft.fftshift(np.fft.fft2(avg_spec))
    
#     # Return the module
#     return np.abs(avg_spec)

def compute_avg_spectrum_module(df: pd.DataFrame) -> np.ndarray:
    
    avg_spec = []
    for i, r in df.iterrows():
        
        # Convert the image to grayscale
        img = Image.open(r['path']).convert('L')
        
        # Crop the image
        height, width = img.size
        left = (width - 224)/2
        top = (height - 224)/2
        right = (width + 224)/2
        bottom = (height + 224)/2
        img = np.array(img.crop((left, top, right, bottom)))

        # High pass filtering
        hp_img = cv2.Laplacian(img, cv2.CV_64F)
        
        # Remove DC component
        #hp_img = cv2.subtract(hp_img, hp_img.mean())

        # Compute the FFT
        spec = np.abs(np.fft.fftshift(np.fft.fft2(hp_img)))
        avg_spec.append(spec[np.newaxis, :, :])
    
    # Compute the average
    avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
    # Return the module
    return avg_spec

In [None]:
from tqdm.notebook import tqdm

fsize=20
avg_specs = dict()
for dataset in tqdm(all_images['dataset'].unique()):
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(50, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    # Merge the dataframes on 'filename' and 'content'
    sample['filename'] = sample['filename'].apply(lambda x: x.replace('jpg', 'png'))
    comp_samples = pd.merge(comp_df, sample, left_on=['filename', 'content'], right_on=['filename', 'content'],
                        suffixes=('', '_uncompressed'))
    # Compute them
    avg_specs[dataset] = dict()
    # Compressed samples
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        spec = compute_avg_spectrum_module(comp_samples.loc[comp_samples['target_bpp']==t_bpp])
        avg_specs[dataset][t_bpp] = spec
    # Uncompressed samples
    spec = compute_avg_spectrum_module(sample)
    avg_specs[dataset]['unc'] = spec
    
    

In [None]:
fsize = 20
for dataset, specs in avg_specs.items():
    # Prepare the figure
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        qmin, qmax = np.quantile(specs[t_bpp].flatten(), [0.01, 0.999])
        axs[row_idx][col_idx].imshow(specs[t_bpp], vmax=qmax)
        #axs[row_idx][col_idx].imshow(np.log(1+specs[t_bpp]))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{t_bpp} BPP avg spectrum', fontsize=fsize-5)
    qmin, qmax = np.quantile(specs['unc'].flatten(), [0.01, 0.999])
    axs[0][0].imshow(specs['unc'], vmax=qmax)
    #axs[0][0].imshow(np.log(1+specs['unc']))
    axs[0][0].axis('off')
    axs[0][0].set_title(f'Uncompressed avg spectrum', fontsize=fsize-5)
    fig.suptitle(f'{dataset} average spectrums', fontsize=fsize)
    plt.show()

### These figures don't convince me a lot... Can we compute the residual using a DnCNN?

In [None]:
import cv2
from typing import Tuple
import torch

# --- THIS IS THE OHJA IMPLEMENTATION ACCORDING TO THEIR PAPER, DOES NOT MAKE SENSE TO ME
# def compute_avg_spectrum_module(df: pd.DataFrame) -> np.ndarray:
    
#     avg_spec = []
#     for i, r in df.iterrows():
        
#         # Convert the image to grayscale
#         img = Image.open(r['path']).convert('L')
        
#         # Crop the image
#         height, width = img.size
#         left = (width - 224)/2
#         top = (height - 224)/2
#         right = (width + 224)/2
#         bottom = (height + 224)/2
#         img = np.array(img.crop((left, top, right, bottom)))

#         # Create the median blurred image
#         blur_img = cv2.medianBlur(img, 5)

#         # High pass filtering
#         hp_img = cv2.subtract(img, blur_img)
        
#         # Remove DC component
#         #hp_img = cv2.subtract(hp_img, hp_img.mean())

#         # Compute the FFT
#         avg_spec.append(hp_img[np.newaxis, :, :])
    
#     # Compute the average
#     avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
#     # Compute the spectrum
#     avg_spec = np.fft.fftshift(np.fft.fft2(avg_spec))
    
#     # Return the module
#     return np.abs(avg_spec)

def compute_avg_spectrum_module(df: pd.DataFrame, model: torch.nn.Module, device: str) -> np.ndarray:
    
    avg_spec = []
    for i, r in df.iterrows():
        
        # Convert the image to grayscale
        img = Image.open(r['path']).convert('RGB')
        
        # Crop the image
        height, width = img.size
        left = (width - 224)/2
        top = (height - 224)/2
        right = (width + 224)/2
        bottom = (height + 224)/2
        img = np.array(img.crop((left, top, right, bottom)))
        
        # Convert the img to grayscale
        img = img.astype(np.float32)
        img = (0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2])
        
        # Convert the dynamic to 0-1
        img /= 256.0

        # De-noised image
        dn_img = model(torch.Tensor(img).unsqueeze(0).unsqueeze(0).to(device)).squeeze().cpu().numpy()
        
        # High-pass filtered img
        hp_img = img-dn_img
        
        # Compute the FFT
        spec = np.abs(np.fft.fftshift(np.fft.fft2(hp_img)))
        avg_spec.append(spec[np.newaxis, :, :])
    
    # Compute the average
    avg_spec = np.squeeze(np.mean(avg_spec, axis=0))
    
    # Return the module
    return avg_spec

In [None]:
# --- Load the network and device
from utils.third_party.KAIR_master.models.network_dncnn import DnCNN as net

# prepare the device
gpu = 3
device = f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu'

# Load the network
model_path = '/nas/public/exchange/JPEG-AI/code/utils/third_party/KAIR_master/model_zoo'
nb = 20
model_name = 'dncnn_gray_blind.pth'
model = net(in_nc=1, out_nc=1, nc=64, nb=nb, act_mode='R')
# model = net(in_nc=n_channels, out_nc=n_channels, nc=64, nb=nb, act_mode='BR')  # use this if BN is not merged by utils_bnorm.merge_bn(model)
model.load_state_dict(torch.load(os.path.join(model_path, model_name)), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)

In [None]:
from tqdm.notebook import tqdm

fsize=20
avg_specs = dict()
for dataset in tqdm(all_images['dataset'].unique()):
    # Select the dataset
    dataset_df = all_images.loc[all_images['dataset']==dataset]
    # Get the uncompressed sample
    unc_df = dataset_df.loc[dataset_df['compressed']==False]
    sample = unc_df.sample(50, random_state=42)
    # Get the compressed samples
    comp_df = dataset_df.loc[dataset_df['compressed']]
    # Merge the dataframes on 'filename' and 'content'
    sample['filename'] = sample['filename'].apply(lambda x: x.replace('jpg', 'png'))
    comp_samples = pd.merge(comp_df, sample, left_on=['filename', 'content'], right_on=['filename', 'content'],
                        suffixes=('', '_uncompressed'))
    # Compute them
    avg_specs[dataset] = dict()
    # Compressed samples
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        spec = compute_avg_spectrum_module(comp_samples.loc[comp_samples['target_bpp']==t_bpp], model, device)
        avg_specs[dataset][t_bpp] = spec
    # Uncompressed samples
    spec = compute_avg_spectrum_module(sample, model, device)
    avg_specs[dataset]['unc'] = spec
    
    

Plot them

In [None]:
fsize = 20
for dataset, specs in avg_specs.items():
    # Prepare the figure
    fig, axs = plt.subplots(2, 3, figsize=(9*3, 9*2))
    for idx, t_bpp in enumerate(comp_samples['target_bpp'].unique()):
        row_idx = (idx+1)//3
        col_idx = idx+1 if idx+1 < 3 else idx-3+1
        qmin, qmax = np.quantile(specs[t_bpp].flatten(), [0.01, 0.999])
        axs[row_idx][col_idx].imshow(specs[t_bpp], vmax=qmax)
        #axs[row_idx][col_idx].imshow(np.log(1+specs[t_bpp]))
        axs[row_idx][col_idx].axis('off')
        axs[row_idx][col_idx].set_title(f'{t_bpp} BPP avg spectrum', fontsize=fsize-5)
    qmin, qmax = np.quantile(specs['unc'].flatten(), [0.01, 0.999])
    axs[0][0].imshow(specs['unc'], vmax=qmax)
    #axs[0][0].imshow(np.log(1+specs['unc']))
    axs[0][0].axis('off')
    axs[0][0].set_title(f'Uncompressed avg spectrum', fontsize=fsize-5)
    fig.suptitle(f'{dataset} average spectrums', fontsize=fsize)
    plt.show()