In [2]:
import pickle
import chromadb
from chromadb.utils import embedding_functions
import json
import numpy as np
from tqdm import tqdm
from colormath.color_conversions import convert_color
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_diff import delta_e_cie2000
from skimage import color

# Patch for numpy.asscalar removed in newer versions
def patch_asscalar(a):
    return a.item()

setattr(np, "asscalar", patch_asscalar)

# Load training data
with open("hexcolor_vf/train_names.pkl", "rb") as f:
    train_names = pickle.load(f)
    
with open("hexcolor_vf/train_palettes_rgb.pkl", "rb") as f:
    train_palettes_rgb = pickle.load(f)

print(f"Loaded {len(train_names)} training examples")
print(f"Sample name: {train_names[0]}")
print(f"Sample palette: {train_palettes_rgb[0]}")

# Load KD tree for color name lookup
KD_TREE_PATH = "color_kdtree.pkl"
print(f"\nLoading prebuilt KDTree from {KD_TREE_PATH}...")
with open(KD_TREE_PATH, "rb") as f:
    kdtree_data = pickle.load(f)

kdtree = kdtree_data["tree"]
color_ref_metadata = kdtree_data["metadata"]
print(f"KDTree loaded successfully with {len(color_ref_metadata)} color references.")

Loaded 9165 training examples
Sample name: ['it', 'is', 'cold']
Sample palette: [(86, 131, 192), (151, 221, 209), (128, 109, 212), (160, 208, 218), (191, 167, 229)]

Loading prebuilt KDTree from color_kdtree.pkl...
KDTree loaded successfully with 100000 color references.


In [3]:
# Deduplicate training data - keep only first occurrence of duplicate strings
seen_names = {}
unique_names = []
unique_palettes = []

for name_tokens, palette in zip(train_names, train_palettes_rgb):
    # Convert name tokens to string
    name_str = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
    
    # Only add if we haven't seen this name before
    if name_str not in seen_names:
        seen_names[name_str] = True
        unique_names.append(name_tokens)
        unique_palettes.append(palette)

print(f"Original count: {len(train_names)}")
print(f"After deduplication: {len(unique_names)}")
print(f"Duplicates removed: {len(train_names) - len(unique_names)}")

Original count: 9165
After deduplication: 7752
Duplicates removed: 1413


In [4]:
# Split into DB examples (for RAG) and query examples (1500 for JSONL)
# We'll use the last 1500 unique examples as our query examples
# The rest will go into the DB for RAG retrieval

QUERY_SIZE = 1500

if len(unique_names) < QUERY_SIZE:
    raise ValueError(f"Not enough unique examples! Have {len(unique_names)}, need at least {QUERY_SIZE}")

# Split: everything except last 1500 goes to DB, last 1500 are query examples
db_names = unique_names[:-QUERY_SIZE]
db_palettes = unique_palettes[:-QUERY_SIZE]

query_names = unique_names[-QUERY_SIZE:]
query_palettes = unique_palettes[-QUERY_SIZE:]

print(f"DB examples (for RAG): {len(db_names)}")
print(f"Query examples (for JSONL): {len(query_names)}")
assert len(query_names) == QUERY_SIZE

DB examples (for RAG): 6252
Query examples (for JSONL): 1500


In [5]:
# Helper function to convert RGB to HSL
def rgb_to_hsl(rgb):
    """
    Convert RGB tuple (0-255) to HSL format string (H, S%, L%).
    H is in [0, 359], S and L are in [0, 100]%.
    
    Args:
        rgb: Tuple of (R, G, B) values in range 0-255
        
    Returns:
        String in format "(H, S%, L%)"
    """
    # Normalize RGB to 0-1 range
    r, g, b = rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0
    
    # Calculate HSL
    max_c = max(r, g, b)
    min_c = min(r, g, b)
    l = (max_c + min_c) / 2.0
    
    if max_c == min_c:
        h = s = 0.0
    else:
        diff = max_c - min_c
        s = diff / (2.0 - max_c - min_c) if l > 0.5 else diff / (max_c + min_c)
        
        if max_c == r:
            h = (g - b) / diff + (6.0 if g < b else 0.0)
        elif max_c == g:
            h = (b - r) / diff + 2.0
        else:
            h = (r - g) / diff + 4.0
        h /= 6.0
    
    # Convert to standard ranges
    h = int(round(h * 360)) % 360
    s = int(round(s * 100))
    l = int(round(l * 100))
    
    return f"({h}, {s}%, {l}%)"

