# LLM humor detection with Subspace based metric

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd
from more_itertools import batched
from tqdm import tqdm
from pathlib import Path

MODEL_ID = "google/gemma-2-2b-it"

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
cwd = Path.cwd().parent

In [3]:
ground_truth = pd.read_csv(next(cwd.glob("**/standup_data.csv")))
transcript = pd.read_csv(next(cwd.glob("**/standup_transcripts.csv")))

In [4]:
INSTRUCTIONS = [
    "Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    "The following is a stand-up comedy transcript. When performed in front of a live audience, which jokes do you think made the audience laugh?  List of quotes:",
    "You are a person who enjoys aggressive humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    "You are a person who enjoys self-enhancing humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys self-deprecating humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys dark humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys affiliative humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "The following is a stand-up comedy transcript. What are the funniest punchlines from the transcript. List of quotes:",
    # "Below is a transcript from a stand-up comedy routine. Analyze the transcript and extract the quotes that are most likely to have made the audience laugh. List of quotes:",
    # "The following is a stand-up comedy transcript. When preformed in front of a live audience, which jokes do you think made the audience laugh? List of quotes:",
    # "Pretend that you are a stand-up comedian reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "Pretend that you are a stand-up comedy fan reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "Pretend that you are a stand-up comedy critic reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    #"Analyze the stand-up comedy transcript below. Which lines and punchlines do you think delivered the biggest laughs to the audience? List of quotes:", 
    #"As a person who enjoys witty, intellectual humor, extract the key humorous lines and punchlines from this stand-up comedy transcript. Focus on the quotes that demonstrate clever wordplay or insights. List of quotes:",
    #"This is a transcript from a stand-up routine. Identify the lines and punchlines that likely had the strongest comedic impact during the performance. List of quotes:",
    #"Pretend you're an audience member at this stand-up show. Which lines do you think got the biggest laughs? Focus on key moments of humor. List of quotes:",
    #"This is a transcript of a live stand-up performance. Which quotes do you believe would have resonated the most with the audience? Focus on key punchlines. List of quotes:",
    #"Imagine you are a comedian reviewing this stand-up routine. Identify the funniest moments and lines where the punchlines landed the hardest. List of quotes:",
    #"Read through the stand-up comedy transcript and extract the lines that best capture the humor and timing of the performance. Focus on punchlines that likely had the audience laughing. List of quotes:",
    #"This is a stand-up comedy transcript. Analyze the content and extract the lines that most effectively build up to or deliver punchlines. List of quotes:",
    #"Pretend you're watching this performance live. What do you think were the standout comedic lines and punchlines that elicited the loudest laughs? List of quotes:",
    #"Imagine you are writing a review of this stand-up performance. What lines and punchlines would you highlight as the funniest moments? List of quotes:"
]

CONTENTS = [
    "",
    "Sure, here are the key humorous lines:",
    "Here are some lines and punchlines that could be funny:",
    "Got it! Here are the main punchlines and comedic highlights:",
    # "Here's a selection of the funniest quotes from the transcript:",
    # "I've picked out the key humorous moments for you:",
    # "Below are the standout lines and punchlines from the performance:",
    # "Here's a breakdown of the top quotes that likely got the biggest laughs:",
    # "Take a look at these key comedic lines from the routine:",
    # "Here's a list of the most memorable punchlines from the set:",
    # "Check out these quotes—some of the best comedic moments from the transcript:",
    # "Here are the funniest moments and punchlines I found in the transcript:",
    # "Here's what I've identified as the standout lines and punchlines in this comedy routine:" 
]

In [5]:
gt = ground_truth.groupby("comedian")["sentence"].apply(list).apply(lambda sentences: "\n".join([f"{i + 1}. {s}" for i, s in enumerate(sentences)]))
df = transcript.set_index("comedian").join(gt).rename(columns={"sentence": "ground_truth"})

df["instruction"] = [INSTRUCTIONS] * len(df)
df = df.explode("instruction")
df["content"] = [CONTENTS] * len(df)
df = df.explode("content")

def gt_chat_template(row):
    return tokenizer.apply_chat_template([
        # {"role": "system", "content": ""},
        {"role": "user", "content": row["instruction"] + "\n" + row["transcript"]},
        {"role": "assistant", "content": row["content"] + "\n" + row["ground_truth"]},
    ], tokenize=False)

df["gt_input"] = df.apply(gt_chat_template, axis=1)

