# Neural Imaging - Getting Started

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tqdm import tqdm

from helpers import utils, tf_helpers

utils.setup_logging()

tf_helpers.print_versions()

# 1. General Toolbox Structure

Let's start with the high-level toolbox structure, just to have a general feel of how things are organized.

**Python Modules**
```
>> neural-imaging:
├── compression                     - learned image codec, helper functions for JPEG and BPG codecs
├── config                          - configuration files
├── data                            - data directory (datasets, models, results)
├── docs                            - documentation
├── helpers                         - helper functions (working with datasets, results, plotting, image processing, etc.):
├── models                          - Tensorflow models (camera ISPs, codecs, forensics, custom layers, etc.)
├── pyfse                           - wrappers for the FSE entropy codec (Cython)
├── training                        - training routines
└── workflows                       - workflows for various applications
 
```

**Command Line Tools**
```
>> Training scripts:
├── train_dcn.py                    - train lossy image codecs
├── train_manipulation.py           - train manipulation classification workflows
├── train_nip.py                    - train camera ISP models
└── train_prepare_training_set.py   - prepare training data

>> Testing scripts:
├── test_dcn.py                     - test learned image codecs (plots, rate-distortion) 
├── test_dcn_rate_dist.py           - compare rate-distortion profiles for various codecs
├── test_fan.py                     - test forensic manipulation classification 
├── test_framework.py               - high-level framework test
├── test_jpeg.py                    - test differentiable JPEG codec
└── test_nip.py                     - test camera ISPs

>> Others:
├── develop_images.py               - render RAW images with selected ISP
├── diff_nip.py                     - compare output from 2 ISPs
├── results.py                      - search & display results from manipulation classification
└── summarize_nip.py                - summarize ISP training statistics

```

# 2. Working with Datasets

The toolbox features a `Dataset` class to help with loading and feeding the images into the model. The class splits images into training and testing subsets. The **training subset** is loaded as **full resolution** images, and patches will be randomly sampled on demand. The **validation subset** is **sampled upon loading** and only patches are stored in the memory. The dataset can load either RAW-RGB pairs (default) or just the RGB images - the behavior is controlled via the `load` argument.

<div class="alert alert-block alert-info">
<b>Notice:</b> The following sections assume you have downloaded example datasets. Feel free to get example data from <a href="https://pkorus.pl/downloads/neural-imaging-resources">pkorus.pl/downloads/neural-imaging-resources</a> or adjust the paths accordingly.</div>

Once you download the data and models, they should be placed in the following folders:

```
data/raw/                               - RAW images used for camera ISP training
  └ training_data/{camera name}          Bayer stacks (*.npy) and developed (*.png)
data/models                             - pre-trained TF models
  ├ nip/{camera name}/{nip}              NIP models (TF checkpoints)
  ├ isp                                  Classic ISP models (TF checkpoints)
  └ dcn/baselines/{dcn model}            DCN models (TF checkpoints)
```


In [None]:
from helpers import dataset, plots, utils, image

In [None]:
data = dataset.Dataset('data/raw/training_data/D90/', n_images=10, v_images=10, val_n_patches=4, load='xy')

## 2.1. Plotting 101

Plotting and visualization functions are available in the `helpers.plots` module. We'll start with plotting images. Your swiss army knife is the `plots.images` function, which shows grids of image collections.

In [None]:
# Sample an example training batch
batch_raw, batch_rgb = data.next_training_batch(0, 10, 128)
plots.images(batch_rgb, '{}', ncols=-1)

## 2.2. Patch Selection

Sampling from the dataset can be fully random or biased. The behavior can be controlled via the `discard` argument. The following modes are available:

- `flat` - attempts to discard flat patches based on patch variance (not strict)
- `flat-aggressive` - a more aggressive version that avoids patches with variance < 0.01
- `dark-n-textured` - avoid dark (mean < 0.35) and textured patches (variance > 0.005)