# Helper function to get color name from RGB using KD tree
def get_color_name_from_rgb(rgb_val):
    """
    Get the nearest color name for an RGB value using the KD tree.
    
    Args:
        rgb_val: RGB tuple (r, g, b) with values 0-255
    
    Returns:
        str: The nearest color name
    """
    # Normalize RGB to 0-1 range
    rgb_norm = np.array([[[c / 255.0 for c in rgb_val]]])
    # Convert to LAB color space
    lab = color.rgb2lab(rgb_norm)[0][0]
    # Find nearest neighbor using KDTree
    dist, nearest_idx = kdtree.query(lab, k=1)
    # Get the color name
    nearest_name, _ = color_ref_metadata[nearest_idx]
    return nearest_name

# Test HSL conversion
test_rgb = (255, 0, 0)
test_hsl = rgb_to_hsl(test_rgb)
print(f"Test RGB {test_rgb} -> HSL {test_hsl}")

Test RGB (255, 0, 0) -> HSL (0, 100%, 50%)


In [6]:
# Create ChromaDB collection with DB examples (not the query examples)
# Store colors in HSL format
CHROMA_PATH = "chroma_db_train_hsl"
COLLECTION_NAME = "train_rag_hsl"
MODEL_NAME = "all-mpnet-base-v2"

print("Initializing ChromaDB for training RAG (HSL format)...")
client = chromadb.PersistentClient(path=CHROMA_PATH)

# Create embedding function
ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL_NAME)

# Delete collection if it exists, then create fresh
try:
    client.delete_collection(name=COLLECTION_NAME)
    print("Deleted existing collection")
except:
    pass

collection = client.create_collection(name=COLLECTION_NAME, embedding_function=ef)
print("Created new collection")

Initializing ChromaDB for training RAG (HSL format)...


  from .autonotebook import tqdm as notebook_tqdm


Created new collection


In [7]:
# Populate ChromaDB with DB examples in HSL format
print("Populating ChromaDB with training examples (HSL format)...")
batch_size = 128

documents = []
ids = []

for idx, (name_tokens, palette_rgb) in enumerate(tqdm(zip(db_names, db_palettes), total=len(db_names))):
    try:
        desc = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
        if not desc or not palette_rgb or len(palette_rgb) != 5:
            continue
        
        # Convert palette to HSL
        hsl_palette = [rgb_to_hsl(rgb) for rgb in palette_rgb]
        name_palette = [get_color_name_from_rgb(rgb) for rgb in palette_rgb]
        palette = []
        for i in range(len(hsl_palette)):
            palette.append(f"Color: {name_palette[i]}, HSL: {hsl_palette[i]}")
        
        # Create document structure
        doc_structure = {
            "description": desc,
            "palette": palette
        }
        
        documents.append(json.dumps(doc_structure))
        ids.append(f"db_{idx}")
        
        # Batch insert
        if len(documents) >= batch_size:
            descriptions = [json.loads(d)["description"] for d in documents]
            batch_embeddings = ef(descriptions)
            collection.add(ids=ids, embeddings=batch_embeddings, documents=documents)
            documents.clear()
            ids.clear()
    
    except Exception as e:
        print(f"Error processing DB entry {idx}: {e}")
        continue

# Add remaining documents
if documents:
    descriptions = [json.loads(d)["description"] for d in documents]
    batch_embeddings = ef(descriptions)
    collection.add(ids=ids, embeddings=batch_embeddings, documents=documents)

