# Mistral Entity Taxonomy Classifier

Step-by-step implementation following the Mistral Classifier Factory cookbook.
This notebook trains a multi-label classifier for entity taxonomy classification.

## Step 1: Setup and Data Loading

In [None]:
# Install required packages
!pip install mistralai pandas matplotlib seaborn wandb

In [3]:
import os
import json
import pandas as pd
import time
import random
from pathlib import Path
from mistralai import Mistral

# Configuration
DATA_DIR = Path("../data")
OUTPUT_DIR = Path("../data/output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Set your API keys here
MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY")
WANDB_API_KEY = os.environ.get("WANDB_API_KEY")

if not MISTRAL_API_KEY:
    print("⚠️ Please set MISTRAL_API_KEY:")
    print("os.environ['MISTRAL_API_KEY'] = 'your_key_here'")
else:
    print("✅ Mistral API key found")

if not WANDB_API_KEY:
    print("⚠️ Please set WANDB_API_KEY for experiment tracking:")
    print("os.environ['WANDB_API_KEY'] = 'your_key_here'")
else:
    print("✅ Weights & Biases API key found")

✅ Mistral API key found
✅ Weights & Biases API key found


In [4]:
# Initialize Mistral client
client = Mistral(api_key=MISTRAL_API_KEY)
print("🤖 Mistral client initialized")

🤖 Mistral client initialized


## Step 2: Load and Prepare Data

In [None]:
# Load entity data
entity_df = pd.read_csv(DATA_DIR / "input" / "training_dataset.csv")
print(f"Loaded {len(entity_df)} entities")

# Load parallel classifications
with open(DATA_DIR / "input" / "parallel_classifications.json", 'r', encoding='utf-8') as f:
    classifications = json.load(f)
print(f"Loaded {len(classifications)} classifications")

# Show sample data
print("\n📊 Sample entity:")
print(entity_df[['personId', 'person', 'composite']].head(1))

print("\n📊 Sample classification:")
sample_key = list(classifications.keys())[0]
sample = classifications[sample_key]
print(f"PersonID: {sample_key}")
print(f"Labels: {sample['label']}")
print(f"Paths: {sample['path']}")

In [None]:
# Create entity lookup
entity_lookup = {}
for _, row in entity_df.iterrows():
    person_id = str(row['personId'])
    entity_lookup[person_id] = row['composite']

print(f"Created entity lookup for {len(entity_lookup)} entities")

In [None]:
# Convert to Mistral format
training_examples = []

for person_id, classification_data in classifications.items():
    # Get composite text
    composite_text = entity_lookup.get(person_id)
    if not composite_text:
        continue
    
    # Extract labels and parent categories
    labels_list = classification_data.get('label', [])
    paths_list = classification_data.get('path', [])
    
    if not labels_list:
        continue
    
    # Extract parent categories from paths
    parent_categories = []
    for path in paths_list:
        if " > " in path:
            parent_categories.append(path.split(" > ")[0])
    
    # Create training example in Mistral format
    training_examples.append({
        "text": composite_text,
        "labels": {
            "domain": labels_list,  # Multi-label list
            "parent_category": parent_categories
        }
    })

print(f"Created {len(training_examples)} training examples")

# Show sample
print("\n📝 Sample training example:")
sample_ex = training_examples[0]
print(f"Text: {sample_ex['text'][:100]}...")
print(f"Domains: {sample_ex['labels']['domain']}")
print(f"Parents: {sample_ex['labels']['parent_category']}")

In [None]:
# Split data (80% train, 20% validation)
random.seed(42)
random.shuffle(training_examples)

split_idx = int(len(training_examples) * 0.8)
train_examples = training_examples[:split_idx]
val_examples = training_examples[split_idx:]

print(f"Training set: {len(train_examples)} examples")
print(f"Validation set: {len(val_examples)} examples")

In [None]:
# Save to JSONL files
def save_jsonl(examples, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        for example in examples:
            f.write(json.dumps(example, ensure_ascii=False) + '\n')
    print(f"Saved {len(examples)} examples to {filepath}")

train_path = OUTPUT_DIR / "mistral_train.jsonl"
val_path = OUTPUT_DIR / "mistral_val.jsonl"

save_jsonl(train_examples, train_path)
save_jsonl(val_examples, val_path)

print("✅ Data preparation complete!")

## Step 3: Upload Training Files

In [None]:
# Upload the training data
print("📤 Uploading training data...")
training_data = client.files.upload(
    file={
        "file_name": "mistral_train.jsonl",
        "content": open(train_path, "rb"),
    }
)
print(f"✅ Training file uploaded: {training_data.id}")

# Upload the validation data
print("📤 Uploading validation data...")
validation_data = client.files.upload(
    file={
        "file_name": "mistral_val.jsonl",
        "content": open(val_path, "rb"),
    }
)
print(f"✅ Validation file uploaded: {validation_data.id}")

print("\n📋 File IDs:")
print(f"Training: {training_data.id}")
print(f"Validation: {validation_data.id}")

## Step 4: Create Fine-tuning Job

In [None]:
# Create fine-tuning job with W&B integration
print("🚀 Creating fine-tuning job...")

# job_config = {
#     "model": "ministral-3b-latest",
#     "job_type": "classifier",
#     "training_files": [{"file_id": training_id, "weight": 1}],
#     "validation_files": [validation_id],
#     "hyperparameters": {
#         "training_steps": 250,
#         "learning_rate": 0.00007
#     },
#     "auto_start": True,  # Start manually
#     integrations=[
#     {
#         "project": "entity_resolver",
        
#         "api_key": "WANDB_API_KEY",
#     }
#   ]
# }

# # Add W&B integration if available
# if WANDB_API_KEY:
    
#     job_config["integrations"] = {
#         "wandb": {
#             "project": "entity_resolver",
#             "name": wandb_run_name,
#             "tags": ["mistral", "entity-resolution", "multilabel", "taxonomy"]
#         }
#     }
#     print(f"📊 W&B integration enabled: {wandb_run_name}")

created_jobs = client.fine_tuning.jobs.create(
    model="ministral-3b-latest",
    job_type="classifier",
    training_files=[{"file_id": training_id, "weight": 1}],
    validation_files=[validation_id],
    hyperparameters={
        "training_steps": 250,
        "learning_rate":0.00007
    },
    auto_start=True,
    integrations=[
        {
            "project": "entity_resolver",
            "name": f"mistral-entity-classifier-{int(time.time())}",  
            "api_key": WANDB_API_KEY,
            "tags": ["mistral", "entity-resolution", "multilabel", "taxonomy"]
        }
    ]
)
#created_job = client.fine_tuning.jobs.create(**job_config)

print("✅ Fine-tuning job created!")
print(json.dumps(created_job.model_dump(), indent=2))



In [None]:
job_id = created_jobs.id
print(f"✅ Training job created: {job_id}")
    
print(json.dumps(created_jobs.model_dump(), indent=2))

## Step 5: Start and Monitor Training

In [None]:
# Start the training job
print(f"▶️ Starting training job {job_id}...")
started_job = client.fine_tuning.jobs.start(job_id=job_id)
print(f"✅ Training started! Status: {started_job.status}")

In [None]:
# Monitor training progress
def monitor_training(job_id, check_interval=60):
   """Monitor training job progress."""
   print(f"👀 Monitoring job {job_id}...")
   print(f"Checking every {check_interval} seconds. Press Ctrl+C to stop monitoring.")

   start_time = time.time()

   try:
       while True:
           job = client.fine_tuning.jobs.get(job_id=job_id)
           status = job.status
           elapsed = time.time() - start_time

           timestamp = time.strftime('%H:%M:%S')
           print(f"[{timestamp}] Status: {status} (Elapsed: {elapsed/60:.1f}m)")

           if status == "SUCCESS":
               model_id = job.fine_tuned_model
               print(f"\\n🎉 Training completed successfully!")
               print(f"Model ID: {model_id}")
               print(f"Total time: {elapsed/60:.1f} minutes")

               # Save model info
               model_info = {
                   "job_id": job_id,
                   "model_id": model_id,
                   "status": status,
                   "training_time_minutes": elapsed/60,
                   "completed_at": time.strftime('%Y-%m-%d %H:%M:%S')
               }

               with open(OUTPUT_DIR / "model_info.json", 'w') as f:
                   json.dump(model_info, f, indent=2)

               with open(OUTPUT_DIR / "model_id.txt", 'w') as f:
                   f.write(model_id)

               return model_id

           elif status == "FAILED":
               print(f"\\n❌ Training failed")
               if hasattr(job, 'message'):
                   print(f"Error: {job.message}")
               return None

           elif status in ["RUNNING", "QUEUED", "VALIDATING"]:
               print(f"   Training in progress...")
               time.sleep(check_interval)

           else:
               print(f"   Unknown status: {status}")
               time.sleep(check_interval)

   except KeyboardInterrupt:
       print(f"\\n⏸️ Monitoring stopped. Training continues in background.")
       print(f"Job ID: {job_id}")
       return None

# Start monitoring
model_id = monitor_training(job_id)

In [None]:
retrieved_job = client.fine_tuning.jobs.get(job_id=created_jobs.id)
print(json.dumps(retrieved_job.model_dump(), indent=4))

In [None]:
canceled_jobs = client.fine_tuning.jobs.cancel(job_id = created_jobs.id)
print(canceled_jobs)

In [None]:
jobs = client.fine_tuning.jobs.list()
print(json.dumps(jobs.model_dump(), indent=4))

## Step 6: Test the Trained Model

In [None]:
# Load model ID if training is complete
model_id_path = OUTPUT_DIR / "model_id.txt"
if model_id_path.exists():
    with open(model_id_path, 'r') as f:
        model_id = f.read().strip()
    print(f"📖 Loaded model ID: {model_id}")
else:
    print("⚠️ No trained model found. Please complete training first.")
    model_id = None

In [11]:
# Test classification function
def classify_text(text, model_id):
    """Classify a single text using the trained model."""
    try:
        response = client.classifiers.classify(
            model=model_id,
            inputs=[text]
        )
        
        # result = response.results[0]
        # classification = {}
        
        # for label_name, prediction in result.predictions.items():
        #     if hasattr(prediction, 'value'):
        #         classification[label_name] = prediction.value
        #     else:
        #         classification[label_name] = str(prediction)
        
        return response
    
    except Exception as e:
        print(f"Error: {e}")
        return {}

model_id = "ft:classifier:ministral-3b-latest:2bec22ef:20250623:ff14496d"

if model_id:
    print("🧪 Testing model with sample texts...")
    
    # Test with some examples
    test_texts = [
        "Roles: Contributor\nTitle: Quartette für zwei Violinen, Viola, Violoncell\nSubjects: String quartets--Scores",
        "Roles: Contributor\nTitle: John Wesley's Sunday service of the Methodists\nSubjects: Methodist Church--Liturgy--Texts",
        "Roles: Contributor\nTitle: The owl of Minerva: poems\nAttribution: by James Laughlin",
        "Roles: Creator\nTitle: Archaeology and photography : the early years, 1868-1880\nAttribution: [text] Ismeth Raheem\nSubjects: Photography in archaeology\nProvision information: Colombo : National Trust Sri Lanka, 2009",
        "Roles: Creator\nTitle: Shakespeare, education and pedagogy : representations, interactions and adaptations\nAttribution: edited by Pamela Bickley and Jenny Stevens\nSubjects: Shakespeare, William, 1564-1616--Criticism and interpretation; Shakespeare, William, 1564-1616--Study and teaching; Education in literature; Education\nProvision information: Abingdon, Oxon ; New York, NY : Routledge, Taylor & Francis Group, 2023"
        
    ]
    
    for i, text in enumerate(test_texts, 1):
        print(f"\n📝 Test {i}:")
        print(f"Text: {text[:60]}...")
        
        classification = classify_text(text, model_id)
        if classification:
            # print(f"Domains: {classification.get('domain', [])}")
            # print(f"Parents: {classification.get('parent_category', [])}")
            print("Classifier Response:", json.dumps(classification.model_dump(), indent=4))
        else:
            print("Classification failed")
else:
    print("⏭️ Skipping tests (no trained model)")


# Classify the first test sample
# classifier_response = client.classifiers.classify(
#     model=retrieved_job.fine_tuned_model,
#     inputs=[test_samples[0]["text"]],
# )
# print("Text:", test_samples[0]["text"])
# print("Classifier Response:", json.dumps(classifier_response.model_dump(), indent=4))

🧪 Testing model with sample texts...

📝 Test 1:
Text: Roles: Contributor
Title: Quartette für zwei Violinen, Viola...
Classifier Response: {
    "id": "b189e1ddb3684911afb67502fec639a2",
    "model": "ft:classifier:ministral-3b-latest:2bec22ef:20250623:ff14496d",
    "results": [
        {
            "domain": {
                "scores": {
                    "Politics, Policy, and Government": 1.0128285794053227e-05,
                    "Music, Sound, and Sonic Arts": 0.9998216032981873,
                    "Philosophy and Ethics": 5.092821993457619e-06,
                    "History, Heritage, and Memory": 3.392550979697262e-06,
                    "Media, Journalism, and Communication": 9.420773494639434e-07,
                    "Mathematics and Quantitative Sciences": 2.4056787424342474e-06,
                    "Law, Justice, and Jurisprudence": 3.611351530707907e-06,
                    "Performing Arts and Media": 1.952278398675844e-05,
                    "Military, Security, an

## Step 7: Evaluate on Validation Set

In [None]:
# Evaluate model performance
if model_id:
    print(f"📊 Evaluating model on {len(val_examples)} validation examples...")
    
    correct_exact = 0
    correct_partial = 0
    total = len(val_examples)
    
    for i, example in enumerate(val_examples[:50]):  # Test first 50 for speed
        if i % 10 == 0:
            print(f"  Progress: {i}/{min(50, total)}")
        
        text = example['text']
        true_domains = set(example['labels']['domain'])
        
        classification = classify_text(text, model_id)
        if classification:
            pred_domains = set(classification.get('domain', []))
            
            # Exact match
            if true_domains == pred_domains:
                correct_exact += 1
            
            # Partial match (any overlap)
            if len(true_domains & pred_domains) > 0:
                correct_partial += 1
        
        time.sleep(0.1)  # Rate limiting
    
    tested = min(50, total)
    exact_accuracy = correct_exact / tested
    partial_accuracy = correct_partial / tested
    
    print(f"\n📈 Evaluation Results (on {tested} examples):")
    print(f"Exact match accuracy: {exact_accuracy:.3f} ({correct_exact}/{tested})")
    print(f"Partial match accuracy: {partial_accuracy:.3f} ({correct_partial}/{tested})")
    
    # Save results
    eval_results = {
        "model_id": model_id,
        "tested_examples": tested,
        "exact_accuracy": exact_accuracy,
        "partial_accuracy": partial_accuracy,
        "evaluated_at": time.strftime('%Y-%m-%d %H:%M:%S')
    }
    
    with open(OUTPUT_DIR / "evaluation_results.json", 'w') as f:
        json.dump(eval_results, f, indent=2)
    
    print(f"\n💾 Results saved to {OUTPUT_DIR / 'evaluation_results.json'}")
    
else:
    print("⏭️ Skipping evaluation (no trained model)")

## Step 8: Integration with Entity Resolution Pipeline

In [None]:
# Create taxonomy dissimilarity calculator for entity resolution
class MistralTaxonomyFeature:
    """Mistral-based taxonomy feature for entity resolution."""
    
    def __init__(self, model_id, entity_lookup):
        self.model_id = model_id
        self.entity_lookup = entity_lookup
        self.cache = {}
        self.client = client
    
    def get_classification(self, person_id):
        """Get classification for a person ID with caching."""
        if person_id in self.cache:
            return self.cache[person_id]
        
        text = self.entity_lookup.get(person_id)
        if not text:
            return {}
        
        classification = classify_text(text, self.model_id)
        self.cache[person_id] = classification
        return classification
    
    def calculate_dissimilarity(self, person_id1, person_id2):
        """Calculate taxonomy dissimilarity between two entities."""
        class1 = self.get_classification(person_id1)
        class2 = self.get_classification(person_id2)
        
        if not class1 or not class2:
            return 0.5  # Neutral dissimilarity
        
        domains1 = set(class1.get('domain', []))
        domains2 = set(class2.get('domain', []))
        
        if not domains1 or not domains2:
            return 0.5
        
        # Calculate Jaccard dissimilarity
        intersection = len(domains1 & domains2)
        union = len(domains1 | domains2)
        
        if union == 0:
            return 0.5
        
        similarity = intersection / union
        dissimilarity = 1.0 - similarity
        
        return dissimilarity

if model_id:
    # Create feature calculator
    taxonomy_feature = MistralTaxonomyFeature(model_id, entity_lookup)
    print("🔧 Taxonomy feature calculator created")
    
    # Test with sample entity pairs
    print("\n🧪 Testing dissimilarity calculation:")
    sample_ids = list(entity_lookup.keys())[:5]
    
    for i in range(min(3, len(sample_ids)-1)):
        id1, id2 = sample_ids[i], sample_ids[i+1]
        dissimilarity = taxonomy_feature.calculate_dissimilarity(id1, id2)
        print(f"  {id1} vs {id2}: {dissimilarity:.3f}")
        time.sleep(0.2)  # Rate limiting
    
    print("\n✅ Ready for pipeline integration!")
    print("Use taxonomy_feature.calculate_dissimilarity(id1, id2) in your entity resolution pipeline")
    
else:
    print("⏭️ Skipping integration demo (no trained model)")

## Summary

This notebook has:

1. ✅ Loaded and prepared your entity taxonomy data
2. ✅ Converted to Mistral's multi-label format
3. ✅ Uploaded training files to Mistral
4. ✅ Created and started fine-tuning job
5. ✅ Monitored training progress
6. ✅ Tested the trained model
7. ✅ Evaluated performance
8. ✅ Created integration components

Your Mistral classifier is now ready to replace SetFit in your entity resolution pipeline!

### Next Steps:
- Replace `src/taxonomy_feature.py` with `MistralTaxonomyFeature`
- Update your pipeline configuration
- Test entity resolution performance
- Monitor API usage and costs