In [0]:
# imports
import requests
import time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, concat_ws, regexp_replace, to_date, trim, lit, when, length, udf, broadcast
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, FloatType, ArrayType
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from functools import partial
import sys
import traceback
import json
from langdetect import detect
import re
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Tuple, Set
import pandas as pd
import ast
from json_repair import repair_json

def get_exercises(api_url):
    '''Retrieves complete list of exercises from the wger public API'''
    all_exercises = []
    current_url = api_url

    try:
        while current_url:
            response = requests.get(current_url)
            response.raise_for_status()
            data = response.json()
            all_exercises.extend(data['results'])
            current_url = data.get('next')
        return all_exercises
    except Exception as e:
        print('error fetching exercise data from wger API:', e)

def save_data(df, table_name: str, mode: str = "overwrite"):
    """save pyspark df to delta table."""
    writer = df.write.format("delta").mode(mode)
    if mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")
    writer.saveAsTable(table_name)

def is_english_content(description):
    """check if description is actually in English"""
    try:
        # check description language, sampling first 100 chars 
        if description and len(description.strip()) > 0:
            desc_sample = description[:100]
            desc_lang = detect(desc_sample)
            if desc_lang != 'en':
                return False
        
        return True
    except Exception as e:
        # if detection fails, assume not english
        return False
    
def normalize_record(record):
    """Normalize a record dictionary by extracting and formatting specific fields."""
    # extract name and description from translations with language == 2 (english)
    translations = record.get('translations', [])
    name = None
    description = None
    
    for translation in translations:
        if translation.get("language") == 2:
            candidate_name = translation.get("name")
            candidate_description = translation.get("description")
            
            # verify the content is english
            if is_english_content(candidate_description):
                name = candidate_name
                description = candidate_description
                break
    
    # return None if no English content found
    # this is a data quality fix;
    # there is a known issue in wger API that results in non-english records despite language = 2 filtering
    if not name:
        return None
    
    # data cleaning step: remove HTML tags from name and description
    if name:
        name = re.sub(r'<[^>]*>', '', name).strip()
    if description:
        description = re.sub(r'<[^>]*>', '', description).strip()
    
    # parse datetime fields with proper error handling
    created = None
    if record.get("created"):
        created = datetime.fromisoformat(record["created"].replace('Z', '+00:00'))
    
    last_update = None
    if record.get("last_update"):
        last_update = datetime.fromisoformat(
            record["last_update"].replace('Z', '+00:00'))
    
    # generate normalized dict structure
    normalized = {
        "id": record.get("id"),
        "uuid": record.get("uuid"),
        "name": name.upper(),
        "description": description,
        "created": created,
        "last_update": last_update,
        "category": record.get("category")["name"],
        "muscles": [muscle["name"] for muscle in record.get("muscles", [])],
        "muscles_secondary": [muscle["name"] for muscle in record.get("muscles_secondary", [])],
        "equipment": [equip["name"] for equip in record.get("equipment", [])],
        "variations": record.get("variations") or [],
        "license_author": record.get("license_author")}
    
    return normalized

