Skip to content

rami0205/RAMiT

Repository files navigation

RAMiT

Reciprocal Attention Mixing Transformer for Lightweight Image Restoration (CVPR 2024 Workshop NTIRE)

Haram Choi*, Cheolwoong Na, Jihyeon Oh, Seungjae Lee, Jinseop Kim, Subeen Choe, Jeongmin Lee, Taehoon Kim, and Jihoon Yang+

*: This work has been done during Master Course in Sogang University.

+: Corresponding author.

arXiv visual

  • Proposes RAMiT which employs Dimensional Reciprocal Attention Mixing Transformer (D-RAMiT) and Hierarchical Reciprocal Attention Mixer (H-RAMi)
  • D-RAMiT: computing bi-dimensional self-attention in parallel to capture both local and global dependencies
  • H-RAMi: using multi-scale attention for considering where and how much attention to pay semantically and globally
  • Achieves state-of-the-art results on five lightweight image restoration tasks: Super-Resolution, Color Denoising, Grayscale Denoising, Low-Light Enhancement, Deraining

News

-April 17, 2024: Accepted at CVPR 2024 Workshop NTIRE (New Trends in Image Restoration and Enhancement)

-July 12, 2023: Codes released publicly

-May 19, 2023: Pre-printed at arXiv

Model Architecture

Click

image

Dimensional Reciprocal Self-Attentions

Click

image

Lightweight Image Restoration Results

Super-Resolution (SR)

image

slimSR

image

SR trade-off

image

Color Denoising (CDN)

image

Grayscale Denoising (GDN)

image

Low-Light Enhancement (LLE)

image

Deraining (DR)

image

Visual Results

* The visual results on the other images can be downloaded in my drive.

Super-Resolution (SR)

image

Color Denoising (CDN)

image

Low-Light Enhancement (LLE)

image

Deraining (DR)

image

Testing Instructions (with pre-trained models)

Please properly edit the first five arguments to work on your devices.

RAMiT SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT-1 SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT-slimSR SR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X2.pth --task lightweight_sr --target_mode light_x2 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X3.pth --task lightweight_sr --target_mode light_x3 --result_image_save --img_norm

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --pretrain_path ./pretrained/RAMiT-slimSR_X4.pth --task lightweight_sr --target_mode light_x4 --result_image_save --img_norm

RAMiT CDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm

RAMiT-1 CDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --pretrain_path ./pretrained/RAMiT-1_CDN.pth --task lightweight_dn --target_mode light_dn --result_image_save --img_norm

RAMiT GDN

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_GDN.pth --task lightweight_dn --target_mode light_graydn --result_image_save --img_norm

RAMiT LLE

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm

RAMiT-slimLLE LLE

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --pretrain_path ./pretrained/RAMiT-slimLLE_LLE.pth --task lightweight_lle --target_mode light_lle --result_image_save --img_norm

RAMiT DR

python3 ddp_main_test.py --total_nodes 1 --gpus_per_node 1 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --pretrain_path ./pretrained/RAMiT_DR.pth --task lightweight_dr --target_mode light_dr --result_image_save --img_norm

Training Instructions

Please properly edit the first five arguments to work on your devices.

RAMiT SR
(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT-1 SR

(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-1 --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT-slimSR SR

(x2) from scratch
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimSR --target_mode light_x2 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 500 --half_list 200,300,400,425,450,475 --img_norm

(x3) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x3 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

(x4) warm-start
python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --finetune --pretrain_path [pretrain PATH] --warm_start --warm_start_epoch 50 --model_name RAMiT-slimSR --target_mode light_x4 --task lightweight_sr --training_patch_size 64 --batch_size 32 --progressive_epoch 0 --data_name DIV2K --total_epochs 300 --warmup_epoch 10 --half_list 50,100,150,175,200,225 --img_norm

RAMiT CDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT-1 CDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-1 --target_mode light_dn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT GDN (blind noise level)

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_graydn --task lightweight_dn --sigma 0,50 --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DFBW --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT LLE

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT-slimLLE LLE

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT-slimLLE --target_mode light_lle --task lightweight_lle --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name LLE --total_epochs 400 --half_list 200,300,350,375 --img_norm

RAMiT DR

python3 ddp_main.py --total_nodes 1 --gpus_per_node 2 --node_rank 0 --ip_address [ip address XXX.XXX.XXX.XXX] --backend gloo --model_name RAMiT --target_mode light_dr --task lightweight_dr --training_patch_size 64,96,128 --batch_size 32,16,8 --progressive_epoch 0,100,200 --data_name DR --total_epochs 400 --half_list 200,300,350,375 --img_norm

Citation

@article{choi2023ramit,
  title={RAMiT: Reciprocal Attention Mixing Transformer for Lightweight Image Restoration},
  author={Choi, Haram and Na, Cheolwoong and Oh, Jihyeon and Lee, Seungjae and Kim, Jinseop and Choe, Subeen and Lee, Jeongmin and Kim, Taehoon and Yang, Jihoon},
  journal={arXiv preprint arXiv:2305.11474},
  year={2023}
}

My Related Works

  • N-Gram in Swin Transformers for Efficient Lightweight Image Super-Resolution, CVPR 2023. proceedings arXiv code
  • Exploration of Lightweight Single Image Denoising with Transformers and Truly Fair Training, ICMR 2023. proceedings arXiv code

About

Reciprocal Attention Mixing Transformer for Lightweight Image Restoration (CVPR 2024 Workshop)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published