In [11]:
import pickle
import chromadb
from chromadb.utils import embedding_functions
import json
import numpy as np
from tqdm import tqdm
from colormath.color_conversions import convert_color
from colormath.color_objects import sRGBColor, LabColor
from colormath.color_diff import delta_e_cie2000

# Patch for numpy.asscalar removed in newer versions
def patch_asscalar(a):
    return a.item()

setattr(np, "asscalar", patch_asscalar)

# Load training data
with open("hexcolor_vf/train_names.pkl", "rb") as f:
    train_names = pickle.load(f)
    
with open("hexcolor_vf/train_palettes_rgb.pkl", "rb") as f:
    train_palettes_rgb = pickle.load(f)

print(f"Loaded {len(train_names)} training examples")
print(f"Sample name: {train_names[0]}")
print(f"Sample palette: {train_palettes_rgb[0]}")

Loaded 9165 training examples
Sample name: ['it', 'is', 'cold']
Sample palette: [(86, 131, 192), (151, 221, 209), (128, 109, 212), (160, 208, 218), (191, 167, 229)]


In [12]:
# Deduplicate training data - keep only first occurrence of duplicate strings
seen_names = {}
unique_names = []
unique_palettes = []

for name_tokens, palette in zip(train_names, train_palettes_rgb):
    # Convert name tokens to string
    name_str = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
    
    # Only add if we haven't seen this name before
    if name_str not in seen_names:
        seen_names[name_str] = True
        unique_names.append(name_tokens)
        unique_palettes.append(palette)

print(f"Original count: {len(train_names)}")
print(f"After deduplication: {len(unique_names)}")
print(f"Duplicates removed: {len(train_names) - len(unique_names)}")

Original count: 9165
After deduplication: 7752
Duplicates removed: 1413


In [13]:
# Split into DB examples (for RAG) and query examples (1500 for JSONL)
# We'll use the last 1500 unique examples as our query examples
# The rest will go into the DB for RAG retrieval

QUERY_SIZE = 1500

if len(unique_names) < QUERY_SIZE:
    raise ValueError(f"Not enough unique examples! Have {len(unique_names)}, need at least {QUERY_SIZE}")

# Split: everything except last 1500 goes to DB, last 1500 are query examples
db_names = unique_names[:-QUERY_SIZE]
db_palettes = unique_palettes[:-QUERY_SIZE]

query_names = unique_names[-QUERY_SIZE:]
query_palettes = unique_palettes[-QUERY_SIZE:]

print(f"DB examples (for RAG): {len(db_names)}")
print(f"Query examples (for JSONL): {len(query_names)}")
assert len(query_names) == QUERY_SIZE

DB examples (for RAG): 6252
Query examples (for JSONL): 1500


In [14]:
# Helper function to convert RGB to hex
def rgb_to_hex(rgb):
    return '#{:02x}{:02x}{:02x}'.format(int(rgb[0]), int(rgb[1]), int(rgb[2]))

# Create ChromaDB collection with DB examples (not the query examples)
CHROMA_PATH = "chroma_db_train"
COLLECTION_NAME = "train_rag"
MODEL_NAME = "all-mpnet-base-v2"

print("Initializing ChromaDB for training RAG...")
client = chromadb.PersistentClient(path=CHROMA_PATH)

# Create embedding function
ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=MODEL_NAME)

# Delete collection if it exists, then create fresh
try:
    client.delete_collection(name=COLLECTION_NAME)
    print("Deleted existing collection")
except:
    pass

collection = client.create_collection(name=COLLECTION_NAME, embedding_function=ef)
print("Created new collection")

Initializing ChromaDB for training RAG...
Deleted existing collection
Created new collection


In [15]:
# Populate ChromaDB with DB examples
print("Populating ChromaDB with training examples...")
batch_size = 128

documents = []
ids = []

