Skip to content

The template code of metric learning based on pytorch framework includes model training, inference, similarity matching and so on

Notifications You must be signed in to change notification settings

xxcheng0708/pytorch-metric-learning-template

Repository files navigation

pytorch-metric-learning-template

基于pytorch-metric-learning开源工具,实现了包括模型训练、模型验证、模型推理的相关代码。


模型训练

  • 使用cifar10数据集训练模型
python train_embedding_model_cifar10.py
  • 使用cifar100数据集训练模型
python train_embedding_model_cifar100.py
  • 使用flower花朵数据集训练模型,下载数据并解压后放在datasets目录下。flower花朵数据集见(链接:https://pan.baidu.com/s/1TfzLYZrkfwLy8wShy7nyMA 提取码:gxei)
    • 修改./config/embedding.yaml配置文件里面的train_dataset_dir指向数据集的位置
python train_embedding_model.py
  • 使用pytorch-metric-learning提供的API训练模型
python trainer_model.py
  • results目录下提供了几个训练好的模型文件
    • model_cifar10_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在cifar10上训练的模型
    • model_cifar100_CircleLoss.pth 使用CircleLoss在cifar100上训练的模型
    • model_flower_photos_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在flower花朵数据集上训练的模型

备注:上面这些模型使用的损失函数可以通过模型的名字得到,训练过程中做了embedding归一化,使用余弦相似度计算特征之间的距离

模型推理

使用训练好的模型,以及pytorch-metric-learning工具提供的接口进行模型推理。

python model_inference.py

Embedding特征提取

使用训练好的模型,将读入的数据转化为embedding特征。

python feature_extraction.py

Embedding特征可视化

使用训练好的模型,将读入的数据转化为embedding特征,并对embedding降维后可视化。

python visualizer.py

自定义训练数据

将数据按照类别ID存放在不同的目录中,具体格式可以参考flower花朵数据集那样。

模型效果展示

  • cifar10
embedding特征之间的相似度可视化 embedding特征降维之后可视化
  • cifar100
embedding特征之间的相似度可视化 embedding特征降维之后可视化
  • 花朵数据集(共5类)
embedding特征之间的相似度可视化 embedding特征降维之后可视化


度量学习相关的损失函数介绍:


基于度量学习方法实现音乐特征匹配的系列文章

About

The template code of metric learning based on pytorch framework includes model training, inference, similarity matching and so on

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages