In [1]:
import json

from pymongo import MongoClient

# Load data from JSONL file
client = MongoClient("mongodb://localhost:27017/")
db = client["ddxplus"]
collection_train = db["train-semistructured"]
collection_val = db["validate-semistructured"]
collection_test = db["test-semistructured"]

train_data = list(
    collection_train.find(
        {},
        {
            "_id": 0,
            "PATHOLOGY": 0,
            "EVIDENCES": 0,
            "EVIDENCES_JSON_V2": 0,
            "DIFFERENTIAL_DIAGNOSIS": 0,
        },
    )
)
val_data = list(
    collection_val.find(
        {},
        {
            "_id": 0,
            "PATHOLOGY": 0,
            "EVIDENCES": 0,
            "EVIDENCES_JSON_V2": 0,
            "DIFFERENTIAL_DIAGNOSIS": 0,
        },
    )
)
test_data = list(
    collection_test.find(
        {},
        {
            "_id": 0,
            "PATHOLOGY": 0,
            "EVIDENCES": 0,
            "EVIDENCES_JSON_V2": 0,
            "DIFFERENTIAL_DIAGNOSIS": 0,
        },
    )
)

print(f"Train samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Test samples: {len(test_data)}")

# show one train sample
print("Example train sample:")
print(json.dumps(train_data[0], indent=2))

TARGET_KEY = "DIFFERENTIAL_DIAGNOSIS_NOPROB"

Train samples: 1025602
Validation samples: 132448
Test samples: 134529
Example train sample:
{
  "AGE": 18,
  "SEX": "M",
  "INITIAL_EVIDENCE": "E_91",
  "EVIDENCES_JSON_V1": {
    "E_48": [],
    "E_50": [],
    "E_53": [],
    "E_54": [
      "V_161",
      "V_183"
    ],
    "E_55": [
      "V_89",
      "V_108",
      "V_167"
    ],
    "E_56": [
      "4"
    ],
    "E_57": [
      "V_123"
    ],
    "E_58": [
      "3"
    ],
    "E_59": [
      "3"
    ],
    "E_77": [],
    "E_79": [],
    "E_91": [],
    "E_97": [],
    "E_201": [],
    "E_204": [
      "V_10"
    ],
    "E_222": []
  },
  "DIFFERENTIAL_DIAGNOSIS_NOPROB": [
    "Bronchitis",
    "Pneumonia",
    "URTI",
    "Bronchiectasis",
    "Tuberculosis",
    "Influenza",
    "HIV (initial infection)",
    "Chagas"
  ]
}


In [2]:
import random

import torch

from origami import DataConfig, ModelConfig, OrigamiConfig, OrigamiPipeline, TrainingConfig
from origami.training import TableLogCallback, array_f1, array_jaccard

# For reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

config = OrigamiConfig(
    data=DataConfig(
        numeric_mode="none",
    ),
    model=ModelConfig(
        backbone="cached_transformer",
        d_model=256,
        n_layers=6,
        n_heads=8,
    ),
    training=TrainingConfig(
        learning_rate=0.001,
        eval_strategy="steps",
        eval_steps=100,
        eval_sample_size=100,
        eval_metrics={"jaccard": array_jaccard, "f1": array_f1},
        shuffle_keys=True,
        batch_size=100,
        target_key=TARGET_KEY,
        target_loss_weight=1000.0,
    ),
)

pipeline = OrigamiPipeline(config)
callback = TableLogCallback(print_every=10)

In [3]:
pipeline.fit(train_data, eval_data=val_data, callbacks=[callback], epochs=1, verbose=True)