print(f"Database creation complete. Total items in DB: {collection.count()}")

Populating ChromaDB with training examples (HSL format)...


100%|██████████| 6252/6252 [00:25<00:00, 242.28it/s]


Database creation complete. Total items in DB: 6252


In [8]:
# Function to calculate palette diversity (from ml/grader/grader.py)
def calculate_palette_diversity(palette_rgb):
    """
    Calculate diversity of a color palette using pairwise CIELAB color distances.
    
    Args:
        palette_rgb: List of RGB tuples [(r1,g1,b1), (r2,g2,b2), ...]
    
    Returns:
        float: Average pairwise distance between colors
    """
    # Convert RGB to LAB color objects
    lab_colors = []
    for rgb in palette_rgb:
        # Create sRGB color object (values should be 0-255)
        rgb_color = sRGBColor(rgb[0], rgb[1], rgb[2], is_upscaled=True)
        # Convert to LAB
        lab_color = convert_color(rgb_color, LabColor)
        lab_colors.append(lab_color)
    
    # Calculate pairwise distances
    pairwise_distance = 0.0
    num_pairs = 0
    N = len(lab_colors)
    
    for i in range(N - 1):
        cur_color = lab_colors[i]
        for j in range(i + 1, N):
            distance = delta_e_cie2000(cur_color, lab_colors[j])
            pairwise_distance += distance
            num_pairs += 1
    
    return pairwise_distance / num_pairs if num_pairs > 0 else 0.0

# Test the function
test_palette = query_palettes[0]
test_diversity = calculate_palette_diversity(test_palette)
print(f"Test palette: {[rgb_to_hsl(rgb) for rgb in test_palette]}")
print(f"Test diversity: {test_diversity:.2f}")

Test palette: ['(30, 65%, 82%)', '(21, 36%, 70%)', '(359, 23%, 49%)', '(20, 25%, 48%)', '(14, 25%, 55%)']
Test diversity: 18.63


In [12]:
SYSTEM_PROMPT_HSL = """
You are an expert Color Theorist and UI Designer. Your task is to generate a cohesive, aesthetically pleasing color palette based on a user's text query.

Generate a JSON response with the following format:
{
    "palette_text": ["color_name1", "color_name2", "color_name3", "color_name4", "color_name5"],
    "palette_hex": ["(H, S%, L%)", "(H, S%, L%)", "(H, S%, L%)", "(H, S%, L%)", "(H, S%, L%)"]
}
where H is between [0, 360), and S and L are between [0, 100]. No decimals.
As a reminder, H stands for hue, S stands for saturation, and L stands for lightness
"""

USER_PROMPT_TEMPLATE_HSL = """
What's the best color palette consisting of five colors to describe the text "{query}"?
Provide the color values using hsl format.

Here are some associate text-palette pairs for reference:
### REFERENCE PALETTES
{examples}
"""

print("HSL prompts defined")

HSL prompts defined


In [14]:
# Generate JSONL dataset with HSL format
output_file = "train_dataset_hsl.jsonl"

print(f"Generating JSONL dataset with {len(query_names)} examples in HSL format...")

