In [0]:
# imports
import requests
import time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, concat_ws, regexp_replace, to_date, trim, lit, when, length, udf, broadcast
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, FloatType, ArrayType
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from functools import partial
import sys
import traceback
import json
from langdetect import detect
import re
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import DBSCAN
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Tuple, Set
import pandas as pd
import ast
from json_repair import repair_json
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk import pos_tag, word_tokenize, punkt

def get_exercises(api_url):
    """Retrieves complete list of exercises from the wger public API"""
    all_exercises = []
    current_url = api_url

    try:
        while current_url:
            response = requests.get(current_url)
            response.raise_for_status()
            data = response.json()
            all_exercises.extend(data["results"])
            current_url = data.get("next")
        return all_exercises
    except Exception as e:
        print("error fetching exercise data from wger API:", e)

def save_data(df, table_name: str, mode: str = "overwrite"):
    """save pyspark df to delta table."""
    writer = df.write.format("delta").mode(mode)
    if mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")
    writer.saveAsTable(table_name)

def is_english_content(description):
    """check if description is actually in English"""
    try:
        # check description language, sampling first 100 chars 
        if description and len(description.strip()) > 0:
            desc_sample = description[:100]
            desc_lang = detect(desc_sample)
            if desc_lang != "en":
                return False
        
        return True
    except Exception as e:
        # if detection fails, assume not english
        return False
    
def normalize_record(record):
    """Normalize a record dictionary by extracting and formatting specific fields."""
    # extract name and description from translations with language == 2 (english)
    translations = record.get("translations", [])
    name = None
    description = None
    
    for translation in translations:
        if translation.get("language") == 2:
            candidate_name = translation.get("name")
            candidate_description = translation.get("description")
            
            # verify the content is english
            if is_english_content(candidate_description):
                name = candidate_name
                description = candidate_description
                break
    
    # return None if no English content found
    # this is a data quality fix;
    # there is a known issue in wger API that results in non-english records despite language = 2 filtering
    if not name:
        return None
    
    # data cleaning step: remove HTML tags from name and description
    if name:
        name = re.sub(r"<[^>]*>", "", name).strip()
    if description:
        description = re.sub(r"<[^>]*>", "", description).strip()
    
    # parse datetime fields with proper error handling
    created = None
    if record.get("created"):
        created = datetime.fromisoformat(record["created"].replace("Z", "+00:00"))
    
    last_update = None
    if record.get("last_update"):
        last_update = datetime.fromisoformat(
            record["last_update"].replace("Z", "+00:00"))
    
    # generate normalized dict structure
    normalized = {
        "id": record.get("id"),
        "uuid": record.get("uuid"),
        "name": name.upper(),
        "description": description,
        "created": created,
        "last_update": last_update,
        "category": record.get("category")["name"],
        "muscles": [muscle["name"] for muscle in record.get("muscles", [])],
        "muscles_secondary": [muscle["name"] for muscle in record.get("muscles_secondary", [])],
        "equipment": [equip["name"] for equip in record.get("equipment", [])],
        "variations": record.get("variations") or [],
        "license_author": record.get("license_author")}
    
    return normalized