for idx, (name_tokens, palette_rgb) in enumerate(tqdm(zip(db_names, db_palettes), total=len(db_names))):
    try:
        desc = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
        if not desc or not palette_rgb or len(palette_rgb) != 5:
            continue
        
        # Convert palette to hex
        hex_palette = [rgb_to_hex(rgb) for rgb in palette_rgb]
        
        # Create document structure similar to query_db.py format
        doc_structure = {
            "description": desc,
            "palette": hex_palette
        }
        
        documents.append(json.dumps(doc_structure))
        ids.append(f"db_{idx}")
        
        # Batch insert
        if len(documents) >= batch_size:
            descriptions = [json.loads(d)["description"] for d in documents]
            batch_embeddings = ef(descriptions)
            collection.add(ids=ids, embeddings=batch_embeddings, documents=documents)
            documents.clear()
            ids.clear()
    
    except Exception as e:
        print(f"Error processing DB entry {idx}: {e}")
        continue

# Add remaining documents
if documents:
    descriptions = [json.loads(d)["description"] for d in documents]
    batch_embeddings = ef(descriptions)
    collection.add(ids=ids, embeddings=batch_embeddings, documents=documents)

print(f"Database creation complete. Total items in DB: {collection.count()}")

Populating ChromaDB with training examples...


100%|██████████| 6252/6252 [00:21<00:00, 295.70it/s]


Database creation complete. Total items in DB: 6252


In [16]:
# Function to calculate palette diversity (from ml/grader/grader.py)
def calculate_palette_diversity(palette_rgb):
    """
    Calculate diversity of a color palette using pairwise CIELAB color distances.
    
    Args:
        palette_rgb: List of RGB tuples [(r1,g1,b1), (r2,g2,b2), ...]
    
    Returns:
        float: Average pairwise distance between colors
    """
    # Convert RGB to LAB color objects
    lab_colors = []
    for rgb in palette_rgb:
        # Create sRGB color object (values should be 0-255)
        rgb_color = sRGBColor(rgb[0], rgb[1], rgb[2], is_upscaled=True)
        # Convert to LAB
        lab_color = convert_color(rgb_color, LabColor)
        lab_colors.append(lab_color)
    
    # Calculate pairwise distances
    pairwise_distance = 0.0
    num_pairs = 0
    N = len(lab_colors)
    
    for i in range(N - 1):
        cur_color = lab_colors[i]
        for j in range(i + 1, N):
            distance = delta_e_cie2000(cur_color, lab_colors[j])
            pairwise_distance += distance
            num_pairs += 1
    
    return pairwise_distance / num_pairs if num_pairs > 0 else 0.0

# Test the function
test_palette = query_palettes[0]
test_diversity = calculate_palette_diversity(test_palette)
print(f"Test palette: {[rgb_to_hex(rgb) for rgb in test_palette]}")
print(f"Test diversity: {test_diversity:.2f}")

Test palette: ['#efd1b3', '#ceaa96', '#9b6162', '#99705b', '#a87c6f']
Test diversity: 18.63


In [17]:
# Define prompts (from ml/evaluator/config.py)
SYSTEM_PROMPT = """
You are an expert Color Theorist and UI Designer. Your task is to generate a cohesive, aesthetically pleasing color palette based on a user's text query.

Generate a JSON response with the following format:
{
    "palette_text": ["color_name1", "color_name2", "color_name3", "color_name4", "color_name5"],
    "palette_hex": ["#XXXXXX", "#XXXXXX", "#XXXXXX", "#XXXXXX", "#XXXXXX"]
}

\\no_think
"""

USER_PROMPT_TEMPLATE = """
What's the best color palette consisting of five colors to describe the text {query}?
Provide the color values using text (hex) format in ascending order.

Here are some associate text-palette pairs for reference:
### REFERENCE PALETTES
{examples}
"""

print("Prompts defined")

Prompts defined


In [18]:
# Generate JSONL dataset
output_file = "train_dataset.jsonl"

print(f"Generating JSONL dataset with {len(query_names)} examples...")

