Skip to content

sbyebss/monge_map_solver

Repository files navigation

Neural Monge map estimation and its applications

This is the official Python implementation of the paper Neural Monge map estimation and its applications (paper on TMLR Featured Certification, Jiaojiao Fan*, Shu Liu*, Shaojun Ma, Haomin Zhou), and Yongxin Chen.

The repository includes reproducible PyTorch source code for:

  1. Unpaired text to image generation on DALLE2 backbone.
  2. Unpaired class-preserving map.
  3. Unpaired inpainting on the CelebA 128x128 dataset using Monge maps.
  4. Transport between synthetic datasets.
  5. Other toy examples

Installation

Install pytorch, torchvision, then run

pip install --no-deps -r requirements.txt

Repository structure

The repository highly depends on the pytorch-lightning template. The hyper-parameters are stored in configs/.

Reproduction Instructions

Outputs are saved in the logs directory:

  • Training outputs and checkpoints: logs/reproduce

  • Testing outputs: logs/test

The dataset dir (defined in config.yaml) is set to be datasets in default.

Note: Currently, training is only supported on a single GPU. We set trainer.devices=[gpu_id] default as 0. Replace "gpu_id" with your GPU id.

Unpaired text to image

Step1: prepare text and image CLIP embeddings

bash bash/txt2img_dataset_processing.sh

The folder structure of text-to-image datasets

./datasets/laion-art_test
|- laion_art
|  |- laion-art.parquet
|- laion-art-en
|  |- laion-art-en.parquet
|- laion-high-resolution-en
|  |- clip_emb
|  |  |- img_emb
|  |  |- img_emb_reorder
|  |  |- text_emb
|  |  |- text_emb_reorder
|  |  |- metadata
|  |  |- stats
|  |- 00000.tar
|  |- 00000.parquet
|  |- 00000_stats.json
|  |- 00001.tar
|  |- 00001.parquet
|  |- 00001_stats.json
|- ...
./datasets/cc3m_test
|- cc3m_no_watermark.tsv
|- cc3m_no_watermark
|  |- clip_emb
|  |  |- img_emb
|  |  |- img_emb_reorder
|  |  |- text_emb
|  |  |- text_emb_reorder
|  |  |- metadata
|  |  |- stats
|  |- 00000.tar
|  |- 00000.parquet
|  |- 00000_stats.json
|  |- 00001.tar
|  |- 00001.parquet
|  |- 00001_stats.json
|- ...

Step2: training and testing

bash bash/txt2img.sh

We need to load diffusion decoder for sampling, the default GPU device for that is set as cuda:1 in txt2img_callbacks.

Class-preserving Monge map

bash bash/class_preserving_map.sh

Unpaired inpainting

Based on the data_dir parameter in configs/config.yaml, please download the celebA dataset and put it in the data_dir folder. We don't need the label information thanks to the unpair property of our algorithm. So you just need to split the images into 📂train_source, 📂train_target, and 📂test folders with ratio 0.45 : 0.45 : 0.1. data_dir folder should have the following structure. The naming format of images doesn't have to follow the template below.

📂celeba
 ┣ 📂train_source
  ┣ 📂images
    ┣ 📜000001.jpg
    ┣ 📜000002.jpg
    ┣ 📜...
 ┣ 📂train_target
  ┣ 📂images
    ┣ 📜000001.jpg
    ┣ 📜000002.jpg
    ┣ 📜...
 ┣ 📂test
  ┣ 📂images
    ┣ 📜000001.jpg
    ┣ 📜000002.jpg
    ┣ 📜...

Then run the following command to train the model. The trainer.devices is the list of GPU you have.

bash bash/inpainting.sh

Toy examples

Please run the notebooks in the toy_examples folder.

Citation

@article{
fan2023neural,
title={Neural Monge Map estimation and its applications},
author={Jiaojiao Fan and Shu Liu and Shaojun Ma and Hao-Min Zhou and Yongxin Chen},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=2mZSlQscj3},
note={Featured Certification}
}

Contact

For any inquiries, please feel free to reach out:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published