def query_databricks_foundation_model(prompt, model_name="databricks-llama-4-maverick", max_tokens=2000, temperature=0.5):
    """
    Query a Databricks foundation model through the serving endpoint
    """
    # construct the API endpoint URL
    api_url = f"https://{DATABRICKS_INSTANCE}/serving-endpoints/{model_name}/invocations"
    
    headers = {
        "Authorization": f"Bearer {DATABRICKS_TOKEN}",
        "Content-Type": "application/json"}
    
    # chat payload
    payload_messages = {
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # prompt payload
    payload_prompt = {
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # input payload
    payload_input = {
        "input": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # attempt each format above
    for payload_name, payload in [("messages", payload_messages), ("prompt", payload_prompt), ("input", payload_input)]:
        try:
            print(f"Trying payload format: {payload_name}")
            response = requests.post(api_url, headers=headers, json=payload)
            response.raise_for_status()
            print(f"Success with {payload_name} format!")
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"Failed with {payload_name} format: {e}")
            if hasattr(e, 'response') and e.response is not None:
                print(f"Response content: {e.response.text}")
            continue
    
    print("all payload formats failed")
    return None

# Alternative approach using the newer Databricks SDK
def query_with_databricks_sdk(prompt, model_name="databricks-meta-llama-3-1-70b-instruct", max_tokens=2000, temperature=0.5):
    """
    Alternative approach using Databricks SDK (if available)
    """
    try:
        from databricks.sdk import WorkspaceClient
        
        w = WorkspaceClient()
        
        # Format the request properly for the SDK
        request_data = {
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_tokens": max_tokens,
            "temperature": temperature
        }
        
        response = w.serving_endpoints.query(
            name=model_name,
            **request_data
        )
        
        return response
    except ImportError:
        print("Databricks SDK not available, use the API approach instead")
        return None
    except Exception as e:
        print(f"Error with SDK approach: {e}")
        return None

def format_prompt(instruction, system_prompt=None):
    """
    Format prompt similar to your original template
    Based on documentation here:
    https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/01_load_inference.py
    """
    if system_prompt is None:
        system_prompt = """You are a JSON-only exercise standardization assistant. 
        ALWAYS respond with valid JSON in this exact format:
        {"groups":[{"standardized_name":"NAME","variations":["var1","var2"]}]}
        Never include explanations, markdown, misspellings, or extra text."""
    
    formatted_prompt = f"""<s>[INST]<<SYS>>{system_prompt}<</SYS>>{instruction}[/INST]"""
    
    return formatted_prompt

# Example usage
def gen_text_databricks(prompts, use_template=True, **kwargs):
    """
    Generate text using Databricks foundation model
    """
    results = []
    
    for prompt in prompts:
        if use_template:
            formatted_prompt = format_prompt(prompt)
        else:
            formatted_prompt = prompt
            
        # First try the API approach
        response = query_databricks_foundation_model(
            formatted_prompt, 
            max_tokens=kwargs.get('max_new_tokens', 2000),
            temperature=kwargs.get('temperature', 0.5)
        )
        
        if response:
            # Handle different response formats
            if 'choices' in response and response['choices']:
                # OpenAI-style response
                content = response['choices'][0].get('message', {}).get('content', '') or response['choices'][0].get('text', '')
                results.append(content)
            elif 'predictions' in response:
                # MLflow-style response
                results.append(response['predictions'][0].get('generated_text', ''))
            elif 'candidates' in response:
                # Gemini-style response
                results.append(response['candidates'][0].get('content', {}).get('parts', [{}])[0].get('text', ''))
            else:
                # Try to extract any text from the response
                results.append(str(response))
        else:
            # Try SDK approach as fallback
            sdk_response = query_with_databricks_sdk(
                formatted_prompt,
                max_tokens=kwargs.get('max_new_tokens', 2000),
                temperature=kwargs.get('temperature', 0.5)
            )
            if sdk_response:
                results.append(str(sdk_response))
            else:
                results.append("Error: Could not get response from model")
    
    return results


DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_INSTANCE = spark.conf.get("spark.databricks.workspaceUrl")

# documentation link: https://exercise.hellogym.io/nl/software/api
# the exerciseinfo contains denormalized data with nested
wger_api_url = "https://wger.de/api/v2/exerciseinfo/?status=2&language=2"
exercises = get_exercises(wger_api_url)

ex_schema = StructType([
    StructField("id", IntegerType()),
    StructField("uuid", StringType()), 
    StructField("name", StringType()), 
    StructField("description", StringType()),
    StructField("created", TimestampType()),
    StructField("last_update", TimestampType()),
    StructField("category", StringType()),
    StructField("muscles",  StringType()),
    StructField("muscles_secondary",  StringType()),
    StructField("equipment",  StringType()),
    StructField("variations", StringType()),
    StructField("license_author", StringType())])
    
# filter out None values for data quality
exercise_abbv = [normalize_record(record) for record in exercises]
exercise_abbv = [record for record in exercise_abbv if record is not None]  # Remove non-English records

try:
    exercise_df = spark.createDataFrame(exercise_abbv, schema=ex_schema)
    exercise_df = exercise_df.drop("uuid", "created", "last_update", "license_author")
    display(exercise_df.sort("name"))
except Exception as e:
    print("error creating dataframe:", e)
    traceback.print_exc()


In [0]:
from dataclasses import dataclass, field
from typing import Dict, List, Set, Optional, Tuple
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import re
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings('ignore')

@dataclass
class ExerciseFeatures:
    """Structured representation of exercise features"""
    equipment: Set[str] = field(default_factory=set)
    movement_pattern: Optional[str] = None
    muscle_groups: Set[str] = field(default_factory=set)
    modifiers: Set[str] = field(default_factory=set)
    angle_position: Optional[str] = None
    grip_stance: Optional[str] = None
    raw_name: str = ""
    cleaned_name: str = ""

@dataclass
class StandardizationConfig:
    """Configuration for the standardization pipeline"""
    similarity_threshold: float = 0.85
    feature_weight: float = 0.3
    semantic_weight: float = 0.7
    min_cluster_size: int = 2
    
    # Equipment hierarchies - more specific equipment takes precedence
    equipment_patterns: Dict[str, Set[str]] = field(default_factory=lambda: {
        'BARBELL': {'BARBELL', 'BB', 'OLYMPIC BAR'},
        'DUMBBELL': {'DUMBBELL', 'DB', 'DUMBBELLS'},
        'CABLE': {'CABLE', 'CABLES', 'PULLEY'},
        'MACHINE': {'MACHINE', 'SMITH MACHINE', 'HACK SQUAT'},
        'KETTLEBELL': {'KETTLEBELL', 'KB'},
        'RESISTANCE_BAND': {'RESISTANCE BAND', 'BAND', 'ELASTIC'},
        'BODYWEIGHT': {'BODYWEIGHT', 'BW', 'CALISTHENICS'},
        'EZ_BAR': {'EZ BAR', 'EZ-BAR', 'EZBAR', 'CURL BAR'},
        'T_BAR': {'T-BAR', 'T BAR', 'TBAR'}
    })
    
    movement_patterns: Dict[str, Set[str]] = field(default_factory=lambda: {
        'PRESS': {'PRESS', 'PRESSING'},
        'CURL': {'CURL', 'CURLING'},
        'ROW': {'ROW', 'ROWING'},
        'RAISE': {'RAISE', 'LATERAL RAISE', 'FRONT RAISE'},
        'EXTENSION': {'EXTENSION', 'EXTEND'},
        'FLEXION': {'FLEXION', 'FLEX'},
        'SQUAT': {'SQUAT', 'SQUATTING'},
        'DEADLIFT': {'DEADLIFT', 'DEAD LIFT'},
        'PULL': {'PULL', 'PULLDOWN', 'PULL-DOWN'},
        'PUSH': {'PUSH', 'PUSHUP', 'PUSH-UP'},
        'FLY': {'FLY', 'FLYE'},
        'DIP': {'DIP', 'DIPS'},
        'LUNGE': {'LUNGE', 'LUNGES'},
        'CRUNCH': {'CRUNCH', 'CRUNCHES'},
        'PLANK': {'PLANK', 'PLANKS'}
    })
    
    muscle_groups: Dict[str, Set[str]] = field(default_factory=lambda: {
        'CHEST': {'CHEST', 'PECTORAL', 'PECS'},
        'BACK': {'BACK', 'LATISSIMUS', 'LATS', 'RHOMBOIDS'},
        'SHOULDERS': {'SHOULDER', 'SHOULDERS', 'DELTOID', 'DELTS'},
        'BICEPS': {'BICEP', 'BICEPS'},
        'TRICEPS': {'TRICEP', 'TRICEPS'},
        'LEGS': {'LEG', 'LEGS', 'QUADRICEPS', 'QUADS', 'HAMSTRING', 'HAMSTRINGS'},
        'GLUTES': {'GLUTE', 'GLUTES', 'GLUTEUS'},
        'CALVES': {'CALF', 'CALVES', 'GASTROCNEMIUS'},
        'CORE': {'ABS', 'ABDOMINAL', 'CORE', 'OBLIQUES'}
    })
    
    modifiers: Dict[str, Set[str]] = field(default_factory=lambda: {
        'INCLINE': {'INCLINE', 'INCLINED'},
        'DECLINE': {'DECLINE', 'DECLINED'},
        'FLAT': {'FLAT', 'HORIZONTAL'},
        'SEATED': {'SEATED', 'SITTING'},
        'STANDING': {'STANDING', 'UPRIGHT'},
        'LYING': {'LYING', 'SUPINE', 'PRONE'},
        'SINGLE_ARM': {'SINGLE ARM', 'ONE ARM', 'UNILATERAL'},
        'ALTERNATING': {'ALTERNATING', 'ALTERNATE'},
        'HAMMER': {'HAMMER', 'NEUTRAL GRIP'},
        'REVERSE': {'REVERSE', 'PRONATED'},
        'WIDE': {'WIDE', 'WIDE GRIP'},
        'NARROW': {'NARROW', 'CLOSE GRIP', 'NARROW GRIP'},
        'OVERHEAD': {'OVERHEAD', 'MILITARY'}
    })

class SmartExerciseStandardizer:
    """Advanced exercise standardization using hybrid ML approach"""
    
    def __init__(self, config: StandardizationConfig = None):
        self.config = config or StandardizationConfig()
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.feature_scaler = StandardScaler()
        
    def extract_features(self, exercise_name: str) -> ExerciseFeatures:
        """Extract structured features from exercise name"""
        name_upper = exercise_name.upper()
        features = ExerciseFeatures(raw_name=exercise_name)
        
        # Extract equipment (most specific wins)
        equipment_scores = {}
        for equip_type, patterns in self.config.equipment_patterns.items():
            for pattern in patterns:
                if pattern in name_upper:
                    equipment_scores[equip_type] = len(pattern)  # Longer match = more specific
        
        if equipment_scores:
            best_equipment = max(equipment_scores, key=equipment_scores.get)
            features.equipment.add(best_equipment)
        
        # Extract movement pattern
        for movement, patterns in self.config.movement_patterns.items():
            if any(pattern in name_upper for pattern in patterns):
                features.movement_pattern = movement
                break
        
        # Extract muscle groups
        for muscle, patterns in self.config.muscle_groups.items():
            if any(pattern in name_upper for pattern in patterns):
                features.muscle_groups.add(muscle)
        
        # Extract modifiers
        for modifier, patterns in self.config.modifiers.items():
            if any(pattern in name_upper for pattern in patterns):
                features.modifiers.add(modifier)
        
        # Extract angle/position
        if any(word in name_upper for word in ['INCLINE', 'DECLINE', 'FLAT']):
            if 'INCLINE' in name_upper:
                features.angle_position = 'INCLINE'
            elif 'DECLINE' in name_upper:
                features.angle_position = 'DECLINE'
            else:
                features.angle_position = 'FLAT'
        
        # Extract grip/stance
        if any(word in name_upper for word in ['WIDE', 'NARROW', 'CLOSE', 'HAMMER']):
            if 'WIDE' in name_upper:
                features.grip_stance = 'WIDE'
            elif any(word in name_upper for word in ['NARROW', 'CLOSE']):
                features.grip_stance = 'NARROW'
            elif 'HAMMER' in name_upper:
                features.grip_stance = 'HAMMER'
        
        # Clean name
        features.cleaned_name = self._clean_exercise_name(exercise_name)
        
        return features
    
    def _clean_exercise_name(self, name: str) -> str:
        """Clean exercise name while preserving important distinctions"""
        # Remove parenthetical information
        name = re.sub(r'\s*\([^)]*\)', '', name)
        
        # Standardize common variations
        replacements = {
            r'DUMBBELLS?': 'DUMBBELL',
            r'BARBELLS?': 'BARBELL',
            r'PULL[-\s]?UPS?': 'PULL UP',
            r'PUSH[-\s]?UPS?': 'PUSH UP',
            r'SIT[-\s]?UPS?': 'SIT UP',
            r'BICEPS?': 'BICEP',
            r'TRICEPS?': 'TRICEP'
        }
        
        for pattern, replacement in replacements.items():
            name = re.sub(pattern, replacement, name, flags=re.IGNORECASE)
        
        # Clean whitespace
        name = re.sub(r'[-_]+', ' ', name)
        name = ' '.join(name.split())
        
        return name.strip()
    
    def create_feature_vector(self, features: ExerciseFeatures) -> np.ndarray:
        """Convert structured features to numerical vector"""
        vector = []
        
        # Equipment features (one-hot encoded)
        for equip in ['BARBELL', 'DUMBBELL', 'CABLE', 'MACHINE', 'KETTLEBELL', 
                     'RESISTANCE_BAND', 'BODYWEIGHT', 'EZ_BAR', 'T_BAR']:
            vector.append(1 if equip in features.equipment else 0)
        
        # Movement pattern features
        for movement in ['PRESS', 'CURL', 'ROW', 'RAISE', 'EXTENSION', 'SQUAT', 
                        'DEADLIFT', 'PULL', 'PUSH', 'FLY', 'DIP', 'LUNGE', 'CRUNCH']:
            vector.append(1 if features.movement_pattern == movement else 0)
        
        # Muscle group features
        for muscle in ['CHEST', 'BACK', 'SHOULDERS', 'BICEPS', 'TRICEPS', 'LEGS', 'GLUTES', 'CORE']:
            vector.append(1 if muscle in features.muscle_groups else 0)
        
        # Modifier features
        for modifier in ['INCLINE', 'DECLINE', 'SEATED', 'STANDING', 'HAMMER', 'REVERSE', 'WIDE', 'NARROW']:
            vector.append(1 if modifier in features.modifiers else 0)
        
        return np.array(vector)
    
    def compute_hybrid_similarity(self, exercises: List[str]) -> np.ndarray:
        """Compute hybrid similarity matrix combining semantic and structural features"""
        # Extract features for all exercises
        all_features = [self.extract_features(ex) for ex in exercises]
        
        # Create semantic embeddings
        semantic_embeddings = self.model.encode(exercises)
        
        # Create feature vectors
        feature_vectors = np.array([self.create_feature_vector(f) for f in all_features])
        
        # Compute similarity matrices
        semantic_sim = cosine_similarity(semantic_embeddings)
        feature_sim = cosine_similarity(feature_vectors)
        
        # Combine similarities
        hybrid_sim = (self.config.semantic_weight * semantic_sim + 
                     self.config.feature_weight * feature_sim)
        
        return hybrid_sim
    
    def smart_clustering(self, exercises: List[str]) -> Dict[str, List[str]]:
        """Perform intelligent clustering using hybrid similarity"""
        print(f"Clustering {len(exercises)} exercises...")
        
        # Compute hybrid similarity
        similarity_matrix = self.compute_hybrid_similarity(exercises)
        
        # Convert similarity to distance
        distance_matrix = 1 - similarity_matrix
        
        # Use Agglomerative Clustering for better control
        n_clusters = None
        linkage = 'average'
        
        clustering = AgglomerativeClustering(
            n_clusters=n_clusters,
            distance_threshold=1 - self.config.similarity_threshold,
            linkage=linkage,
            metric='precomputed'
        )
        
        labels = clustering.fit_predict(distance_matrix)
        
        # Group exercises by cluster
        clusters = defaultdict(list)
        for i, label in enumerate(labels):
            clusters[label].append(exercises[i])
        
        print(f"Found {len(clusters)} clusters")
        return dict(clusters)
    
    def select_canonical_name(self, exercise_group: List[str]) -> str:
        """Select the most representative name from a group"""
        if len(exercise_group) == 1:
            return self._clean_exercise_name(exercise_group[0])
        
        # Extract features for all exercises in group
        features_list = [self.extract_features(ex) for ex in exercise_group]
        
        # Score each exercise based on completeness and clarity
        scores = []
        for i, (exercise, features) in enumerate(zip(exercise_group, features_list)):
            score = 0
            
            # Prefer exercises with equipment specified
            if features.equipment:
                score += 2
            
            # Prefer exercises with clear movement pattern
            if features.movement_pattern:
                score += 2
            
            # Prefer exercises with modifiers (more specific)
            score += len(features.modifiers)
            
            # Prefer shorter names (less likely to have extra info)
            score += max(0, 10 - len(exercise.split()))
            
            # Prefer common naming patterns
            if any(pattern in exercise.upper() for pattern in ['BARBELL', 'DUMBBELL']):
                score += 1
            
            scores.append((score, exercise))
        
        # Select highest scoring exercise
        best_exercise = max(scores, key=lambda x: x[0])[1]
        return self._clean_exercise_name(best_exercise)
    
    def validate_cluster_quality(self, cluster: List[str]) -> bool:
        """Validate if exercises in cluster should really be grouped"""
        if len(cluster) <= 1:
            return True
        
        features_list = [self.extract_features(ex) for ex in cluster]
        
        # Check if all exercises have same core characteristics
        equipments = [f.equipment for f in features_list if f.equipment]
        movements = [f.movement_pattern for f in features_list if f.movement_pattern]
        
        # If equipment differs significantly, likely different exercises
        if len(set().union(*equipments)) > 1 and len(equipments) > 1:
            return False
        
        # If movement patterns differ, likely different exercises
        if len(set(movements)) > 1:
            return False
        
        return True
    
    def standardize_exercises(self, exercises: List[str]) -> Tuple[Dict[str, str], pd.DataFrame]:
        """Main standardization pipeline"""
        print("Starting exercise standardization...")
        
        # Remove duplicates while preserving order
        unique_exercises = list(dict.fromkeys(exercises))
        print(f"Processing {len(unique_exercises)} unique exercises...")
        
        # Perform smart clustering
        clusters = self.smart_clustering(unique_exercises)
        
        # Create mapping and validate clusters
        mapping = {}
        validated_clusters = []
        
        for cluster_id, cluster_exercises in clusters.items():
            if self.validate_cluster_quality(cluster_exercises):
                canonical_name = self.select_canonical_name(cluster_exercises)
                for exercise in cluster_exercises:
                    mapping[exercise] = canonical_name
                validated_clusters.append({
                    'canonical_name': canonical_name,
                    'variations': cluster_exercises,
                    'count': len(cluster_exercises)
                })
            else:
                # Split invalid clusters
                for exercise in cluster_exercises:
                    mapping[exercise] = self._clean_exercise_name(exercise)
                    validated_clusters.append({
                        'canonical_name': self._clean_exercise_name(exercise),
                        'variations': [exercise],
                        'count': 1
                    })
        
        # Create summary DataFrame
        summary_df = pd.DataFrame(validated_clusters)
        summary_df = summary_df.sort_values('count', ascending=False)
        
        # Print results
        original_count = len(unique_exercises)
        standardized_count = len(set(mapping.values()))
        reduction = (1 - standardized_count / original_count) * 100
        
        print(f"\nStandardization Results:")
        print(f"Original exercises: {original_count}")
        print(f"Standardized exercises: {standardized_count}")
        print(f"Reduction: {reduction:.1f}%")
        print(f"Largest groups: {summary_df.head(3)['canonical_name'].tolist()}")
        
        return mapping, summary_df

# Usage example
def apply_standardization(df, exercise_column='name'):
    """Apply standardization to DataFrame using serverless-compatible operations"""
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import broadcast
    
    standardizer = SmartExerciseStandardizer()
    
    # Get unique exercise names using DataFrame operations (serverless compatible)
    unique_exercises_df = df.select(exercise_column).distinct().filter(f"{exercise_column} IS NOT NULL")
    unique_exercises = [row[exercise_column] for row in unique_exercises_df.collect()]
    
    print(f"Collected {len(unique_exercises)} unique exercises for standardization")
    
    # Standardize
    mapping, summary = standardizer.standardize_exercises(unique_exercises)
    
    # Create mapping DataFrame for join
    spark = SparkSession.getActiveSession()
    mapping_data = [(k, v) for k, v in mapping.items()]
    mapping_df = spark.createDataFrame(mapping_data, [exercise_column, "standardized_name"])
    
    # Use broadcast join for efficiency (mapping table is small)
    result_df = df.join(broadcast(mapping_df), on=exercise_column, how="left")
    
    return result_df, mapping, summary

# Alternative approach using pandas UDF if you need to process in batches
def apply_standardization_batched(df, exercise_column='name', batch_size=1000):
    """Apply standardization using pandas UDFs for large datasets"""
    from pyspark.sql.functions import pandas_udf, col
    from pyspark.sql.types import StringType
    import pandas as pd
    
    # First, get all unique exercises and create mapping
    unique_exercises_df = df.select(exercise_column).distinct().filter(f"{exercise_column} IS NOT NULL")
    unique_exercises = [row[exercise_column] for row in unique_exercises_df.collect()]
    
    standardizer = SmartExerciseStandardizer()
    mapping, summary = standardizer.standardize_exercises(unique_exercises)
    
    # Create pandas UDF for standardization
    @pandas_udf(StringType())
    def standardize_exercise_udf(exercise_series: pd.Series) -> pd.Series:
        return exercise_series.map(mapping).fillna(exercise_series)
    
    # Apply standardization
    result_df = df.withColumn("standardized_name", standardize_exercise_udf(col(exercise_column)))
    
    return result_df, mapping, summary

# Example usage:
standardized_df, exercise_mapping, summary_report = apply_standardization(exercise_df)

In [0]:
display(standardized_df.select('name', 'standardized_name').sort('standardized_name'))