with open(output_file, "w") as f:
    for idx, (name_tokens, palette_rgb) in enumerate(tqdm(zip(query_names, query_palettes), total=len(query_names))):
        try:
            # Get query string
            query_str = " ".join(name_tokens) if isinstance(name_tokens, list) else name_tokens
            
            if not query_str or not palette_rgb or len(palette_rgb) != 5:
                print(f"Skipping invalid entry at index {idx}")
                continue
            
            # Retrieve RAG examples from DB (similar to query_db.py)
            results = collection.query(
                query_texts=[query_str],
                n_results=3
            )
            
            # Format examples
            examples_str = ""
            if results['documents'] and results['documents'][0]:
                for i, result_doc in enumerate(results['documents'][0], start=1):
                    data = json.loads(result_doc)
                    examples_str += f"Palette {i}:\n"
                    examples_str += f"Description: {data['description']}\n"
                    for color in data['palette']:
                        examples_str += f"  - {color}\n"
                    examples_str += "\n"
            
            # Create full query with system prompt + user prompt
            user_prompt = USER_PROMPT_TEMPLATE.format(query=query_str, examples=examples_str)
            full_query = SYSTEM_PROMPT + "\n" + user_prompt
            
            # Convert palette to hex
            gt_palette_hex = [rgb_to_hex(rgb) for rgb in palette_rgb]
            
            # Calculate diversity
            gt_diversity = calculate_palette_diversity(palette_rgb)
            
            # Create JSONL entry
            entry = {
                "input": {
                    "messages": [
                        {"role": "assistant", "content": full_query}
                    ],
                    "gt_palette": gt_palette_hex,
                    "gt_diversity": gt_diversity
                }
            }
            
            # Write to file
            f.write(json.dumps(entry) + "\n")
        
        except Exception as e:
            print(f"Error processing query entry {idx}: {e}")
            continue

print(f"JSONL dataset saved to {output_file}")

Generating JSONL dataset with 1500 examples...


100%|██████████| 1500/1500 [01:28<00:00, 16.89it/s]

JSONL dataset saved to train_dataset.jsonl





In [19]:
# Verify the output by reading a few lines
print("Verifying JSONL output format:")
print("=" * 80)

with open(output_file, "r") as f:
    # Count lines
    lines = f.readlines()
    print(f"\nTotal lines in JSONL: {len(lines)}")
    
    # Show first example
    print("\n" + "=" * 80)
    print("FIRST EXAMPLE:")
    print("=" * 80)
    first_entry = json.loads(lines[0])
    print(f"Query preview (first 200 chars): {first_entry['input']['messages'][0]['content'][:200]}...")
    print(f"\nGround truth palette: {first_entry['input']['gt_palette']}")
    print(f"Ground truth diversity: {first_entry['input']['gt_diversity']:.2f}")
    
    # Show structure
    print("\n" + "=" * 80)
    print("FULL STRUCTURE OF FIRST ENTRY:")
    print("=" * 80)
    print(json.dumps(first_entry, indent=2)[:1000] + "...")

Verifying JSONL output format:

Total lines in JSONL: 1500

FIRST EXAMPLE:
Query preview (first 200 chars): 
You are an expert Color Theorist and UI Designer. Your task is to generate a cohesive, aesthetically pleasing color palette based on a user's text query.

Generate a JSON response with the following ...

Ground truth palette: ['#efd1b3', '#ceaa96', '#9b6162', '#99705b', '#a87c6f']
Ground truth diversity: 18.63

FULL STRUCTURE OF FIRST ENTRY:
{
  "input": {
    "messages": [
      {
        "role": "assistant",
        "content": "\nYou are an expert Color Theorist and UI Designer. Your task is to generate a cohesive, aesthetically pleasing color palette based on a user's text query.\n\nGenerate a JSON response with the following format:\n{\n    \"palette_text\": [\"color_name1\", \"color_name2\", \"color_name3\", \"color_name4\", \"color_name5\"],\n    \"palette_hex\": [\"#XXXXXX\", \"#XXXXXX\", \"#XXXXXX\", \"#XXXXXX\", \"#XXXXXX\"]\n}\n\n\\no_think\n\n\nWhat's the best color p