# Interview a Model
By William Caban

Interview a model for analogous harmony palette.

---

## Interview model

In [1]:
from mlx_lm import generate, load
from tqdm.notebook import tqdm
import sys

In [2]:
## GLOBAL CONFIG PARAMETERS

# Specify the maximum number of tokens
max_tokens = 1_000

# Specify if tokens and timing information will be printed
verbose = False

# Some optional arguments for causal language model generation
generation_args = {
    "temp": 0.1,
    "repetition_penalty": 1.2,
    "repetition_context_size": 20,
    "top_p": 0.95,
}

# System template
_system = "You are a cautious assistant. You are an expert in color palette. " + \
          "Given a Color generate a JSON array of 10 colors for an analogous harmony palette."

In [3]:
# Specify the model checkpoint
checkpoint = "instructlab/granite-7b-lab"

# Load the corresponding model and tokenizer
model0, tokenizer0 = load(path_or_hf_repo=checkpoint, tokenizer_config={'legacy': 'false'})

Fetching 11 files:   0%|          | 0/11 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
def extract_hex_list(text):
    """
    return a list of up to 10 hex color numbers in the original text
    """
    hexapattern = r'(#[0-9a-fA-F]+)'
    hex_list = re.findall(hexapattern, text)
    # silently drop any invalid hex entry in the form '#123456'
    condition = lambda x: len(x) != 7
    hex_list_clean = [x for x in hex_list if not condition(x)]
    # only return the first 10 entries (workaround response including more colors than expected)
    return hex_list_clean[0:10]

In [5]:
def to_prompt(color, tokenizer):
    global _system
    # Specify the prompt and conversation history
    conversation = [
        {"role": "system", "content" : _system},
        {"role": "user", "content": f"Color: {color}" },
        {"role": "assistant", "content": "" }
    ]

    # Transform the prompt into the chat template
    return tokenizer.apply_chat_template(
        conversation=conversation, 
        tokenize=False, 
        add_generation_prompt=True
    )

In [6]:
def query(prompt, model, tokenizer, DEBUG=False):
    global max_tokens
    global verbose
    global generation_args
    # Generate a response with the specified settings
    response = generate(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        max_tokens=max_tokens,
        verbose=verbose,
        **generation_args,
    )
    
    if DEBUG is True:
        print(f"Answer: {response}")
        
    return extract_hex_list(response)

In [7]:
def interview(df_set, model_cname, model, tokenizer):
    df = df_set.copy()
    
    if model_cname not in df.columns:
        # make sure the column exist for continuation logic
        df[model_cname]=pd.NA
    
    df[model_cname]=pd.NA
    for indx, row in tqdm(df.iterrows(), desc=f"Interviewing {model_cname}"):
        prompt = to_prompt(row['input'], tokenizer)
    
        # only invoke the llm if there is no answer with this model
        try:
            if df.isnull().loc[indx, model_cname].sum() > 0:
                df.loc[indx, model_cname] = str(query(prompt, model, tokenizer))
            else:
                continue
        except Exception as e:
            print(
                f"ERROR: {e}\n df.loc results: {df.isnull().loc[indx, model_cname]} with count={df.isnull().loc[indx, model_cname].sum()}")
            sys.exit()
    return df

In [8]:
print(f"{help(extract_hex_list)} {help(to_prompt)} {help(query)} {help(interview)}")

Help on function extract_hex_list in module __main__:

extract_hex_list(text)
    return a list of up to 10 hex color numbers in the original text

Help on function to_prompt in module __main__:

to_prompt(color, tokenizer)

Help on function query in module __main__:

query(prompt, model, tokenizer, DEBUG=False)

Help on function interview in module __main__:

interview(df_set, model_cname, model, tokenizer)

None None None None