def model_chat_template(row):
    return tokenizer.apply_chat_template([
        # {"role": "system", "content": ""},
        {"role": "user", "content": row["instruction"] + "\n" + row["transcript"]},
    ], tokenize=False)

df["model_input"] = df.apply(model_chat_template, axis=1)

df

Unnamed: 0_level_0,transcript,ground_truth,instruction,content,gt_input,model_input
comedian,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",Extract the key humorous lines and punchlines ...,,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",Extract the key humorous lines and punchlines ...,"Sure, here are the key humorous lines:",<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",Extract the key humorous lines and punchlines ...,Here are some lines and punchlines that could ...,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",Extract the key humorous lines and punchlines ...,Got it! Here are the main punchlines and comed...,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",The following is a stand-up comedy transcript....,,<bos><start_of_turn>user\nThe following is a s...,<bos><start_of_turn>user\nThe following is a s...
...,...,...,...,...,...,...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",You are a person who enjoys aggressive humor. ...,Got it! Here are the main punchlines and comed...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",You are a person who enjoys self-enhancing hum...,,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",You are a person who enjoys self-enhancing hum...,"Sure, here are the key humorous lines:",<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",You are a person who enjoys self-enhancing hum...,Here are some lines and punchlines that could ...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...


In [6]:
U = model.lm_head.weight.float().detach()

In [7]:
# use unembedding tokenization form
def get_gt_representation(batch_of_strs: list, subspace_size: int = 8) -> torch.Tensor:
    inputs = tokenizer(batch_of_strs, return_tensors="pt", padding=True, truncation=False).to(model.device)
    *_, subs_repr = torch.pca_lowrank(U[inputs["input_ids"]], q=subspace_size)
    return subs_repr

gt_representations = {
    comedian: get_gt_representation(batch.tolist())
    for comedian, batch in tqdm(df.groupby("comedian")["model_input"])
}

100%|██████████| 51/51 [00:00<00:00, 65.13it/s]


In [8]:
gt_representations["Ali_Wong"].shape

torch.Size([16, 2304, 8])

In [None]:
def get_output_representation(batch_of_strs: list, number_of_tokens: int = 128, subspace_size: int = 8) -> torch.Tensor:
    inputs = tokenizer(batch_of_strs, return_tensors="pt", padding=True, truncation=False).to(model.device)
    
    with torch.inference_mode():
        ids = model.generate(**inputs, max_new_tokens=number_of_tokens)
    
    *_, subs_repr = torch.pca_lowrank(U[ids], q=subspace_size)
    return subs_repr
        

BATCH_SIZE = 16

output_representations = {
    comedian: torch.cat([get_output_representation(x) for x in batched(batch.tolist(), BATCH_SIZE)])
    for comedian, batch in tqdm(df.groupby("comedian")["gt_input"])
}

  0%|          | 0/51 [00:00<?, ?it/s]The 'max_batch_size' argument of HybridCache is deprecated and will be removed in v4.46. Use the more precisely named 'batch_size' argument instead.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
 10%|▉         | 5/51 [00:36<05:34,  7.28s/it]

In [None]:
scores = {}
for comedian in tqdm(gt_representations.keys()):
    gt_reference_subspaces = gt_representations[comedian]
    out_reference_subspaces = output_representations[comedian]

    A = gt_reference_subspaces.mT @ out_reference_subspaces
    scores[comedian] = A.matrix_power(2).diagonal(dim1=1,dim2=2).mean().item()

In [None]:
# pd.DataFrame(scores, index=range(len(scores))).to_csv("scores.csv")
df = pd.DataFrame(list(scores.items()), columns=['Comedian', 'Score'])
df["Score"].mean() * 100

In [None]:
filepath = next(cwd.glob("**/subspace_scores/*.csv")).parent / f"scores_{MODEL_ID.rsplit('/')[-1]}.csv"
df.to_csv(filepath)

spearman's rank correlation coefficient

In [None]:
import sys
import pandas as pd
sys.path.append("..")
from scipy.stats import spearmanr
from humor.bipartite_metric import bipartite_metric

ground_truth = pd.read_csv('../data/stand_up_dataset/standup_data.csv')
model = pd.read_csv('../data/stand_up_dataset/gemma_answers.csv')

gemma_metric = bipartite_metric(model, ground_truth)
merged_df = pd.merge(gemma_metric, df, on='comedian', suffixes=('_df1', '_df2'))
correlation, p_value = spearmanr(merged_df['score_df1'], merged_df['score_df2'])
print("Correlation: ", correlation)
print("p_value:", p_value)