def query_databricks_foundation_model(prompt, model_name="databricks-llama-4-maverick", max_tokens=512, temperature=0.5):
    """
    Query a Databricks foundation model through the serving endpoint
    """
    
    # construct the API endpoint URL
    api_url = f"https://{DATABRICKS_INSTANCE}/serving-endpoints/{model_name}/invocations"
    
    headers = {
        "Authorization": f"Bearer {DATABRICKS_TOKEN}",
        "Content-Type": "application/json"}
    
    # chat payload
    payload_messages = {
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # prompt payload
    payload_prompt = {
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # input payload
    payload_input = {
        "input": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # attempt each format above
    for payload_name, payload in [("messages", payload_messages), ("prompt", payload_prompt), ("input", payload_input)]:
        try:
            print(f"Trying payload format: {payload_name}")
            response = requests.post(api_url, headers=headers, json=payload)
            response.raise_for_status()
            print(f"Success with {payload_name} format!")
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"Failed with {payload_name} format: {e}")
            if hasattr(e, "response") and e.response is not None:
                print(f"Response content: {e.response.text}")
            continue
    
    print("all payload formats failed")
    return None

# Alternative approach using the newer Databricks SDK
def query_with_databricks_sdk(prompt, model_name="databricks-llama-4-maverick", max_tokens=5000, temperature=0.5):
    """
    Alternative approach using Databricks SDK (if available)
    """
    try:
        from databricks.sdk import WorkspaceClient
        
        w = WorkspaceClient()
        
        # Format the request properly for the SDK
        request_data = {
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_tokens": max_tokens,
            "temperature": temperature
        }
        
        response = w.serving_endpoints.query(
            name=model_name,
            **request_data
        )
        
        return response
    except ImportError:
        print("Databricks SDK not available, use the API approach instead")
        return None
    except Exception as e:
        print(f"Error with SDK approach: {e}")
        return None

def format_prompt(instruction, system_prompt=None):
    """
    Format prompt similar to your original template
    Based on documentation here:
    https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/01_load_inference.py
    """
    if system_prompt is None:
        system_prompt = """You are a JSON-only exercise standardization assistant. 
        ALWAYS respond with valid JSON in this exact format:
        {"groups":[{"standardized_name":"NAME","variations":["var1","var2"]}]}
        Never include explanations, markdown, misspellings, or extra text."""
    
    formatted_prompt = f"""[INST]<>{system_prompt}<>{instruction}[/INST]"""
    
    return formatted_prompt

# Example usage
def gen_text_databricks(prompts, use_template=True, **kwargs):
    """
    Generate text using Databricks foundation model
    """
    results = []
    
    for prompt in prompts:
        if use_template:
            formatted_prompt = format_prompt(prompt)
        else:
            formatted_prompt = prompt
            
        # First try the API approach
        response = query_databricks_foundation_model(
            formatted_prompt, 
            max_tokens=kwargs.get("max_new_tokens", 512),
            temperature=kwargs.get("temperature", 0.5)
        )
        
        if response:
            # Handle different response formats
            if "choices" in response and response["choices"]:
                # OpenAI-style response
                content = response["choices"][0].get("message", {}).get("content", "") or response["choices"][0].get("text", "")
                results.append(content)
            elif "predictions" in response:
                # MLflow-style response
                results.append(response["predictions"][0].get("generated_text", ""))
            elif "candidates" in response:
                # Gemini-style response
                results.append(response["candidates"][0].get("content", {}).get("parts", [{}])[0].get("text", ""))
            else:
                # Try to extract any text from the response
                results.append(str(response))
        else:
            # Try SDK approach as fallback
            sdk_response = query_with_databricks_sdk(
                formatted_prompt,
                max_tokens=kwargs.get("max_new_tokens", 512),
                temperature=kwargs.get("temperature", 0.5)
            )
            if sdk_response:
                results.append(str(sdk_response))
            else:
                results.append("Error: Could not get response from model")
    
    return results


DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_INSTANCE = spark.conf.get("spark.databricks.workspaceUrl")

# documentation link: https://exercise.hellogym.io/nl/software/api
# the exerciseinfo contains denormalized data with nested
wger_api_url = "https://wger.de/api/v2/exerciseinfo/?status=2&language=2"
exercises = get_exercises(wger_api_url)

ex_schema = StructType([
    StructField("id", IntegerType()),
    StructField("uuid", StringType()), 
    StructField("name", StringType()), 
    StructField("description", StringType()),
    StructField("created", TimestampType()),
    StructField("last_update", TimestampType()),
    StructField("category", StringType()),
    StructField("muscles",  StringType()),
    StructField("muscles_secondary",  StringType()),
    StructField("equipment",  StringType()),
    StructField("variations", StringType()),
    StructField("license_author", StringType())])
    
# filter out None values for data quality
exercise_abbv = [normalize_record(record) for record in exercises]
exercise_abbv = [record for record in exercise_abbv if record is not None]  # Remove non-English records

try:
    exercise_df = spark.createDataFrame(exercise_abbv, schema=ex_schema)
    exercise_df = exercise_df.drop("uuid", "created", "last_update", "license_author")
    display(exercise_df.sort("name"))
except Exception as e:
    print("error creating dataframe:", e)
    traceback.print_exc()

In [0]:
from nltk.tokenize import word_tokenize
import nltk

lemmatizer = WordNetLemmatizer()
STOPWORDS = {'WITH', 'AND', 'ON', 'INTO', 'IN', 'AT', 'BY', 'OF', 'FOR', 'TO',
    'FROM', 'USING', 'THE', 'A', 'AN', 'ONTO'}

def validate_json_structure(result):
    """Validate that the JSON has the required structure"""
    if not isinstance(result, dict):
        raise ValueError("Result is not a dictionary")
    
    if "groups" not in result:
        raise ValueError("Missing 'groups' key")
    
    if not isinstance(result["groups"], list):
        raise ValueError("'groups' is not a list")
    
    for i, group in enumerate(result["groups"]):
        if not isinstance(group, dict):
            raise ValueError(f"Group {i} is not a dictionary")
        
        # Check for standardized name with multiple possible keys
        if not any(key in group for key in ["standardized_name", "standard_name"]):
            raise ValueError(f"Group {i} missing standardized name key")
        
        if "variations" not in group:
            raise ValueError(f"Group {i} missing 'variations' key")
        
        if not isinstance(group["variations"], list):
            raise ValueError(f"Group {i} 'variations' is not a list")
    
    return True

def escalate_prompt(base_prompt, level):
    escalated = base_prompt
    if level == 1:
        escalated += "\nSTRICTLY enforce JSON compliance. Only output a parsable JSON object."
    elif level == 2:
        escalated += "\nYour previous response was invalid. Ensure the response is a clean JSON object with no extra formatting or text."
    elif level >= 3:
        escalated += "\nFinal warning: Return ONLY a valid JSON object. No natural language, no markdown, no surrounding explanations."
    return escalated

def get_wordnet_pos(tag):
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    return wordnet.NOUN

# dataset augmentation module
class AdvancedExerciseStandardizer:
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        """
        Initialize with a sentence transformer model optimized for semantic similarity
        """
        self.model = SentenceTransformer(model_name)
        
    def preprocess_exercise_name(self, name: str) -> str:
        """
        Preprocess exercise name with lemmatization and stopword removal
        """
        name = name.upper()
        name = re.sub(r'[-_/\\|]+', ' ', name)
        name = re.sub(r'[^\w\s]', '', name)
        name = re.sub(r'\s+', ' ', name).strip()
        tokens = word_tokenize(name.lower())
        tagged = pos_tag(tokens)
        lemmatized = [
            lemmatizer.lemmatize(word, get_wordnet_pos(pos))
            for word, pos in tagged
            if word.upper() not in STOPWORDS]

        return ' '.join([t.upper() for t in lemmatized])

    def extract_exercise_features(self, exercise_name: str) -> Dict:
        """
        Extract key features that distinguish exercises
        """
        name_upper = exercise_name.upper()
        
        # Find equipment in the name
        equipment_found = set()
        equipment_list = [
            "BARBELL", "DUMBBELL", "KETTLEBELL", 
            "EZ BAR", "EZ-BAR", "EZBAR", "T-BAR", "T BAR", "HEX", "HEX BAR",
            "CABLE",  "ROPE", "STRAIGHT BAR", "V-BAR", "TRICEP BAR",
            "HANDLE", "SINGLE HANDLE", "MACHINE", "SMITH MACHINE", "CABLE MACHINE",
            "PREACHER", "LEVER", "HACK","TRX", "SUSPENSION", "BAR", "BAG",
            "RESISTANCE BAND", "BAND", "BANDED", "MINI BAND", "THERABAND",
            "BENCH", "INCLINE BENCH", "DECLINE BENCH", "ADJUSTABLE BENCH", "STEP", "BOX", "PULL BLOCK", "BALL"]
        
        for equip in equipment_list:
            if equip in name_upper:
                equipment_found.add(equip)

        # Find movement patterns
        movement_found = set()
        movement_list = [
            "SQUAT", "DEADLIFT", "LUNGE", "STEP-UP", "PULL-APART",
            "HINGE", "PUSH", "PULL", "PRESS", "ROW",
            "DIP", "CURL", "EXTENSION", "RAISE", "FLY", "FLYE",
            "PULLOVER", "SHRUG", "ROTATION", "TWIST",
            "CRUNCH", "PLANK", "HOLD", "BRIDGE",  "BEND",
            "LIFT", "JUMP", "HOP", "SPRINT", "KICKBACK",
            "CLAP", "THRUST", "TAP", "THROWS",  "PULL UP",
            "PUSH DOWN", "CHIN UP", "CLIMB", "CARRY", "WALK", "THROW", "ROLL"]
        
        for movement in movement_list:
            if movement in name_upper:
                movement_found.add(movement)
        
        # Find body parts
        body_parts_found = set()
        body_parts_list = [
            "NECK", "TRAP", "SHOULDER", "DELTOID", "DELT",
            "BICEP", "BI", "TRICEP", "TRI",
            "FOREARM", "CHEST", "PECTORAL", "PEC",
            "BACK", "LAT", "UPPER BACK", "LOWER BACK",
            "ABS", "ABDOMINALS", "AB", "OBLIQUE", "CORE",
            "GLUTE", "BUTT", "HIP",  "HIP FLEXOR",
            "ADDUCTOR", "ABDUCTOR", "QUAD", "QUADRICEP", "HAMSTRING", "HAM",
            "CALF", "LEG", "ARM", "CHIN"]
        
        for body_part in body_parts_list:
            if body_part in name_upper:
                body_parts_found.add(body_part)

        # Find modifiers
        modifiers_found = set()
        modifiers_list = [
            "WIDE", "WIDE-GRIP", "NARROW", "NARROW-GRIP", "CLOSE", "CLOSE-GRIP", "CROSSBODY", "CROSS-BODY",
            "HAMMER", "NEUTRAL-GRIP", "OVERHAND", "UNDERHAND", "SUPINATED", "PRONATED", 
            "REVERSE", "MIXED-GRIP", "SINGLE-ARM", "ONE-ARM", "ONE-ARMED","SINGLE-LEG", "ONE-LEG", "ONE-LEGGED",
            "UNILATERAL", "ALTERNATING", "ALTERNATE","ISOLATED", "ISOLATION","CROSSOVER", "CROSS-BODY","BILATERAL", 
            "DOUBLE-ARM", "DOUBLE-LEG", "SUMO", "WIDE-STANCE", "NARROW-STANCE", "CLOSE-STANCE", "SPLIT", "BULGARIAN",
            "STAGGERED","STEP-BACK", "STEP-UP","LATERAL", "SIDE", "SIDE-STEP","CROSSOVER", "REVERSE", "INCLINE", 
            "DECLINE", "FLAT","VERTICAL", "HORIZONTAL","DIAGONAL", "ROTATIONAL", "TWISTING", "STANDING", "SEATED", 
            "SITTING", "LYING", "PRONE", "SUPINE", "SIDE-LYING", "KNEELING", "HALF-KNEELING", "LEANING", 
            "FORWARD-LEANING", "PARTIAL", "FULL", "HALF", "HOLD", "ISOMETRIC", "ISO", "STATIC-HOLD","LATERAL"]
        
        for modifier in modifiers_list:
            if modifier in name_upper:
                modifiers_found.add(modifier)
        
        return {
            "equipment": equipment_found,
            "movement": movement_found,
            "body_parts": body_parts_found,
            "modifiers": modifiers_found,
            "raw_name": exercise_name}

    def create_feature_aware_embeddings(self, exercise_names: List[str]) -> np.ndarray:
        """
        Create embeddings that emphasize critical distinguishing features
        """
        enhanced_texts = []
        
        for name in exercise_names:
            # First, preprocess the name (lemmatize, remove stopwords)
            processed_name = self.preprocess_exercise_name(name)
            features = self.extract_exercise_features(name)
            
            # Create enhanced text that emphasizes critical features
            enhanced_parts = [processed_name]  # Use processed name
            
            # Equipment gets highest weight (most distinguishing) - 3x
            for equip in features["equipment"]:
                enhanced_parts.extend([equip] * 3)
            
            # Movement gets medium weight - 2x  
            for movement in features["movement"]:
                enhanced_parts.extend([movement] * 2)
            
            # Modifiers get medium weight - 2x
            for modifier in features["modifiers"]:
                enhanced_parts.extend([modifier] * 2)
            
            # Body parts get lower weight - 1x
            for body_part in features["body_parts"]:
                enhanced_parts.append(body_part)
            
            enhanced_texts.append(" ".join(enhanced_parts))
        
        return self.model.encode(enhanced_texts)

    def feature_based_pre_clustering(self, exercise_names: List[str]) -> Dict[str, List[str]]:
        """
        Pre-cluster exercises by equipment type to create more balanced groups
        """
        equipment_groups = {}
        
        for name in exercise_names:
            features = self.extract_exercise_features(name)
            
            # Create equipment signature (primary equipment)
            if features["equipment"]:
                # Prioritize more specific equipment
                equipment_priority = [
                    "BARBELL", "DUMBBELL", "KETTLEBELL", "EZ BAR", "EZ-BAR", "EZBAR",
                    "T-BAR", "T BAR", "HEX BAR", "CABLE", "MACHINE", "SMITH MACHINE"]
                
                primary_equipment = None
                for equip in equipment_priority:
                    if equip in features["equipment"]:
                        primary_equipment = equip
                        break
                
                if not primary_equipment:
                    primary_equipment = next(iter(features["equipment"]))
            else:
                primary_equipment = "BODYWEIGHT"
            
            # Add movement pattern for finer grouping
            if features["movement"]:
                primary_movement = next(iter(features["movement"]))
                group_key = f"{primary_equipment}_{primary_movement}"
            else:
                group_key = primary_equipment
            
            if group_key not in equipment_groups:
                equipment_groups[group_key] = []
            equipment_groups[group_key].append(name)
        
        return equipment_groups

    def balanced_clustering_within_group(self, exercise_names: List[str], max_cluster_size: int = 25) -> List[List[str]]:
        """
        Perform balanced clustering within a pre-filtered group
        """
        if len(exercise_names) <= max_cluster_size:
            return [exercise_names]
        
        # Create embeddings for this group
        embeddings = self.create_feature_aware_embeddings(exercise_names)
        
        # Calculate initial number of clusters
        n_clusters = max(1, (len(exercise_names) + max_cluster_size - 1) // max_cluster_size)
        
        clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage='average')
        cluster_labels = clustering.fit_predict(embeddings)
        
        # Group exercises by cluster
        clusters = []
        for cluster_id in set(cluster_labels):
            cluster_exercises = [exercise_names[i] for i in range(len(exercise_names)) 
                               if cluster_labels[i] == cluster_id]
            
            # Recursively split if still too large
            if len(cluster_exercises) > max_cluster_size:
                sub_clusters = self.balanced_clustering_within_group(cluster_exercises, max_cluster_size)
                clusters.extend(sub_clusters)
            else:
                clusters.append(cluster_exercises)
        
        return clusters

    def get_base_prompt(self, exercise_names):
        prompt = f"""
        You are an expert exercise classification system. Your task is to analyze a list of exercise names and determine whether they represent the same or distinct exercises.

        For any exercises that refer to the same underlying movement, group them together and provide a single standardized name for the group. 
        If exercises are functionally distinct, group them separately with unique standardized names.

        Use the following strict disambiguation rules:

        RULES FOR CLASSIFICATION:
        1. Exercises using different equipment (e.g., Barbell vs Dumbbell vs EZ Bar) = DIFFERENT exercises.
        2. Exercises with different grips or hand positions (e.g., Wide, Narrow, Neutral, Reverse) = DIFFERENT exercises.
        3. Exercises performed at different body positions or angles (e.g., Incline, Decline, Flat, Overhead) = DIFFERENT exercises.
        4. Variations due only to spelling, punctuation, pluralization, or casing = SAME exercise.
        5. Left/right or single-side references (e.g., "Left Arm", "Right Leg") = SAME exercise.
        6. NEVER use hyphens in your standardized names (e.g., "Pull Up", NOT "Pull-Up")
        7. Do not use acronyms or abbreviations in your standardized names.
        8. Be aware of common acronyms and abbreviations in exercise names (e.g., "OHP" for "Overhead Press", "DB" for "Dumbbell", "BB" for "Barbell") and do not persist acronyms or abbreviations in standardized names.
        9. Do NOT generalize specific exercise names into broader categories (e.g., "Spider Curl" should not be generalized to "Bicep Curl", "Bird Dog" should not be generalized to "Arabesque")

        INSTRUCTIONS:
        - Return output as **valid JSON** only — do not include explanations, markdown, or extra text.
        - Do not combine individual words into single tokens (e.g., avoid malformed outputs like "BENTOVERROWTOEXTERNALROTATION").
        - Ensure standardized names are clean, readable, and semantically meaningful.

        INPUT:
        Exercise names: {exercise_names}

        EXPECTED OUTPUT FORMAT:
        {{
        "groups": [
            {{
            "standardized_name": "STANDARDIZED EXERCISE NAME",
            "variations": ["Variation 1", "Variation 2", "..."]
            }},
            {{
            "standardized_name": "ANOTHER STANDARDIZED EXERCISE NAME",
            "variations": ["Variation A", "Variation B"]
            }}
        ]
        }}

        EXAMPLE OUTPUT:
        {{
        "groups": [
            {{
            "standardized_name": "BARBELL BICEP CURL",
            "variations": ["Barbell Biceps Curl", "Bar Bell Bicep Curls"]
            }},
            {{
            "standardized_name": "EZ BAR BICEP CURL",
            "variations": ["EZ Bar Biceps Curl", "Ez-Bar Bicep Curl"]
            }}
        ]
        }}
        """
        return prompt

    def llm_validate_cluster(self, exercise_names: List[str]) -> Dict[str, str]:
        """
        use LLM to validate and standardize a cluster of exercise names
        """
        prompt = self.get_base_prompt(exercise_names).replace("'", "''")
        
        bad_response = True
        attempt = 0
        max_attempts = 5

        while bad_response and attempt < max_attempts:
            try:
                escalated_prompt = escalate_prompt(self.get_base_prompt(exercise_names), attempt)
                
                # leveraging Databricks foundation model (free tier...)
                response = str(gen_text_databricks([escalated_prompt], temperature=0.1, max_new_tokens=5000, use_template=True))
                response = response.replace("```json", "").replace("```", "").strip()
                
                repaired = repair_json(response)
                parsed = ast.literal_eval(json.loads(repaired)[0])
                validate_json_structure(parsed)
                
                print("Success:")
                bad_response = False
                
                # Create mapping from variations to standardized names
                mapping = {}
                for group in parsed["groups"]:
                    try:
                        standardized = group["standardized_name"]
                    except: # plan for the LLM to hallucinate standardized name 
                        standardized = group["standard_name"]
                    for variation in group["variations"]:
                        mapping[variation] = standardized
                
                return mapping
                
            except Exception as e:
                print(f"\nAttempt {attempt + 1} failed.")
                print("Error:", str(e))
                print("Response:", response[:500] if 'response' in locals() else "No response")
                attempt += 1
        
        # If all attempts failed, raise the last exception
        raise Exception(f"LLM validation failed after {max_attempts} attempts")

    def smart_clustering(self, exercise_names: List[str], max_cluster_size: int = 25) -> Dict[str, str]:
        """
        Perform intelligent clustering with balanced cluster sizes
        """
        print(f"Starting smart clustering for {len(exercise_names)} exercises...")
        
        # Step 1: Pre-cluster by equipment/movement to create balanced groups
        equipment_groups = self.feature_based_pre_clustering(exercise_names)
        print(f"Pre-clustering created {len(equipment_groups)} equipment/movement groups")
        
        final_mapping = {}
        
        # Step 2: Process each equipment group separately
        for group_name, group_exercises in equipment_groups.items():
            print(f"Processing {group_name} group with {len(group_exercises)} exercises")
            
            if len(group_exercises) == 1:
                # Single exercise - no clustering needed
                final_mapping[group_exercises[0]] = group_exercises[0]
                continue
            
            # Step 3: Balanced clustering within each group
            balanced_clusters = self.balanced_clustering_within_group(group_exercises, max_cluster_size)
            
            # Step 4: LLM validation for each balanced cluster
            for cluster in balanced_clusters:
                print(f"  Processing cluster with {len(cluster)} exercises")
                
                if len(cluster) == 1:
                    final_mapping[cluster[0]] = cluster[0]
                else:
                    try:
                        cluster_mapping = self.llm_validate_cluster(cluster)
                        final_mapping.update(cluster_mapping)
                    except Exception as e:
                        print(f"  Warning: LLM validation failed for cluster, treating as individual exercises")
                        # Fallback: treat each as individual exercise
                        for exercise in cluster:
                            final_mapping[exercise] = exercise
        
        print(f"Smart clustering complete. Processed {len(final_mapping)} exercises")
        return final_mapping

    def standardize_exercise_dataset(self, exercise_names: List[str]) -> Tuple[Dict[str, str], pd.DataFrame]:
        """
        Complete standardization pipeline
        """
        print("Starting exercise standardization pipeline...")
        
        # Remove obvious invalid exercises
        valid_exercises = [name for name in exercise_names 
                          if not any(invalid in name.upper() 
                                   for invalid in ["REST", "JOGGING", "WALKING", "CYCLING"])]
        
        print(f"Filtered out {len(exercise_names) - len(valid_exercises)} invalid exercises")
        
        # Perform smart clustering
        mapping = self.smart_clustering(valid_exercises)
        
        # Generate summary
        standardized_names = list(set(mapping.values()))
        
        for k, v in mapping.items():
            tokens = word_tokenize(v.lower())
            tagged = pos_tag(tokens)
            lemmatized = [
                lemmatizer.lemmatize(word, get_wordnet_pos(pos))
                for word, pos in tagged
                if word.upper() not in STOPWORDS]
            mapping[k] = ' '.join([t.upper() for t in lemmatized]).replace('-', ' ') # update mapping with lemmatized version
        
        # Group by standardized name for review
        groups = {}
        for original, standardized in mapping.items():
            if standardized not in groups:
                groups[standardized] = []
            groups[standardized].append(original)
        
        # Create DataFrame for review
        review_data = []
        for standardized, originals in groups.items():
        
            review_data.append({
                "standardized_name": standardized,
                "original_count": len(originals),
                "original_names": " | ".join(originals)})
        
        review_df = pd.DataFrame(review_data).sort_values("original_count", ascending=False)
        
        print(f"Standardization complete:")
        print(f"  Original exercises: {len(exercise_names)}")
        print(f"  Valid exercises: {len(valid_exercises)}")
        print(f"  Standardized exercises: {len(standardized_names)}")
        print(f"  Reduction: {((len(valid_exercises) - len(standardized_names)) / len(valid_exercises) * 100):.1f}%")
        
        return mapping, review_df

# implementation
def apply_standardization(df, exercise_column="name"):
    """
    Apply standardization to your exercise DataFrame using broadcast join
    """
    
    standardizer = AdvancedExerciseStandardizer()
    unique_exercises = [row.name for row in df.select("name").distinct().collect() if row.name is not None]
    mapping, review_df = standardizer.standardize_exercise_dataset(unique_exercises)
    
    print("Type of mapping:", type(mapping))
    print("First few items:", list(mapping.items())[:5] if hasattr(mapping, "items") else mapping.head())
    
    # mapping df containing original name and standardized name
    if isinstance(mapping, dict):
        spark = SparkSession.getActiveSession()
        mapping_data = [(k, v) for k, v in mapping.items()]
        mapping_df = spark.createDataFrame(mapping_data, ["name", "standardized_name"])
        
        # can safely use a broadcast join here, mapping_df is quite small (< 500 records)
        df = df.join(broadcast(mapping_df), on="name", how="left")
    else:
        df = df.join(broadcast(mapping), on="name", how="left")
    
    # drop records where a standardized name was not successfully generated
    df = df.dropna(subset=["standardized_name"])
    
    return df, mapping, review_df

# apply gen AI standardization of exercise names
standardized_df, name_mapping, review_report = apply_standardization(exercise_df)
display(standardized_df.select("name", "standardized_name").sort("standardized_name"))