Skip to content
/ MTUNet Public

MTUNet: Few-shot Image Classification with Visual Explanations (CVPRW 2021)

Notifications You must be signed in to change notification settings

wbw520/MTUNet

Repository files navigation

Match Them Up: Visually Explainable Few-shot Image Classification

PDF

Model Structure

Structure Figure

Abstract

Few-shot learning (FSL) approaches are usually based on an assumption that the pre-trained knowledge can be obtained from base (seen) categories and can be well transferred to novel (unseen) categories. However, there is no guarantee, especially for the latter part. This issue leads to the unknown nature of the inference process in most FSL methods, which hampers its application in some risk-sensitive areas. In this paper, we reveal a new way to perform FSL for image classification, using visual representations from the backbone model and weights generated by a newly-emerged explainable classifier. The weighted representations only include a minimum number of distinguishable features and the visualized weights can serve as an informative hint for the FSL process. Finally, a discriminator will compare the representations of each pair of the images in the support set and the query set. Pairs with the highest scores will decide the classification results. Experimental results prove that the proposed method can achieve both good accuracy and satisfactory explainability on three mainstream datasets.

Usage

Data Set

Download miniImageNet from miniImageNet

Download tieredImageNet from tiered-ImageNet

Download cifarFS from cifarfs

Put all dataset under "../FSL_data", for the same split setting run the following command:

python data/tiered_imagenet.py --data "../FSL_data/tiered-imagenet/"
python data/cifarfs.py --data "../FSL_data/cifar100/" --split "../FSL_data""

MTUNet training follow by three steps. All training default as 5-way 1-shot 15-query.

Training for backbone
python train_base.py --dataset miniImageNet --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_base.py --dataset miniImageNet --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
python train_base.py --dataset tiered-ImageNet --base_model resnet18 --channel 512 --num_classes 351 --data_root "../FSL_data"
python train_base.py --dataset tiered-ImageNet --base_model wideres --channel 640 --num_classes 351 --data_root "../FSL_data"
python train_base.py --dataset cifarfs --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_base.py --dataset cifarfs --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
Training for patterns
python train_scouter.py --random False --interval 10 --dataset miniImageNet --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_scouter.py --random False --interval 10 --dataset miniImageNet --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
python train_scouter.py --random False --interval 10 --dataset tiered-ImageNet --base_model resnet18 --channel 512 --num_classes 351 --data_root "../FSL_data"
python train_scouter.py --random False --interval 10 --dataset tiered-ImageNet --base_model wideres --channel 640 --num_classes 351 --data_root "../FSL_data"
python train_scouter.py --random False --interval 10 --dataset cifarfs --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_scouter.py --random False --interval 10 --dataset cifarfs --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
Training MTUNet
python train_fsl.py --random False --interval 10 --dataset miniImageNet --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_fsl.py --random False --interval 10 --dataset miniImageNet --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
python train_fsl.py --random False --interval 10 --dataset tiered-ImageNet --base_model resnet18 --channel 512 --num_classes 351 --data_root "../FSL_data"
python train_fsl.py --random False --interval 10 --dataset tiered-ImageNet --base_model wideres --channel 640 --num_classes 351 --data_root "../FSL_data"
python train_fsl.py --random False --interval 10 --dataset cifarfs --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
python train_fsl.py --random False --interval 10 --dataset cifarfs --base_model wideres --channel 640 --num_classes 64 --data_root "../FSL_data"
Test MTUNet
python test_fsl.py --random False --interval 10 --dataset miniImageNet --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"
Visualization of MTUNet
python vis_fsl.py --random False --interval 10 --dataset miniImageNet --base_model resnet18 --channel 512 --num_classes 64 --data_root "../FSL_data"

Patterns visualization

Patterns visualization

Matching matrix

matrix

Publication

If you want to use this work, please consider citing the following paper.

@inproceedings{wang2021mtunet,
 author = {Wang, Bowen and Li, Liangzhi and Verma, Manisha and Nakashima, Yuta and Kawasaki, Ryo and Nagahara, Hajime},
 booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshop},
 pages = {},
 title = {MTUNet: Few-shot Image Classification with Visual Explanations},
 year = {2021}
}

@article{wang2022match,
  title={Match them up: visually explainable few-shot image classification},
  author={Wang, Bowen and Li, Liangzhi and Verma, Manisha and Nakashima, Yuta and Kawasaki, Ryo and Nagahara, Hajime},
  journal={Applied Intelligence},
  pages={1--22},
  year={2022},
  publisher={Springer}
}

About

MTUNet: Few-shot Image Classification with Visual Explanations (CVPRW 2021)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages