Official Pytorch implementation of Deep Optimal Transport: A Practical Algorithm for Photo-realistic Image Restoration
We propose an image restoration algorithm that can control the perceptual quality and/or the mean square error (MSE) of any pre-trained model, trading one over the other at test time. Our algorithm is few-shot: Given about a dozen images restored by the model, it can significantly improve the perceptual quality and/or the MSE of the model for newly restored images without further training.
We used miniconda3 and pip3 to manage dependencies
conda create -n dmax python=3.8
conda activate dmax
git clone git+https://github.com/theoad/dot-dmax
cd dot-dmax
pip install -e .
Example of our algorithm applied to the SwinIR model for SISRx4.
NB: This is a simplified example. For the full algorithm implementation, see dmax/main.py
and Reproducing results.
from PIL import Image
import torch
from torchvision.transforms.functional import to_tensor, resize
from torchvision.utils import save_image
from datasets import load_dataset
from dmax.latent_w2 import LatentW2
from dmax.models import swinir
device = "cuda" if torch.cuda.is_available() else "cpu"
dmax = LatentW2("stabilityai/sd-vae-ft-ema").to(device)
model = swinir("classical_sr-4", pretrained=True).to(device) # <-- replace with any restoration model !
dataset = iter(load_dataset("imagenet-1k", split="train", streaming=True))
for _ in range(100): # arbitrary resolution & aspect ratios
x = to_tensor(next(dataset)['image']).to(device).unsqueeze(0)
dmax.update(x, distribution="target") # update nat. image statistics
for _ in range(100): # unpaired updates
x = to_tensor(next(dataset)['image']).to(device).unsqueeze(0)
y = resize(x, (x.size(-2)//4, x.size(-1)//4), antialias=True) # degrade image
x_star = model(y) # restore with the pre-trained model
dmax.update(x_star, distribution="source") # update model statistics
w2 = dmax.compute() # compute the latent transport operator & W2 distance
print(f"Latent w2 distance: {w2.cpu().item():.2f}")
x = to_tensor(Image.open("../assets/baboon.png")).to(device).unsqueeze(0)
y = resize(x, (x.size(-2)//4, x.size(-1)//4), antialias=True)
x_star = model(y)
xhat_0 = dmax.transport(x_star) # enhance new images
collage = torch.cat([resize(y, x.shape[-2:]), x_star, xhat_0, x], dim=-1).to(device)
save_image(collage, "demo.png", nrow=1, padding=0)
Our algorithm enhances existing methods (we tested SwinIR, Swin2SR, Restormer, ESRGAN and DDRM). Instead of imposing on users to download manually third party code, data or weights, we automate everything using google drive's API.
Note: The following only applies to your script, and does not give access to other users. Nevertheless, we recommend revoking the script's access after the download is complete.
- Follow the steps of google's pydrive quickstart and place your
credentials.json
under thedot-dmax
repository. - Run
python data/gdrive.py init
(must be on a local machine, connected to a display). If the warning "Google hasn’t verified this app" occurs, clickadvance
and thenGo to <Your App Name> (unsafe)
. - [Optional]: To be able to access the API from a remote machine, simply upload the
token.pickle
file generated by the previous step.
We abstract hardware dependency using hugging-face's accelerate library. Configure your environment before launching the scripts by running
accelerate config
Because evaluation is quite heavy (we computed for many values of batch_size=10
.
Reduce this value if you encounter any CUDA out-of-memory issues using
export batch_size=8 # replace with your batch size
After configuring your hardware, launch distributed jobs by replacing python main.py <args>
by accelerate launch main.py <args>
If you enabled the PyDrive-API, you are not required to download any dataset manually, except for ImageNet. Once downloaded, you can declare its location with the following environment variable:
export imagenet_path=~/data/ImageNet # replace with your path
[Optional]: see Hardware Setup for distributed commands
cd dmax # we must run main.py under the source directory
python main.py --help # displays all optional arguments
# export imagenet_path=~/data/ImageNet # <-- replace with your path
# export batch_size=10 # <-- replace with you batch size
# export num_workers=10 # <-- replace with your number of workers
# NB: Replace `python` with `accelerate launch` for distributed run
python main.py ESRGAN classical_sr-4 # ESRGAN (SISRx4)
python main.py SwinIR classical_sr-4 # SwinIR (SISRx4)
python main.py SwinIR jpeg_car-10 # SwinIR (JPEGq10)
python main.py Restormer gaussian_color_denoising_sigma50 # Restormer (AWGNs50)
python precomputed_results.py DDRM classical_sr_4_dn_25 imagenet-1k # DDRM (SISRx4 + AWGNs25)
# Swin2SR (SISRx4 + JPEGq10)
python main.py Swin2SR compressed_sr-4 \
--natural_image_set ["compressed_sr_swin2sr"] \
--degraded_set ["compressed_sr_swin2sr"] \
--quantitative_set ["compressed_sr_swin2sr"] \
--qualitative_set ["compressed_sr_swin2sr"]
# NLM has a significant memory footprint so we use a batch-size of 1
export batch_size=1
export num_workers=0
python main.py NLM color_dn-50 # NLM (AWGNs50)
You can cite our work if you found our research useful with the following bibtex entry
@misc{
adrai2023deep,
title={Deep Optimal Transport: A Practical Algorithm for Photo-realistic Image Restoration},
author={Theo Adrai and Guy Ohayon and Tomer Michaeli and Michael Elad},
year={2023},
eprint={2306.02342},
archivePrefix={arXiv},
primaryClass={cs.AI}
}