In [None]:
# Show differences between patch discarding policies
image_id = 4
patches = []
discard_modes = [None, 'flat', 'flat-aggressive', 'dark-n-textured']

for discard in discard_modes:
    patches_current = [data.next_training_batch(image_id, 1, 128, discard)[-1].squeeze() for _ in range(10)]
    patches.extend(patches_current)
    print(f'discard={discard} -> mean: {np.mean(patches_current):.3f} & var : {np.var(patches_current, axis=(1, 2, 3)).mean():.3f}')

plots.images(patches, '', ncols=-4, rowlabels=discard_modes)

In [None]:
# Show distribution of patch variances for different discard policies
patches = image.cati([data.next_training_batch(image_id, 1, 128, discard=None)[-1] for _ in range(500)])

_ = plt.hist(np.var(patches, axis=(1,2,3)), 30)
plt.xlabel('patch variance')
plt.title(f'random sampling ({data.files["training"][image_id]})')
plt.legend([f'mean: {np.mean(patches):.3f}, var : {np.var(patches, axis=(1, 2, 3)).mean():.3f}'])

# 3. Working with Models

The toolbox contains several models:
- camera ISP models: both **classic ISP** and **neural ISP**
- image compression: a **differentiable JPEG codec** and a **fully learned codec**
- image forensics: state-of-the-art **constrained CNN model**

The models are defined in sub-modules of `models` and need to be derived from `tfmodel.TFModel`. Certain types of models have corresponding abstract model classes that simplify definition of new architectures by providing common functionality (e.g., `NIP` class for defining camera ISP models, or `DCN` for learned image codecs).

For trainable models, creating new instances produces uninitialized models with random weights. Such models need to be either trained or restored with pre-trained weights. The snippet below shows 4 different ways in which models can be restored. The following sections will show more diverse practical examples.

```python
# 1. Manual - create model instance and load weights
dcn = compression.TwitterDCN(rounding="soft-codebook",n_features=16)
dcn.load_model('data/models/dcn/baselines/16c')`

# 2. Restore a pre-trained model of a known class
compression.TwitterDCN.restore('data/models/dcn/baselines/16c/', key='codec')

# 3. Restore a pre-trained model from a known Python module
tfmodel.restore('data/models/dcn/baselines/16c/', compression, key='codec')

