In [0]:
# imports
import requests
import time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, concat_ws, regexp_replace, to_date, trim, lit, when, length, udf, broadcast
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, FloatType, ArrayType
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from functools import partial
from collections import defaultdict
import sys
import traceback
import json
from langdetect import detect
import re
import os
import numpy as np
import hdbscan
from sentence_transformers import SentenceTransformer
from sklearn.cluster import DBSCAN
from sklearn.cluster import AgglomerativeClustering
from sklearn.preprocessing import normalize
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import pairwise_distances
from scipy.spatial.distance import squareform
from typing import List, Dict, Tuple, Set, Union
import pandas as pd
import ast
from json_repair import repair_json
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet
from nltk import pos_tag, word_tokenize, punkt
from nltk.tokenize import word_tokenize
import nltk 

def get_exercises(api_url):
    """Retrieves complete list of exercises from the wger public API"""
    all_exercises = []
    current_url = api_url

    try:
        while current_url:
            response = requests.get(current_url)
            response.raise_for_status()
            data = response.json()
            all_exercises.extend(data["results"])
            current_url = data.get("next")
        return all_exercises
    except Exception as e:
        print("error fetching exercise data from wger API:", e)

def save_data(df, table_name: str, mode: str = "overwrite"):
    """save pyspark df to delta table."""
    writer = df.write.format("delta").mode(mode)
    if mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")
    writer.saveAsTable(table_name)

def is_english_content(description):
    """check if description is actually in English"""
    try:
        # check description language, sampling first 100 chars 
        if description and len(description.strip()) > 0:
            desc_sample = description[:100]
            desc_lang = detect(desc_sample)
            if desc_lang != "en":
                return False
        
        return True
    except Exception as e:
        # if detection fails, assume not english
        return False
    
def normalize_record(record):
    """Normalize a record dictionary by extracting and formatting specific fields."""
    # extract name and description from translations with language == 2 (english)
    translations = record.get("translations", [])
    name = None
    description = None
    
    for translation in translations:
        if translation.get("language") == 2:
            candidate_name = translation.get("name")
            candidate_description = translation.get("description")
            
            # verify the content is english
            if is_english_content(candidate_description):
                name = candidate_name
                description = candidate_description
                break
    
    # return None if no English content found
    # this is a data quality fix;
    # there is a known issue in wger API that results in non-english records despite language = 2 filtering
    if not name:
        return None
    
    # data cleaning step: remove HTML tags from name and description
    if name:
        name = re.sub(r"<[^>]*>", "", name).strip()
    if description:
        description = re.sub(r"<[^>]*>", "", description).strip()
    
    # parse datetime fields with proper error handling
    created = None
    if record.get("created"):
        created = datetime.fromisoformat(record["created"].replace("Z", "+00:00"))
    
    last_update = None
    if record.get("last_update"):
        last_update = datetime.fromisoformat(
            record["last_update"].replace("Z", "+00:00"))
    
    # generate normalized dict structure
    normalized = {
        "id": record.get("id"),
        "uuid": record.get("uuid"),
        "name": name.upper(),
        "description": description,
        "created": created,
        "last_update": last_update,
        "category": record.get("category")["name"],
        "muscles": [muscle["name"] for muscle in record.get("muscles", [])],
        "muscles_secondary": [muscle["name"] for muscle in record.get("muscles_secondary", [])],
        "equipment": [equip["name"] for equip in record.get("equipment", [])],
        "variations": record.get("variations") or [],
        "license_author": record.get("license_author")}
    
    return normalized

