基于pytorch-metric-learning开源工具,实现了包括模型训练、模型验证、模型推理的相关代码。
- 使用cifar10数据集训练模型
python train_embedding_model_cifar10.py
- 使用cifar100数据集训练模型
python train_embedding_model_cifar100.py
- 使用flower花朵数据集训练模型,下载数据并解压后放在datasets目录下。flower花朵数据集见(链接:https://pan.baidu.com/s/1TfzLYZrkfwLy8wShy7nyMA 提取码:gxei)
- 修改./config/embedding.yaml配置文件里面的train_dataset_dir指向数据集的位置
python train_embedding_model.py
- 使用pytorch-metric-learning提供的API训练模型
python trainer_model.py
- results目录下提供了几个训练好的模型文件
- model_cifar10_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在cifar10上训练的模型
- model_cifar100_CircleLoss.pth 使用CircleLoss在cifar100上训练的模型
- model_flower_photos_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在flower花朵数据集上训练的模型
备注:上面这些模型使用的损失函数可以通过模型的名字得到,训练过程中做了embedding归一化,使用余弦相似度计算特征之间的距离
使用训练好的模型,以及pytorch-metric-learning工具提供的接口进行模型推理。
python model_inference.py
使用训练好的模型,将读入的数据转化为embedding特征。
python feature_extraction.py
使用训练好的模型,将读入的数据转化为embedding特征,并对embedding降维后可视化。
python visualizer.py
将数据按照类别ID存放在不同的目录中,具体格式可以参考flower花朵数据集那样。
- cifar10
embedding特征之间的相似度可视化 | embedding特征降维之后可视化 |
- cifar100
embedding特征之间的相似度可视化 | embedding特征降维之后可视化 |
- 花朵数据集(共5类)
embedding特征之间的相似度可视化 | embedding特征降维之后可视化 |
- 度量学习DML之Contrastive Loss及其变种
- 度量学习DML之Triplet Loss
- 度量学习DML之Lifted Structure Loss
- 度量学习DML之Circle Loss
- 度量学习DML之Cross-Batch Memory
- 度量学习DML之MoCO