# 4. Convenience function to restore specific model classes (here, compression)
codec.restore('16c')
```

First, let's load a dataset for our experiments.

In [None]:
data = dataset.Dataset('D90', n_images=10, v_images=10)

## 3.1 Camera ISP 

In this demonstration, we will compare the output of a classic ISP and a neural ISP.

In [None]:
from models import pipelines

# Restore a Classic ISP and a neural imaging pipeline
isp = pipelines.ClassicISP.restore(camera='D90')
nip = pipelines.DNet.restore('data/models/nip/D90/')
batch_size = 6

# Fetch a validation batch from the dataset
batch_raw, batch_rgb = data.next_validation_batch(0, batch_size)

# Compare ground truth images with ISP & NIP output
plots.images(image.cati(batch_rgb, isp.process(batch_raw), nip.process(batch_raw)), '{}', ncols=-3, rowlabels=['Ground truth', 'Classic ISP', f'Neural Pipeline ({nip.class_name})'])

## 3.2. JPEG Compression

The toolbox provides a fully differentiable JPEG codec, which is useful for training components that happen earlier in the workflow (e.g., various types of watermark or fingerprint embedders). The best way of using the codec is via the `jpeg.JPEG` class which serves as a generic interface, and allows for switching the codec in a seamless way. 

In this example, we show how to take advantage of this behavior and compare the results of JPEG compression using the differentiable codec and the open source **libJPEG**.

In [None]:
from models import jpeg

batch_raw, batch_rgb = data.next_training_batch(0, batch_size, 64)

codec_differentiable = jpeg.JPEG(50)
codec_libjpeg = jpeg.JPEG(50, 'libjpeg')

plots.images(image.cati(batch_rgb, codec_differentiable.process(batch_rgb), codec_libjpeg.process(batch_rgb)), '', ncols=-3, rowlabels=['Uncompressed', 'Diff. JPEG', 'LibJPEG'])

## 3.3. Learned Compression

The toolbox provides a fully learned image compression codec and three pre-trained quality configurations. In this example, we use the low-quality settings and compare the results against
a standard JPEG codec at a similar SSIM level.

In [None]:
from compression import codec, jpeg_helpers
from helpers import metrics

batch_size = 6

# Restore a learned codec with low-quality settings (16 channels)
codec_dcn = codec.restore('16c')
codec_jpg = jpeg.JPEG(codec='libjpeg')

batch_rgb = data.next_validation_batch(0, batch_size)[-1]
batch_dcn = codec_dcn.process(batch_rgb).numpy()

# Find SSIM measurements for DCN-compressed images
ssims_d = metrics.ssim(batch_dcn, batch_rgb)

# Find JPEG quality factors that lead to the same SSIM levels
jpeg_qf = [jpeg_helpers.match_quality(rgb, ssim) for rgb, ssim in zip(batch_rgb, ssims_d)]

# Compress with JPEG to match the quality level
batch_jpg = [codec_jpg.process(np.expand_dims(rgb, axis=0), qf) for rgb, qf in zip(batch_rgb, jpeg_qf)]
ssims_j = metrics.ssim(batch_rgb, np.concatenate(batch_jpg))

# Plot results
labels_o = batch_size * ['']
labels_j = [f'QF={qf} -> SSIM={ssim:.3f}' for qf, ssim in zip(jpeg_qf, ssims_j) ]
labels_d = [f'SSIM={ssim:.3f}' for ssim in ssims_d]

plots.images(image.cati(batch_rgb, batch_dcn, batch_jpg), labels_o + labels_d + labels_j, ncols=-3, rowlabels=['Uncompressed', f'Learned codec ({codec_dcn.summary_compact()})', 'JPEG @ similar SSIM'])

# 4. Workflows

Models can be combined into **workflows** that model specific applications. The current version of the toolbox provides an example workflow that models **manipulation classification**. The model involves training a forensic analysis network (FAN) to identify subtle post-processing operations applied to the image. The process starts with the camera ISP and is followed by photo manipulations and a distribution channel. The FAN can access images after they have been degraded (e.g., down-sampled and compressed) by the channel. The generic model is shown below:   

![](docs/manipulation_detection_training_architecture.png)

In this example, we will train a toy version of the model with:
- a dummy ISP (feeds RGB images), 
- 2 easy manipulations (sharpening and Gaussian blurring), 
- mild JPEG compression,
- a small forensic CNN.

## 4.1. Construct the workflow to our specification

In [None]:
from workflows import manipulation_classification

# Define a list of manipulations and their strength
manipulations = ['sharpen:0.75', 'gaussian:1', 'jpeg:50']

# Define the distribution channel: no resizing and JPEG compression
distribution = {
    'downsampling': 'none',
    'compression': 'jpeg',
    'compression_params': {
        'quality': 90,
        'codec': 'soft',
    }
}

# Construct the workflow with a dummy ISP that allows for feeding RGB images
flow = manipulation_classification.ManipulationClassification('ONet', manipulations, distribution, {'kernel': 3, 'n_dense': 2}, {}, raw_patch_size=32)

print(flow.details())

## 4.2. Load an RGB dataset

In [None]:
from helpers import dataset

data = dataset.Dataset('D90', n_images=50, v_images=50, val_n_patches=2, val_rgb_patch_size=64, load='y')
print(data.summary())

## 4.3. Visualize manipulations

In [None]:
batch_m = flow.run_manipulations(data.next_training_batch(0, 1, 64, 'flat-aggressive')).numpy()

plots.images(batch_m, ['{} ()'.format(x) for x in flow._forensics_classes], ncols=flow.n_classes, figwidth=4)

## 4.4. Train the entire workflow

Now, let's train the entire workflow. In this case, we are training only the forensics model. 

We're playing with a toy example and small patches - the whole process should take no more than 10-15 minutes on a mid-tier desktop.

In [None]:
from training.manipulation import train_manipulation_nip

# Training setup
training = {
    'camera_name': 'D90',
    'patch_size': 32,
    'batch_size': 10,
    'n_epochs': 501,
    'validation_schedule': 10,
    'lambda_nip': 0,
}

output_directory = train_manipulation_nip(flow, training, data, {'root': 'data/getting_started/m'}, overwrite=True)

## 4.5. Validate the trained model

First, we'll test the model on the entire (validation) dataset.

In [None]:
from training.validation import validate_fan

accuracy, conf_mat = validate_fan(flow, data)

print(f'\nAccuracy: {accuracy:.2f}')
print('Confusion:')
print(conf_mat.round(2))

And finally, let's run FAN predictions on a sample batch:

In [None]:
# Classify different post-processed variations of a sample patch
batch_x = data.next_validation_batch(1, 1)

# Fetch processed & distributed patches - as seen by the FAN
batch_f = flow.run_rgb_to_fan(batch_x)

# Run the patches through the FAN
predicted_class, confidence = flow.fan.process_and_decide(batch_f, True)

# Prepare labels: real_class -> predicted_class (confidence)
class_labels = flow._forensics_classes
labels = [f'{class_labels[real_class]} -> {class_labels[pred_class]} (confidence={conf:.2f})' for real_class, (pred_class, conf) in enumerate(zip(predicted_class, confidence))]

# Plot all images
plots.images(batch_f, labels, ncols=-1, figwidth=6)

# 5. Working with Results

The toolbox provides helper functions to work with results, e.g.:
- finding results stored before (e.g., the results of our previous training runs)
- rendering & visualization

As an example, let's print the above confusion matrix in a more accessible form: 

In [None]:
from helpers import results_data, plots

print(results_data.confusion_to_text(100 * conf_mat, flow._forensics_classes, flow.summary_compact()))

## 5.1. The Result Cache

Instead of recomputing validation statistics, we can also fetch them from the training log. We will explore the most general way of accessing saved results - via the `ResultsCache`.

`ResultCache` is a helper class which simplifies searching for and access to results stored as either `.json` or `.npz` (numpy archives). It uses templates for building the path and allows you to incrementally provide more specific search terms (see details in the docstrings). In this example, we will look up a predefined template for `manipulation_classification`.

In [None]:
cache = results_data.ResultCache('manipulation_classification', 'data/getting_started/m', camera='D90')
print(str(cache))

By using the `find` method we see the current pattern used for searching and the list of all matching results. In this example, we only have one training run, so we don't have to be too specific.

In [None]:
cache.find()

We can request more specific results by specifying values for (some of the) search terms. Note how the search pattern changes. Search terms that are expected to be common for all queries, can be specified in the constructor.

<div class="alert alert-block alert-info">
<b>Notice:</b> We can request more specific results by specifying values for (some of the) search terms. Note how the search pattern changes. Search terms that are expected to be common for all queries, can be specified in the constructor.</div>

In [None]:
cache.find(isp='ONet')

**Enough introduction, let's load our training log!**

The returned data is exposed as a `dict` object. For bervity, we'll print only the relevant section about the forensic training. You will see several keys:
- `model` and `args` contain class name and arguments that will allow you to restore this model later,
- `performance` contains the values of the metrics that we tracked during training and validation

In [None]:
results = cache.load()
utils.printd(results['forensics'])

We can plot the training progress using helper functions from the `plots` module. we'll use the `utils.get` function to recursively find keys in our dict.

In [None]:
plots.perf(utils.get(results, 'forensics.performance'), alpha=0.1)

## 5.2 Rendering Results

Let's have another look at the confusion matrix. We can compare the recently computed validation results to the one recorded during training. 

To make things more interesting, we can **render the new table as a LaTeX table** (useful when preparing your manuscript).

In [None]:
conf_log = np.array(utils.get(results, 'forensics.performance.confusion'))
conf_labels = utils.get(results, 'manipulations')

tex = results_data.confusion_to_text(100 * conf_log, conf_labels, fmt='tex')

print(tex)

If you're lucky enough to have your LaTeX environment ready, you can give it a shot and render the table directly in here.

In [None]:
results_data.render_tex(tex)

Confusion matrices are rather specific, so you can also use a more general function to convert any 2D arrays. It has more options and more features that this short introduction can allow for, so I'll leave you with one example and the docstrings.

In [None]:
results_data.convert_table(100 * conf_log, conf_labels, fmt='df')

# 6. Final Words

**Congratulations, you made it through the tutorial!**

We barely scratched the surface of the toolbox, but this should be enough to get you started. 

If you have any comments/suggestions or would like to contribute your code or models, feel free to reach out.

# Appendix

## Python Modules
```
>> neural-imaging:
├── compression                     - learned image codec, helper functions for JPEG and BPG codecs
│   ├── bpg_helpers
│   ├── codec
│   ├── jpeg_helpers
│   └── ratedistortion
├── config                          - configuration files
├── data                            - data directory (datasets, models, results)
├── docs                            - documentation
├── helpers                         - helper functions:
│   ├── dataset                       . loading and serving image datasets
│   ├── debugging                     . memory usage
│   ├── fsutil                        . dealing with filenames and directory listing
│   ├── image                         . image processing
│   ├── imdiff                        . comparing images
│   ├── kernels                       . image filtering kernels
│   ├── loading                       . loading images and extracting patches
│   ├── metrics                       . image quality metrics
│   ├── paramspec                     . handling and verification of hyper-parameters
│   ├── plots                         . plotting & visualization
│   ├── raw                           . rendering RAW files and working with Bayer data
│   ├── results_data                  . rendering, saving & loading results
│   ├── stats                         . detection statics (tpr, roc, auc)
│   ├── summaries                     . writing complex data to TF logs
│   ├── tf_helpers                    . various Tensorflow helpers 
│   └── utils                         . various helpers (logging, printing, number tests)
├── models                          - Tensorflow models
│   ├── compression                   . learned image codecs (base + TwitterDCN)
│   ├── forensics                     . forensic analysis network (FAN)
│   ├── jpeg                          . differentiable JPEG
│   ├── layers                        . various layers (constrained convolution, quantization)
│   ├── pipelines                     . camera ISP (classic & neural)
│   └── tfmodel                       . base class for all models (TFModel)
├── pyfse                           - wrappers for the FSE entropy codec (Cython)
├── training                        - training routines
│   ├── compression
│   ├── manipulation
│   ├── pipeline
│   └── validation
└── workflows                       - workflows for various applications
    └── manipulation_classification
 
```

## Command Line Tools
```
>> neural-imaging:
├── train_dcn.py                    - train lossy image codecs
├── train_manipulation.py           - train manipulation classification workflows
├── train_nip.py                    - train camera ISP models
├── train_prepare_training_set.py   - prepare training data
├── develop_images.py               - render RAW images with selected ISP
├── diff_nip.py                     - compare output from 2 ISPs
├── results.py                      - search & display results from manipulation classification
├── summarize_nip.py                - summarize ISP training statistics
├── test_dcn.py                     - test learned image codecs (plots, rate-distortion) 
├── test_dcn_rate_dist.py           - compare rate-distortion profiles for various codecs
├── test_fan.py                     - test forensic manipulation classification 
├── test_framework.py               - high-level framework test
├── test_jpeg.py                    - test differentiable JPEG codec
└── test_nip.py                     - test camera ISPs
```