Melan-Dx: a knowledge-enhanced vision-language framework improves differential diagnosis of melanocytic neoplasm pathology
A training framework for melanocytic neoplasm classification using pre-computed embeddings and hierarchical disease taxonomy.
This framework accompanies the following peer-reviewed publication:
Melan-Dx: a knowledge-enhanced vision-language framework improves differential diagnosis of melanocytic neoplasm pathology Jialu Yao, Songhao Li, Peixian Liang, Xiaowei Xu, David Elder, Zhi Huang npj Digital Medicine, 2026
📄 Paper: https://www.nature.com/articles/s41746-026-02357-3
🔗 DOI: https://doi.org/10.1038/s41746-026-02357-3
If you use this code or the Melan-Dx framework in your research, please cite the paper above.
This version allows you to train directly from pre-computed embedding files without data preprocessing or embedding generation steps.
Four .pt files, each containing:
embeddings: torch.Tensor with shape(N, embed_dim)disease_names: List[str] with length N
Required files:
train_embeddings.pt- Training set embeddingsval_embeddings.pt- Validation set embeddingstest_embeddings.pt- Test set embeddingsknowledge_embeddings.pt- Knowledge base embeddings
Important: Preparing Your Embeddings
If you have a single merged embedding file, you MUST split it into train/val/test sets using stratified sampling:
# Split embeddings with 70/15/15 ratio (stratified by disease label)
python split_embeddings.py merged_all_embeddings.pt ./output_dir --train_ratio 0.7 --val_ratio 0.15 --test_ratio 0.15config/who_44_classes_tree.json - A 3-level hierarchical structure:
{
"Level 2 (Grandparent)": {
"Level 3 (Parent)": [
"Level 4 (Disease 1)",
"Level 4 (Disease 2)",
...
]
}
}Example:
{
"Melanocytic neoplasms in intermittently sun-exposed skin": {
"Naevi": [
"Junctional, compound, and dermal naevi",
"Simple lentigo and lentiginous melanocytic naevus",
"Dysplastic naevus"
]
}
}- Edit
Melan_Dx_musk.shto set your embedding file paths:
TRAIN_EMBEDDING="/path/to/train_embeddings.pt"
VAL_EMBEDDING="/path/to/val_embeddings.pt"
TEST_EMBEDDING="/path/to/test_embeddings.pt"
KNOWLEDGE_EMBEDDING="/path/to/knowledge_embeddings.pt"
SAVE_DIR="output_model"- Run the script:
bash Melan_Dx_musk.shpython train_model.py \
--config config/melandx_musk_config.json \
--train_embedding /path/to/train_embeddings.pt \
--val_embedding /path/to/val_embeddings.pt \
--test_embedding /path/to/test_embeddings.pt \
--knowledge_embedding /path/to/knowledge_embeddings.pt \
--tree_json_path config/who_44_classes_tree.json \
--loss_type basic \
--learning_rates 1e-5 1e-4 1e-3 \
--save_dir output_modelAfter training, the following files will be generated in {SAVE_DIR}/:
{SAVE_DIR}/
├── best_model_lr_1e_5.pth # Best model for each learning rate
├── val_metrics_lr_1e_5.csv # Validation metrics per epoch
├── test_metrics_lr_1e_5.csv # Test metrics per epoch
└── predictions/ # Prediction results
├── val_predictions_epoch_X_lr_1e_5.npz
└── test_predictions_epoch_X_lr_1e_5.npz
Input Files
├── Embedding Files (.pt)
│ ├── embeddings (Tensor)
│ └── disease_names (List)
│
└── Hierarchy JSON
└── 3-level tree structure
↓
Automatic Data Structure Construction
├── train_data
│ ├── paths: Placeholder list
│ ├── disease_names: From embedding file
│ ├── disease_to_parent: Built from JSON
│ └── parent_to_grandparent: Built from JSON
│
├── val_data, test_data
│ ├── paths: Placeholder list
│ └── disease_names: From embedding file
│
└── knowledge_data
├── texts: Placeholder list
└── disease_names: From embedding file
↓
Training Loop
├── Initialize MainModel
├── Initialize ModelTrainer
└── Start training iterations
If you have a single merged embedding file, split it into train/val/test:
# Basic usage (70/15/15 split)
python split_embeddings.py merged_all_embeddings.pt ./split_output
# Custom split ratios
python split_embeddings.py merged_all_embeddings.pt ./split_output \
--train_ratio 0.8 --val_ratio 0.1 --test_ratio 0.1 --seed 42Output:
./split_output/
├── train_embeddings.pt
├── val_embeddings.pt
└── test_embeddings.pt
Edit Melan_Dx_musk.sh to point to your split embeddings:
TRAIN_EMBEDDING="./split_output/train_embeddings.pt"
VAL_EMBEDDING="./split_output/val_embeddings.pt"
TEST_EMBEDDING="./split_output/test_embeddings.pt"
KNOWLEDGE_EMBEDDING="/path/to/knowledge_embeddings.pt"# 1. Split embeddings (if needed)
python split_embeddings.py merged_all_embeddings.pt ./split_output
# 2. Start training
bash Melan_Dx_musk.sh
# 3. Monitor training progress (if using WandB)
# Open WandB link in browser