def query_databricks_foundation_model(prompt, model_name="databricks-llama-4-maverick", max_tokens=512, temperature=0.5):
    """
    Query a Databricks foundation model through the serving endpoint
    """
    
    # construct the API endpoint URL
    api_url = f"https://{DATABRICKS_INSTANCE}/serving-endpoints/{model_name}/invocations"
    
    headers = {
        "Authorization": f"Bearer {DATABRICKS_TOKEN}",
        "Content-Type": "application/json"}
    
    # chat payload
    payload_messages = {
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # prompt payload
    payload_prompt = {
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # input payload
    payload_input = {
        "input": prompt,
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    # attempt each format above
    for payload_name, payload in [("messages", payload_messages), ("prompt", payload_prompt), ("input", payload_input)]:
        try:
            print(f"Trying payload format: {payload_name}")
            response = requests.post(api_url, headers=headers, json=payload)
            response.raise_for_status()
            print(f"Success with {payload_name} format!")
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"Failed with {payload_name} format: {e}")
            if hasattr(e, "response") and e.response is not None:
                print(f"Response content: {e.response.text}")
            continue
    
    print("all payload formats failed")
    return None

# Alternative approach using the newer Databricks SDK
def query_with_databricks_sdk(prompt, model_name="databricks-llama-4-maverick", max_tokens=5000, temperature=0.5):
    """
    Alternative approach using Databricks SDK (if available)
    """
    try:
        from databricks.sdk import WorkspaceClient
        
        w = WorkspaceClient()
        
        # Format the request properly for the SDK
        request_data = {
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "max_tokens": max_tokens,
            "temperature": temperature
        }
        
        response = w.serving_endpoints.query(
            name=model_name,
            **request_data
        )
        
        return response
    except ImportError:
        print("Databricks SDK not available, use the API approach instead")
        return None
    except Exception as e:
        print(f"Error with SDK approach: {e}")
        return None

def format_prompt(instruction, system_prompt=None):
    """
    Format prompt similar to your original template
    Based on documentation here:
    https://github.com/databricks/databricks-ml-examples/blob/master/llm-models/mistral/mistral-7b/01_load_inference.py
    """
    if system_prompt is None:
        system_prompt = """You are a JSON-only exercise standardization assistant. 
        ALWAYS respond with valid JSON in this exact format:
        {"groups":[{"standardized_name":"NAME","variations":["var1","var2"]}]}
        Never include explanations, markdown, misspellings, or extra text."""
    
    formatted_prompt = f"""[INST]<>{system_prompt}<>{instruction}[/INST]"""
    
    return formatted_prompt

# Example usage
def gen_text_databricks(prompts, use_template=True, **kwargs):
    """
    Generate text using Databricks foundation model
    """
    results = []
    
    for prompt in prompts:
        if use_template:
            formatted_prompt = format_prompt(prompt)
        else:
            formatted_prompt = prompt
            
        # First try the API approach
        response = query_databricks_foundation_model(
            formatted_prompt, 
            max_tokens=kwargs.get("max_new_tokens", 5000),
            temperature=kwargs.get("temperature", 0.5)
        )
        
        if response:
            # Handle different response formats
            if "choices" in response and response["choices"]:
                # OpenAI-style response
                content = response["choices"][0].get("message", {}).get("content", "") or response["choices"][0].get("text", "")
                results.append(content)
            elif "predictions" in response:
                # MLflow-style response
                results.append(response["predictions"][0].get("generated_text", ""))
            elif "candidates" in response:
                # Gemini-style response
                results.append(response["candidates"][0].get("content", {}).get("parts", [{}])[0].get("text", ""))
            else:
                # Try to extract any text from the response
                results.append(str(response))
        else:
            # Try SDK approach as fallback
            sdk_response = query_with_databricks_sdk(
                formatted_prompt,
                max_tokens=kwargs.get("max_new_tokens", 5000),
                temperature=kwargs.get("temperature", 0.5)
            )
            if sdk_response:
                results.append(str(sdk_response))
            else:
                results.append("Error: Could not get response from model")
    
    return results


DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_INSTANCE = spark.conf.get("spark.databricks.workspaceUrl")

# documentation link: https://exercise.hellogym.io/nl/software/api
# the exerciseinfo contains denormalized data with nested
wger_api_url = "https://wger.de/api/v2/exerciseinfo/?status=2&language=2"
exercises = get_exercises(wger_api_url)

ex_schema = StructType([
    StructField("id", IntegerType()),
    StructField("uuid", StringType()), 
    StructField("name", StringType()), 
    StructField("description", StringType()),
    StructField("created", TimestampType()),
    StructField("last_update", TimestampType()),
    StructField("category", StringType()),
    StructField("muscles",  StringType()),
    StructField("muscles_secondary",  StringType()),
    StructField("equipment",  StringType()),
    StructField("variations", StringType()),
    StructField("license_author", StringType())])
    
# filter out None values for data quality
exercise_abbv = [normalize_record(record) for record in exercises]
exercise_abbv = [record for record in exercise_abbv if record is not None]  # Remove non-English records

try:
    exercise_df = spark.createDataFrame(exercise_abbv, schema=ex_schema)
    exercise_df = exercise_df.drop("uuid", "created", "last_update", "license_author")
    display(exercise_df.sort("name"))
except Exception as e:
    print("error creating dataframe:", e)
    traceback.print_exc()

In [0]:

nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('wordnet')

# The approach below leverages clustering prior for scoping of exercises into pre-defined areas prior to LLM prompting.
# This is to reduce ambiguity and provide a tightened group of similar exercise names to improve chances 
#   of converging on the correct atomic standardized naming convention.
# The pre-clustering approach reduces tokens and cost by minimizing the number of tokens sent to the LLM.
# Additionally, the code below integrates subject matter knowledge related to equipment, movement patterns,
#   and muscle groups to improve the clustering and provide context to the LLM
#Smaller, semantically meaningful groups make it easier to manually validate or improve LLM results if needed.

lemmatizer = WordNetLemmatizer()
STOPWORDS = {'WITH', 'AND', 'ON', 'INTO', 'IN', 'AT', 'BY', 'OF', 'FOR', 'TO',
    'FROM', 'USING', 'THE', 'A', 'AN', 'ONTO'}

def validate_json_structure(result):
    """Validate JSON structure by checking for required keys and data structures/types"""
    if not isinstance(result, dict):
        raise ValueError("Result is not a dictionary")
    
    if "groups" not in result:
        raise ValueError("Missing 'groups' key")
    
    if not isinstance(result["groups"], list):
        raise ValueError("'groups' is not a list")
    
    for i, group in enumerate(result["groups"]):
        if not isinstance(group, dict):
            raise ValueError(f"Group {i} is not a dictionary")
        
        # Check for standardized name with multiple possible keys
        if not any(key in group for key in ["standardized_name", "standard_name"]):
            raise ValueError(f"Group {i} missing standardized name key")
        
        if "variations" not in group:
            raise ValueError(f"Group {i} missing 'variations' key")
        
        if not isinstance(group["variations"], list):
            raise ValueError(f"Group {i} 'variations' is not a list")
    
    return True

def escalate_prompt(base_prompt, level, feedback=None):
    """Logic to modify prompts that fail validation logic.
    Response augmentation varies by failure type, attempt number, and feedback."""
    escalated = base_prompt
    if feedback:
        escalated += f"\nNOTE: Your previous output failed validation because: {feedback}"
        escalated += "\nCarefully reprocess the input and correct this mistake."
    if level == 1:
        escalated += "\nSTRICTLY enforce JSON compliance. Only output a parsable JSON object."
    elif level == 2:
        escalated += "\nYour previous response was invalid. Ensure the response is a clean JSON object with no extra formatting or text."
    elif level >= 3:
        escalated += "\nFinal warning: Return ONLY a valid JSON object. No natural language, no markdown, no surrounding explanations."
    return escalated

def get_wordnet_pos(tag):
    """Uses nltk wordnet to convert nltk pos tags to wordnet pos tags"""
    if tag.startswith('J'):
        return wordnet.ADJ
    elif tag.startswith('V'):
        return wordnet.VERB
    elif tag.startswith('N'):
        return wordnet.NOUN
    elif tag.startswith('R'):
        return wordnet.ADV
    return wordnet.NOUN

class AdvancedExerciseStandardizer:
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        """
        Initialize with a sentence transformer model optimized for semantic similarity
        """
        self.model = SentenceTransformer(model_name)
        
    def preprocess_exercise_name(self, name: str) -> str:
        """
        Preprocess exercise name with lemmatization and stopword removal.
        Fulfills requirement to eliminate inconsistencies, standardize format, and reduce noise prior to clustering.
        """
        name = name.upper()
        name = re.sub(r'[-_/\\|]+', ' ', name)
        name = re.sub(r'[^\w\s]', '', name)
        name = re.sub(r'\s+', ' ', name).strip()
        tokens = word_tokenize(name.lower())
        tagged = pos_tag(tokens)
        lemmatized = [
            lemmatizer.lemmatize(word, get_wordnet_pos(pos))
            for word, pos in tagged
            if word.upper() not in STOPWORDS]

        return ' '.join([t.upper() for t in lemmatized])

    def extract_exercise_features(self, exercise_name: str) -> Dict:
        """
        Extract key features that distinguish exercises.
        The purpose of this function is to provide context to the clustering algorithm by 
            capturing domain knowledge and facilitating more accurate clustering.
        """
        features = {}
        name_upper = exercise_name.upper()
        
        # Find equipment in the name
        equipment_found = set()
        equipment_list = [
            "BARBELL", "DUMBBELL", "KETTLEBELL", 
            "EZ BAR", "EZ-BAR", "EZBAR", "T-BAR", "T BAR", "HEX", "HEX BAR",
            "CABLE",  "ROPE", "V-BAR",
            "HANDLE", "SINGLE HANDLE", "SMITH",
            "PREACHER", "HACK","TRX", "BAG",
            "BAND", "MINI BAND", "THERABAND",
            "BOX", "BALL"]
        
        for equip in equipment_list:
            if equip in name_upper:
                equipment_found.add(equip)

        # Find movement patterns
        movement_found = set()
        movement_list = [
            "SQUAT", "DEADLIFT", "LUNGE", "STEP-UP",
            "PUSH", "PRESS", "ROW",
            "DIP", "CURL", "EXTENSION", "RAISE", "FLY", "FLYE",
            "PULLOVER", "SHRUG",
            "CRUNCH", "PLANK", "HOLD", "BRIDGE",  "BEND",
            "SPRINT", "KICKBACK",
            "THRUST", "TAP", "THROWS",
            "CLIMB", "CARRY", "WALK", "THROW"]
        
        for movement in movement_list:
            if movement in name_upper:
                movement_found.add(movement)
        
        # Find body parts
        body_parts_found = set()
        body_parts_list = [
            "NECK", "TRAP", "SHOULDER", "DELTOID", "DELT",
            "BICEP", "BI", "TRICEP", "TRI",
            "FOREARM", "CHEST", "PECTORAL", "PEC",
            "BACK", "LAT", "UPPER BACK", "LOWER BACK",
            "ABS", "ABDOMINALS", "AB", "OBLIQUE", "CORE",
            "GLUTE", "BUTT", "HIP",  "HIP FLEXOR",
            "ADDUCTOR", "ABDUCTOR", "QUAD", "QUADRICEP", "HAMSTRING", "HAM",
            "CALF", "LEG", "ARM", "CHIN"]
        
        for body_part in body_parts_list:
            if body_part in name_upper:
                body_parts_found.add(body_part)

        # Find modifiers
        modifiers_found = set()
        modifiers_list = [
            "WIDE", "WIDE-GRIP", "NARROW", "NARROW-GRIP", "CLOSE", "CLOSE-GRIP", "CROSSBODY", "CROSS-BODY",
            "HAMMER", "NEUTRAL-GRIP", "OVERHAND", "UNDERHAND", "SUPINATED", "PRONATED", 
            "REVERSE", "MIXED-GRIP", "SINGLE-ARM", "ONE-ARM", "ONE-ARMED","SINGLE-LEG", "ONE-LEG", "ONE-LEGGED",
            "UNILATERAL", "ALTERNATING", "ALTERNATE","ISOLATED", "ISOLATION","CROSSOVER", "CROSS-BODY","BILATERAL", 
            "DOUBLE-ARM", "DOUBLE-LEG", "SUMO", "WIDE-STANCE", "NARROW-STANCE", "CLOSE-STANCE", "SPLIT", "BULGARIAN",
            "STAGGERED","STEP-BACK", "STEP-UP","LATERAL", "SIDE", "SIDE-STEP","CROSSOVER", "REVERSE", "INCLINE", 
            "DECLINE", "FLAT","VERTICAL", "HORIZONTAL","DIAGONAL", "ROTATIONAL", "TWISTING", "STANDING", "SEATED", 
            "SITTING", "LYING", "PRONE", "SUPINE", "SIDE-LYING", "KNEELING", "HALF-KNEELING", "LEANING", 
            "FORWARD-LEANING", "PARTIAL", "FULL", "HALF", "HOLD", "ISOMETRIC", "ISO", "STATIC-HOLD","LATERAL"]
        
        for modifier in modifiers_list:
            if modifier in name_upper:
                modifiers_found.add(modifier)
        
        return {
            "equipment": equipment_found,
            "movement": movement_found,
            "body_parts": body_parts_found,
            "modifiers": modifiers_found,
            "raw_name": exercise_name}

    def create_feature_aware_embeddings(self, exercise_names: List[str]) -> np.ndarray:
        """
        Create embeddings that emphasize critical distinguishing features (based on equipment,
            movement pattern, body part, and key modifiers). The function adds weights by 
            importance for the clustering algorithm, emphasizing equipment and movement patterns.
            For example, Barbell Press and Barbell Bench Press are more similar than Barbell Press 
            and Barbell Curl. Similarly, Wide Grip Pull Up and Wide Grip Pullup are more similar than 
            Wide Grip Pull Up and Narrow Grip Pull Up. This function is used to create embeddings for the exercise names.
        """
        enhanced_texts = []
        
        for name in exercise_names:
            # First, preprocess the name (lemmatize, remove stopwords)
            processed_name = self.preprocess_exercise_name(name)
            features = self.extract_exercise_features(name)
            
            # Create enhanced text that emphasizes critical features
            enhanced_parts = [processed_name]  # Use processed name
            
            # Equipment gets highest weight (most distinguishing) - 4x
            for equip in features["equipment"]:
                enhanced_parts.extend([equip] * 4)
            
            # Movement gets medium weight - 3x  
            for movement in features["movement"]:
                enhanced_parts.extend([movement] * 3)
            
            # Modifiers get medium weight - 2x
            for modifier in features["modifiers"]:
                enhanced_parts.extend([modifier] * 2)
            
            # Body parts get lower weight - 1x
            for body_part in features["body_parts"]:
                enhanced_parts.append(body_part)
            
            enhanced_texts.append(" ".join(enhanced_parts))
        
        return self.model.encode(enhanced_texts)

    def feature_based_pre_clustering(self, exercise_names: List[str]) -> Dict[str, List[str]]:
        """
        Pre-cluster exercises by equipment type to create more balanced groups. 
        """
        equipment_groups = {}
        
        for name in exercise_names:
            features = self.extract_exercise_features(name)
            
            # Create equipment signature (primary equipment)
            if features["equipment"]:
                # Prioritize more specific equipment
                equipment_priority = [
                    "BARBELL", "DUMBBELL", "KETTLEBELL", "EZ BAR", "EZ-BAR", "EZBAR",
                    "T-BAR", "T BAR", "HEX BAR", "CABLE", "MACHINE", "SMITH MACHINE"]
                
                primary_equipment = None
                for equip in equipment_priority:
                    if equip in features["equipment"]:
                        primary_equipment = equip
                        break
                
                if not primary_equipment:
                    primary_equipment = next(iter(features["equipment"]))
            else:
                primary_equipment = "BODYWEIGHT"
            
            # Add movement pattern for finer grouping
            if features["movement"]:
                primary_movement = next(iter(features["movement"]))
                group_key = f"{primary_equipment}_{primary_movement}"
            else:
                group_key = primary_equipment
            
            if group_key not in equipment_groups:
                equipment_groups[group_key] = []
            equipment_groups[group_key].append(name)
        
        return equipment_groups

    def hdbscan_clustering_within_group(self, exercise_names: List[str], 
                                        min_cluster_size: int = 2, min_samples: int = None) -> List[List[str]]:
        """
        Perform clustering using HDBSCAN with cosine distance.
        Automatically determines number of clusters and filters noise.
        """

        if len(exercise_names) <= min_cluster_size:
            return [exercise_names]

        # Create normalized embeddings
        embeddings = self.create_feature_aware_embeddings(exercise_names)
        normalized_embeddings = normalize(embeddings, norm='l2')
        clusterer = hdbscan.HDBSCAN(
            metric='euclidean',  
            min_cluster_size=min_cluster_size,
            min_samples=min_samples or min_cluster_size,
            cluster_selection_method='eom'
        )
        labels = clusterer.fit_predict(normalized_embeddings)

        # Group items by cluster label, skip noise (-1)
        label_to_names = defaultdict(list)
        for name, label in zip(exercise_names, labels):
            if label != -1:
                label_to_names[label].append(name)

        return list(label_to_names.values())

    def get_base_prompt(self, exercise_names):

        prompt = f"""
        You are a highly accurate exercise classification system.

        Your task is to analyze a list of exercise names and group all equivalent variations under a single standardized name, following strict disambiguation rules. If two exercises differ in function, form, or equipment, they must remain separate.

        OBJECTIVE:
        - Group only truly equivalent exercises together by movement pattern, equipment, grip, and body position.
        - Assign each group a clean, atomic, and standardized name.
        - List all original input names that belong to each group.

        DISAMBIGUATION RULES (strictly enforce):
        1. Different equipment = DIFFERENT exercises (e.g., Barbell ≠ Dumbbell ≠ EZ Bar).
        2. Different grips or hand positions (e.g., Wide, Narrow, Neutral, Close, Reverse) = DIFFERENT.
        3. Different body positions or angles (e.g., Incline, Decline, Overhead, Seated, Standing, Lying) = DIFFERENT.
        4. Different movement patterns = DIFFERENT (e.g., “Clean” ≠ “Clean and Jerk”).
        5. Differences in spelling, punctuation, plurality, or casing = SAME.
        6. Left/right or unilateral/bilateral indicators = SAME.
        7. DO NOT use hyphens in standardized names (e.g., use "Pull Up", not "Pull-Up").
        8. NEVER use acronyms or abbreviations (e.g., DB, BB, OHP) in standardized names.
        9. NEVER use commas or parentheses in standardized names.
        10. Expand abbreviations in variations, but use full terms in standardized names.
        11. Do NOT generalize specific names into broader categories (e.g., “Bird Dog” ≠ “Arabesque”).
        12. Do NOT drop disambiguating words (e.g., “Close-Grip Lat Pull Down” ≠ “Lat Pull Down”; “Reverse Preacher Curl” ≠ “Preacher Curl”).
        13. IF equipment name is included, ALWAYS included it at the beginning of the name (e.g., use "Barbell Curl", NOT "Curl - Barbell")
        14. NEVER add an equipment name to an exercise without one. For example, NEVER standardize "Pullover" to "Dumbbell Pullover".

        ADDITIONAL RULES:
        - When unsure, **keep exercises separate**. It is better to under-group than over-group.
        - Variations with minor typos should be grouped (e.g., “Low Pulley Cable Ffly” = “Low Pulley Cable Fly”).

        OUTPUT RULES:
        - Output must be valid **JSON only**. No extra text.
        - Each standardized name must be unique, readable, and not overly generic.
        - Do not merge distinct exercises into one name.
        - Ensure casing and spacing are consistent (Title Case).
        - Avoid merged tokens (e.g., no “BENTOVERROWTOEXTERNALROTATION”).
        - ALWAYS adhere to the disambiguation rules.

        INPUT:
        Exercise names: {exercise_names}

        OUTPUT FORMAT:
        {{
        "groups": [
            {{
            "standardized_name": "STANDARDIZED EXERCISE NAME",
            "variations": ["Original Name 1", "Original Name 2", ...]
            }},
            ...
        ]
        }}

        EXAMPLE:
        {{
        "groups": [
            {{
            "standardized_name": "BARBELL BICEP CURL",
            "variations": ["Barbell Biceps Curl", "Bar Bell Bicep Curls"]
            }},
            {{
            "standardized_name": "EZ BAR BICEP CURL",
            "variations": ["EZ Bar Biceps Curl", "Ez-Bar Bicep Curl"]
            }}
        ]
        }}
        """

        return prompt

    def enforce_disambiguation_rules(self, parsed_json) -> Union[bool, str]:
        """Validate that the parsed JSON adheres to disambiguation rules.
        For example, check that different equipment or movements are not referenced in the matched group.
        Specifically, Dumbbell Bench Press and Barbell Bench Press should not be grouped.
        The result of a rule violation is early termination of the LLM cluster validation and
            default to the individual exercise name."""
        for group in parsed_json["groups"]:
            variations = group["variations"]
            standard_name = group.get("standardized_name") or group.get("standard_name")

            # Collect feature sets
            equipment_signatures = []
            grip_signatures = []
            position_signatures = []
            movement_signatures = []

            for v in variations:
                features = self.extract_exercise_features(v)
                equipment_signatures.extend(features.get("equipment", []))
                grip_signatures.extend(features.get("grip", []))
                position_signatures.extend(features.get("position", []))
                movement_signatures.extend(features.get("movement", []))

            # Equipment rule
            if len(set(equipment_signatures)) > 1:
                return f"Equipment mismatch in group '{standard_name}': {set(equipment_signatures)}"

            # Grip rule
            if len(set(grip_signatures)) > 1:
                return f"Grip mismatch in group '{standard_name}': {set(grip_signatures)}"

            # Position rule
            if len(set(position_signatures)) > 1:
                return f"Body position mismatch in group '{standard_name}': {set(position_signatures)}"

            # Movement pattern rule
            if len(set(movement_signatures)) > 1:
                return f"Movement pattern mismatch in group '{standard_name}': {set(movement_signatures)}"

        return True

    def llm_validate_cluster(self, exercise_names: List[str]) -> Dict[str, str]:
        """
        use LLM to validate and standardize a cluster of exercise names
        """
        prompt = self.get_base_prompt(exercise_names).replace("'", "''")
        
        bad_response = True
        attempt = 0
        max_attempts = 5
        feedback=None

        while bad_response and attempt < max_attempts:
            try:
                escalated_prompt = escalate_prompt(self.get_base_prompt(exercise_names), attempt)
                escalated_prompt = escalate_prompt(self.get_base_prompt(exercise_names), attempt, feedback=feedback)
                
                # leveraging Databricks foundation model (free tier...)
                response = str(gen_text_databricks([escalated_prompt], temperature=0.1, max_new_tokens=5000, use_template=True))
                response = response.replace("```json", "").replace("```", "").strip()
                
                repaired = repair_json(response)
                parsed = ast.literal_eval(json.loads(repaired)[0])
                validate_json_structure(parsed)
                
                validation_result = self.enforce_disambiguation_rules(parsed)
                if validation_result is not True:
                    feedback = validation_result
                    print("Invalid group detected:", feedback)
                    print('invalid PARSED:', parsed)
                    attempt += 1
                    continue
                else:
                    bad_response = False
                                
                print("Success:")
                bad_response = False
                
                # Create mapping from variations to standardized names
                mapping = {}
                for group in parsed["groups"]:
                    try:
                        standardized = group["standardized_name"]
                    except: # plan for the LLM to hallucinate standardized name 
                        standardized = group["standard_name"]
                    for variation in group["variations"]:
                        mapping[variation] = standardized
                
                return mapping
                
            except Exception as e:
                print(f"\nAttempt {attempt + 1} failed.")
                print("Error:", str(e))
                print("Response:", response[:500] if 'response' in locals() else "No response")
                attempt += 1
        
        # If all attempts failed, raise the last exception
        raise Exception(f"LLM validation failed after {max_attempts} attempts")

    def smart_clustering(self, exercise_names: List[str], min_cluster_size: int = 2) -> Dict[str, str]:
        """
        Perform intelligent clustering with balanced cluster sizes
        """
        print(f"Starting smart clustering for {len(exercise_names)} exercises...")
        
        # Step 1: Pre-cluster by equipment/movement to create balanced groups
        equipment_groups = self.feature_based_pre_clustering(exercise_names)
        print(f"Pre-clustering created {len(equipment_groups)} equipment/movement groups")
        
        final_mapping = {}
        
        # Step 2: Process each equipment group separately
        for group_name, group_exercises in equipment_groups.items():
            print(f"Processing {group_name} group with {len(group_exercises)} exercises")
            
            if len(group_exercises) == 1:
                # Single exercise - no clustering needed
                final_mapping[group_exercises[0]] = group_exercises[0]
                continue
            
            # Step 3: Balanced clustering within each group
            #balanced_clusters = self.balanced_clustering_within_group(group_exercises, max_cluster_size)
            balanced_clusters = self.hdbscan_clustering_within_group(group_exercises, min_cluster_size)
            # Step 4: LLM validation for each balanced cluster
            for cluster in balanced_clusters:
                print(f"  Processing cluster with {len(cluster)} exercises")
                
                if len(cluster) == 1:
                    final_mapping[cluster[0]] = cluster[0]
                else:
                    try:
                        cluster_mapping = self.llm_validate_cluster(cluster)
                        final_mapping.update(cluster_mapping)
                    except Exception as e:
                        print(f"  Warning: LLM validation failed for cluster, treating as individual exercises")
                        # Fallback: treat each as individual exercise
                        for exercise in cluster:
                            final_mapping[exercise] = exercise
        
        print(f"Smart clustering complete. Processed {len(final_mapping)} exercises")
        return final_mapping

    def standardize_exercise_dataset(self, exercise_names: List[str]) -> Tuple[Dict[str, str], pd.DataFrame]:
        """
        Complete standardization pipeline
        """
        print("Starting exercise standardization pipeline...")
        
        # Remove obvious invalid exercises
        valid_exercises = [name for name in exercise_names 
                          if not any(invalid in name.upper() 
                                   for invalid in ["REST", "JOGGING", "WALKING", "CYCLING"])]

        # cleaning known undesired acronyms
        valid_exercises = [name.replace('NP', '') for name in valid_exercises]
        
        print(f"Filtered out {len(exercise_names) - len(valid_exercises)} invalid exercises")
        
        # Perform smart clustering
        mapping = self.smart_clustering(valid_exercises)
        
        # Generate summary
        standardized_names = list(set(mapping.values()))
        
        for k, v in mapping.items():
            tokens = word_tokenize(v.lower())
            tagged = pos_tag(tokens)
            lemmatized = [
                lemmatizer.lemmatize(word, get_wordnet_pos(pos))
                for word, pos in tagged
                if word.upper() not in STOPWORDS]
            mapping[k] = ' '.join([t.upper() for t in lemmatized]).replace('-', ' ') # update mapping with lemmatized version
        
        # Group by standardized name for review
        groups = {}
        for original, standardized in mapping.items():
            if standardized not in groups:
                groups[standardized] = []
            groups[standardized].append(original)
        
        # Create DataFrame for review
        review_data = []
        for standardized, originals in groups.items():
        
            review_data.append({
                "standardized_name": standardized,
                "original_count": len(originals),
                "original_names": " | ".join(originals)})
        
        review_df = pd.DataFrame(review_data).sort_values("original_count", ascending=False)
        
        print(f"Standardization complete:")
        print(f"  Original exercises: {len(exercise_names)}")
        print(f"  Valid exercises: {len(valid_exercises)}")
        print(f"  Standardized exercises: {len(standardized_names)}")
        print(f"  Reduction: {((len(valid_exercises) - len(standardized_names)) / len(valid_exercises) * 100):.1f}%")
        
        return mapping, review_df

# implementation
def apply_standardization(df, exercise_column="name"):
    """
    Apply standardization to your exercise DataFrame using broadcast join
    """
    
    standardizer = AdvancedExerciseStandardizer()
    unique_exercises = [row.name for row in df.select("name").distinct().collect() if row.name is not None]
    mapping, review_df = standardizer.standardize_exercise_dataset(unique_exercises)
    
    print("Type of mapping:", type(mapping))
    print("First few items:", list(mapping.items())[:5] if hasattr(mapping, "items") else mapping.head())
    
    # mapping df containing original name and standardized name
    if isinstance(mapping, dict):
        spark = SparkSession.getActiveSession()
        mapping_data = [(k, v) for k, v in mapping.items()]
        mapping_df = spark.createDataFrame(mapping_data, ["name", "standardized_name"])
        
        # can safely use a broadcast join here, mapping_df is quite small (< 500 records)
        df = df.join(broadcast(mapping_df), on="name", how="left")
    else:
        df = df.join(broadcast(mapping), on="name", how="left")
    
    # drop records where a standardized name was not successfully generated
    df = df.dropna(subset=["standardized_name"])
    
    return df, mapping, review_df

def apply_standardization_with_category_validation(df, exercise_column="name", category_column="category"):
    """
    Apply standardization using category information to prevent erroneous cross-category matches
    while still allowing legitimate matches across inconsistent category assignments.
    """
    
    standardizer = AdvancedExerciseStandardizer()
    
    # Get unique exercises with their category assignments
    unique_exercises_with_categories = (df.select(exercise_column, category_column)
                                       .distinct()
                                       .filter(col(exercise_column).isNotNull() & 
                                              col(category_column).isNotNull())
                                       .collect())
    
    # Create exercise-to-categories mapping (some exercises may appear in multiple categories)
    exercise_to_categories = {}
    for row in unique_exercises_with_categories:
        exercise_name = row[exercise_column]
        category = row[category_column]
        
        if exercise_name not in exercise_to_categories:
            exercise_to_categories[exercise_name] = set()
        exercise_to_categories[exercise_name].add(category)
    
    # Identify exercises with inconsistent categorization
    inconsistent_exercises = {name: cats for name, cats in exercise_to_categories.items() if len(cats) > 1}
    
    print(f"Found {len(inconsistent_exercises)} exercises with inconsistent categorization:")
    for exercise, categories in list(inconsistent_exercises.items())[:10]:  # Show first 10
        print(f"  '{exercise}': {categories}")
    if len(inconsistent_exercises) > 10:
        print(f"  ... and {len(inconsistent_exercises) - 10} more")
    
    # Get all unique exercise names for clustering
    all_exercises = list(exercise_to_categories.keys())
    
    # Apply standardization to all exercises
    all_mappings, review_df = standardizer.standardize_exercise_dataset(all_exercises)
    
    # Post-process: validate that grouped exercises don't span incompatible categories
    validated_mappings = validate_category_consistency(all_mappings, exercise_to_categories)
    
    # Create mapping DataFrame for join
    spark = SparkSession.getActiveSession()
    mapping_data = [(k, v) for k, v in validated_mappings.items()]
    mapping_df = spark.createDataFrame(mapping_data, [exercise_column, "standardized_name"])
    
    # Join with original DataFrame using broadcast join
    df_standardized = df.join(broadcast(mapping_df), on=exercise_column, how="left")
    
    # Drop records where standardized name was not generated
    df_standardized = df_standardized.dropna(subset=["standardized_name"])
    
    # Print summary statistics
    print(f"\nOverall Summary:")
    print(f"  Total original exercises: {len(validated_mappings)}")
    print(f"  Total standardized exercises: {len(set(validated_mappings.values()))}")
    print(f"  Overall reduction: {((len(validated_mappings) - len(set(validated_mappings.values()))) / len(validated_mappings) * 100):.1f}%")
    
    return df_standardized, validated_mappings, review_df


def validate_category_consistency(mappings, exercise_to_categories):
    """
    Validate that exercises grouped together don't span incompatible categories.
    Split groups if they contain exercises from incompatible categories.
    """
    
    # Define compatible category groups (exercises can match across these)
    compatible_groups = [
        {"ARMS", "CHEST"},
        {"LEGS", "CALVES", "CARDIO"},
        {"CARDIO", "ABS"},
        {"ARMS", "BACK"},
        {"ARMS", "SHOULDERS"},
        {"LEGS", "BACK"},
        {"BACK", "ABS"}]
    
    def categories_are_compatible(cat_set1, cat_set2):
        """Check if two sets of categories are compatible for grouping"""
        # If either exercise appears in multiple categories, check for overlap
        if len(cat_set1) > 1 or len(cat_set2) > 1:
            # If there's any overlap, they're compatible
            if cat_set1 & cat_set2:
                return True
        
        # Check if they belong to the same compatible group
        for group in compatible_groups:
            if (cat_set1 & group) and (cat_set2 & group):
                return True
        
        # Special case: if categories are exactly the same
        if cat_set1 == cat_set2:
            return True
            
        return False
    
    # Group exercises by their standardized names
    standardized_groups = {}
    for original, standardized in mappings.items():
        if standardized not in standardized_groups:
            standardized_groups[standardized] = []
        standardized_groups[standardized].append(original)
    
    validated_mappings = {}
    split_count = 0
    
    for standardized_name, original_exercises in standardized_groups.items():
        if len(original_exercises) == 1:
            # Single exercise, no validation needed
            validated_mappings[original_exercises[0]] = standardized_name
            continue
        
        # Check category compatibility within the group
        exercise_categories = [(ex, exercise_to_categories[ex]) for ex in original_exercises]
        
        # Find compatible subgroups
        compatible_subgroups = []
        
        for exercise, categories in exercise_categories:
            placed = False
            for subgroup in compatible_subgroups:
                # Check if this exercise is compatible with all exercises in the subgroup
                if all(categories_are_compatible(categories, exercise_to_categories[ex]) 
                       for ex in subgroup):
                    subgroup.append(exercise)
                    placed = True
                    break
            
            if not placed:
                compatible_subgroups.append([exercise])
        
        # Create mappings for each compatible subgroup
        if len(compatible_subgroups) == 1:
            # All exercises are compatible
            for exercise in original_exercises:
                validated_mappings[exercise] = standardized_name
        else:
            # Split into multiple groups
            split_count += 1
            print(f"Splitting group '{standardized_name}' into {len(compatible_subgroups)} subgroups due to category incompatibility")
            
            for i, subgroup in enumerate(compatible_subgroups):
                if len(subgroup) == 1:
                    # Single exercise gets its own name
                    validated_mappings[subgroup[0]] = subgroup[0]
                else:
                    # Multiple exercises get modified standardized name
                    subgroup_name = f"{standardized_name}"
                    if i > 0:  # Add suffix for disambiguation
                        subgroup_name += f" V{i+1}"
                    
                    for exercise in subgroup:
                        validated_mappings[exercise] = subgroup_name
    
    if split_count > 0:
        print(f"Split {split_count} groups due to category incompatibility")
    
    return validated_mappings

# apply gen AI standardization of exercise names
standardized_df, name_mapping, review_report = apply_standardization(exercise_df)
display(standardized_df.select("name", "standardized_name").sort("standardized_name"))