Vocabulary size: 668
Model parameters: 3,571,356
Training device: mps
| step: 10 | epoch: 0 | lr: 1.00e-05 | batch_dt: 323ms | loss: 6.1953 |
| step: 20 | epoch: 0 | lr: 2.00e-05 | batch_dt: 247ms | loss: 5.9347 |
| step: 30 | epoch: 0 | lr: 3.00e-05 | batch_dt: 341ms | loss: 5.5404 |
| step: 40 | epoch: 0 | lr: 4.00e-05 | batch_dt: 288ms | loss: 5.1551 |
| step: 50 | epoch: 0 | lr: 5.00e-05 | batch_dt: 278ms | loss: 4.7891 |
| step: 60 | epoch: 0 | lr: 6.00e-05 | batch_dt: 246ms | loss: 4.3854 |
| step: 70 | epoch: 0 | lr: 7.00e-05 | batch_dt: 264ms | loss: 4.1454 |
| step: 80 | epoch: 0 | lr: 8.00e-05 | batch_dt: 233ms | loss: 3.8173 |
| step: 90 | epoch: 0 | lr: 9.00e-05 | batch_dt: 247ms | loss: 3.6044 |
| step: 100 | epoch: 0 | lr: 1.00e-04 | batch_dt: 252ms | loss: 3.3144 | val_f1: 0.1136 | val_jaccard: 0.0653 | val_loss: 4.2379 |
| step: 110 | epoch: 0 | lr: 1.10e-04 | batch_dt: 258ms | loss: 3.1417 |
| step: 120 | epoch: 0 | lr: 1.20e-04 | batch_dt: 311ms | loss: 3.0290 |
| ste

OrigamiPipeline(numeric_mode='none', fitted)

In [None]:
pipeline.save("ddxplus_origami_pipeline.pt")

In [None]:
from origami import OrigamiPipeline

pipeline = OrigamiPipeline.load("ddxplus_origami_pipeline.pt")

In [9]:
from origami.training import array_f1, array_precision, array_recall

pipeline.evaluate(
    test_data[:1000],
    metrics={"f1": array_f1, "precision": array_precision, "recall": array_recall},
    batch_size=256,
    verbose=True,
)

Computing loss:   0%|          | 0/4 [00:00<?, ?it/s]

Predicting:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 1.0865601003170013,
 'f1': 0.8926597943234562,
 'precision': 0.9153976904929064,
 'recall': 0.8936634039425163}

In [None]:
preds = pipeline.predict_batch(
    test_data[:100],
    target_key=TARGET_KEY,
    allow_complex_values=True,
    batch_size=256,
    profile=False,
)

actuals = [sample[TARGET_KEY] for sample in test_data[:100]]

for pred, actual in zip(preds, actuals, strict=True):
    print(f"Predicted: {sorted(pred)}\nActual: {sorted(actual)}\n")

Predicted: ['Anaphylaxis', 'Boerhaave', 'Bronchitis', 'GERD', 'Pericarditis', 'Possible NSTEMI / STEMI', 'Spontaneous rib fracture']
Actual: ['Anemia', 'Boerhaave', 'Bronchitis', 'GERD', 'Pericarditis', 'Possible NSTEMI / STEMI', 'Stable angina', 'Unstable angina']

Predicted: ['Acute dystonic reactions', 'Acute laryngitis', 'Allergic sinusitis', 'Anemia', 'Atrial fibrillation', 'Bronchitis', 'Bronchospasm / acute asthma exacerbation', 'Chagas', 'Croup', 'Guillain-Barré syndrome', 'Myasthenia gravis', 'Myocarditis', 'PSVT', 'Pneumonia', 'Sarcoidosis', 'Tuberculosis', 'URTI', 'Viral pharyngitis']
Actual: ['Acute dystonic reactions', 'Acute laryngitis', 'Allergic sinusitis', 'Anaphylaxis', 'Anemia', 'Atrial fibrillation', 'Bronchitis', 'Bronchospasm / acute asthma exacerbation', 'Chagas', 'Croup', 'Guillain-Barré syndrome', 'Influenza', 'Myasthenia gravis', 'Myocarditis', 'PSVT', 'Pneumonia', 'SLE', 'Sarcoidosis', 'Scombroid food poisoning', 'Spontaneous pneumothorax', 'Tuberculosis', 'U