with open(output_file, "w") as f:
    for idx, (name_tokens, palette_rgb) in enumerate(tqdm(zip(query_names, query_palettes), total=len(query_names))):
        try:
            # Get query string
            query_str = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
            
            if not query_str or not palette_rgb or len(palette_rgb) != 5:
                print(f"Skipping invalid entry at index {idx}")
                continue
            
            # Retrieve RAG examples from DB
            results = collection.query(
                query_texts=[query_str],
                n_results=3
            )
            
            # Format examples
            examples_str = ""
            if results['documents'] and results['documents'][0]:
                for i, result_doc in enumerate(results['documents'][0], start=1):
                    data = json.loads(result_doc)
                    examples_str += f"Palette {i}:\n"
                    examples_str += f"Description: {data['description']}\n"
                    for palette_color in data['palette']:
                        examples_str += f"  - {palette_color}\n"
                    examples_str += "\n"
            
            # Create full query with system prompt + user prompt
            user_prompt = USER_PROMPT_TEMPLATE_HSL.format(query=query_str, examples=examples_str)
            
            # Convert palette to HSL
            gt_palette_hsl = [rgb_to_hsl(rgb) for rgb in palette_rgb]
            
            # Get color names from KD tree
            gt_palette_text = [get_color_name_from_rgb(rgb) for rgb in palette_rgb]
            
            # Calculate diversity
            gt_diversity = calculate_palette_diversity(palette_rgb)
            
            # Create assistant response matching SYSTEM_PROMPT_HSL format
            assistant_response = {
                "palette_text": gt_palette_text,
                "palette_hsl": gt_palette_hsl
            }
            
            # Create JSONL entry
            entry = {
                "messages": [
                    {"role": "system", "content": SYSTEM_PROMPT_HSL},
                    {"role": "user", "content": user_prompt},
                    # {"role": "assistant", "content": json.dumps(assistant_response)}
                ],
                "gt_palette": gt_palette_hsl,
                "gt_diversity": gt_diversity
            }
            
            # Write to file
            f.write(json.dumps(entry) + "\n")
        
        except Exception as e:
            print(f"Error processing query entry {idx}: {e}")
            continue

print(f"JSONL dataset saved to {output_file}")

Generating JSONL dataset with 1500 examples in HSL format...


100%|██████████| 1500/1500 [01:35<00:00, 15.72it/s]

JSONL dataset saved to train_dataset_hsl.jsonl





In [11]:
# Verify the output by reading a few lines
print("Verifying JSONL output format:")
print("=" * 80)

with open(output_file, "r") as f:
    # Count lines
    lines = f.readlines()
    print(f"\nTotal lines in JSONL: {len(lines)}")
    
    # Show first example
    print("\n" + "=" * 80)
    print("FIRST EXAMPLE:")
    print("=" * 80)
    first_entry = json.loads(lines[0])
    
    print("System prompt (first 200 chars):")
    print(first_entry['messages'][0]['content'][:200] + "...")
    
    print("\nUser prompt (first 200 chars):")
    print(first_entry['messages'][1]['content'][:200] + "...")
    
    print("\nAssistant response:")
    print(first_entry['messages'][2]['content'])
    
    print(f"\nGround truth palette (HSL): {first_entry['gt_palette']}")
    print(f"Ground truth diversity: {first_entry['gt_diversity']:.2f}")
    
    # Show full structure
    print("\n" + "=" * 80)
    print("FULL STRUCTURE OF FIRST ENTRY:")
    print("=" * 80)
    print(json.dumps(first_entry, indent=2)[:1500] + "...")

Verifying JSONL output format:

Total lines in JSONL: 1500

FIRST EXAMPLE:
System prompt (first 200 chars):

You are an expert Color Theorist and UI Designer. Your task is to generate a cohesive, aesthetically pleasing color palette based on a user's text query.

Generate a JSON response with the following ...

User prompt (first 200 chars):

What's the best color palette consisting of five colors to describe the text iced coffee?
Provide the color values using HSL format (H, S%, L%) in ascending order.

Here are some associate text-palet...

Assistant response:
{"palette_text": ["F2D4B6 (Faded Peach)", "Dandelion Yellow", "Warm Red-Brown", "Chestnut Brown", "Mahogany Rose"], "palette_hsl": ["(30, 65%, 82%)", "(21, 36%, 70%)", "(359, 23%, 49%)", "(20, 25%, 48%)", "(14, 25%, 55%)"]}

Ground truth palette (HSL): ['(30, 65%, 82%)', '(21, 36%, 70%)', '(359, 23%, 49%)', '(20, 25%, 48%)', '(14, 25%, 55%)']
Ground truth diversity: 18.63

FULL STRUCTURE OF FIRST ENTRY:
{
  "messages": [
    {
