# LLM humor detection with Subspace based metric

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import pandas as pd

MODEL_ID = "google/gemma-2-2b-it"

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
    output_hidden_states=True  # Enable hidden states output
)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
ground_truth = pd.read_csv('../data/stand_up_dataset/standup_data.csv')
transcript = pd.read_csv('../data/stand_up_dataset/standup_transcripts.csv')

In [3]:
INSTRUCTIONS = [
    "Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    "The following is a stand-up comedy transcript. When performed in front of a live audience, which jokes do you think made the audience laugh?  List of quotes:",
    "You are a person who enjoys aggressive humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    "You are a person who enjoys self-enhancing humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys self-deprecating humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys dark humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "You are a person who enjoys affiliative humor. Extract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "The following is a stand-up comedy transcript. What are the funniest punchlines from the transcript. List of quotes:",
    # "Below is a transcript from a stand-up comedy routine. Analyze the transcript and extract the quotes that are most likely to have made the audience laugh. List of quotes:",
    # "The following is a stand-up comedy transcript. When preformed in front of a live audience, which jokes do you think made the audience laugh? List of quotes:",
    # "Pretend that you are a stand-up comedian reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "Pretend that you are a stand-up comedy fan reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:",
    # "Pretend that you are a stand-up comedy critic reading the following stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:" 
]

In [4]:
gt = ground_truth.groupby("comedian")["sentence"].apply(list).apply(lambda sentences: "\n".join([f"{i + 1}. {s}" for i, s in enumerate(sentences)]))
df = transcript.set_index("comedian").join(gt).rename(columns={"sentence": "ground_truth"})

df["instruction"] = [INSTRUCTIONS] * len(df)
df = df.explode("instruction")

def gt_chat_template(row):
    return tokenizer.apply_chat_template([
        # {"role": "system", "content": ""},
        {"role": "user", "content": row["instruction"] + "\n" + row["transcript"]},
        {"role": "assistant", "content": "Sure, here are the key humorous lines:\n" + row["ground_truth"]},
    ], tokenize=False)

df["gt_input"] = df.apply(gt_chat_template, axis=1)

def model_chat_template(row):
    return tokenizer.apply_chat_template([
        # {"role": "system", "content": ""},
        {"role": "user", "content": row["instruction"] + "\n" + row["transcript"]},
    ], tokenize=False)

df["model_input"] = df.apply(model_chat_template, axis=1)

df

Unnamed: 0_level_0,transcript,ground_truth,instruction,gt_input,model_input
comedian,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",Extract the key humorous lines and punchlines ...,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",The following is a stand-up comedy transcript....,<bos><start_of_turn>user\nThe following is a s...,<bos><start_of_turn>user\nThe following is a s...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",You are a person who enjoys aggressive humor. ...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Anthony_Jeselnik,"When I was a kid, I used to fantasize about ge...","1. So poor I remember, just so I could go to m...",You are a person who enjoys self-enhancing hum...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Anthony_Jeselnik_2,No one should ever ask me to speak at anyone’...,"1. Was like, ""I've never talked to a group of ...",Extract the key humorous lines and punchlines ...,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
...,...,...,...,...,...
Trevor_Noah_3,You know what fascinates me about New York… is...,1. You know what fascinates me about New York…...,You are a person who enjoys self-enhancing hum...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",Extract the key humorous lines and punchlines ...,<bos><start_of_turn>user\nExtract the key humo...,<bos><start_of_turn>user\nExtract the key humo...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",The following is a stand-up comedy transcript....,<bos><start_of_turn>user\nThe following is a s...,<bos><start_of_turn>user\nThe following is a s...
Tom_Segura_3,Probably checked in to 400 hotels this year. A...,"1. And the guy goes, “Whoa. Are you Japanese?”...",You are a person who enjoys aggressive humor. ...,<bos><start_of_turn>user\nYou are a person who...,<bos><start_of_turn>user\nYou are a person who...


In [5]:
from tqdm import tqdm

# use unembedding tokenization form
def get_gt_representation(batch_of_strs: list[str], number_of_tokens: int = 128) -> torch.Tensor:
    inputs = tokenizer(batch_of_strs, return_tensors="pt", padding=True, truncation=False).to(model.device)
    with torch.inference_mode():
        return model(**inputs).hidden_states[-1][:, -number_of_tokens:].flatten(1)

gt_representations = {
    comedian: get_gt_representation(batch.tolist())
    for comedian, batch in tqdm(df.groupby("comedian")["model_input"])
}

  0%|          | 0/51 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
100%|██████████| 51/51 [00:04<00:00, 10.49it/s]


In [6]:
def get_output_representation(batch_of_strs: list[str], number_of_tokens: int = 128) -> torch.Tensor:
    inputs = tokenizer(batch_of_strs, return_tensors="pt", padding=True, truncation=False).to(model.device)
    with torch.inference_mode():
        return model(input_ids=model.generate(**inputs, max_new_tokens=128)).hidden_states[-1][:, -number_of_tokens:].flatten(1)

output_representations = {
    comedian: get_output_representation(batch.tolist())
    for comedian, batch in tqdm(df.groupby("comedian")["gt_input"])
}

  0%|          | 0/51 [00:00<?, ?it/s]The 'max_batch_size' argument of HybridCache is deprecated and will be removed in v4.46. Use the more precisely named 'batch_size' argument instead.
100%|██████████| 51/51 [01:49<00:00,  2.14s/it]


In [7]:
gt_representations["Ali_Wong"].shape

torch.Size([4, 294912])

In [8]:
output_representations["Ali_Wong"].shape

torch.Size([4, 294912])

In [22]:
def make_subspace(data: torch.FloatTensor, q: int = 3) -> torch.Tensor:
    data = torch.nn.functional.normalize(data, p=2, dim=-1)
    data = data - data.mean(0, keepdim=True)
    *_, Vh = torch.pca_lowrank(data, q=q)
    return Vh

scores = {}
for comedian in tqdm(gt_representations.keys()):
    gt_reference_subspace = make_subspace(gt_representations[comedian].float())
    out_reference_subspace = make_subspace(output_representations[comedian].float())

    A = gt_reference_subspace.mT @ out_reference_subspace
    scores[comedian] = (A.mT @ A).trace() / A.shape[0]

100%|██████████| 51/51 [00:00<00:00, 239.06it/s]


In [34]:
df.loc["Ali_Wong", "gt_input"].tolist()

["<bos><start_of_turn>user\nExtract the key humorous lines and punchlines for this stand-up comedy transcript. Focus on the quotes highlighting the main comedic moments. List of quotes:\nThe last time I was at home in San Francisco, I was trying to help her get rid of shit. Don’t ever do that with your mom. It was like the worst experience of my life. It was so emotional. We were screaming and fighting and yelling and it all came to a climax when she refused to let go of a Texas Instruments TI-82… manual. The manual. She don’t even know… where the calculator is. Those of you under 25 probably don’t know what that calculator is. It was this calculator that bamboozled my generation. We were all required to buy it when we were in eight grade. It cost like $200. And everybody thought it was like this Judy Jetson’s laptop from the future. All because what? It could graph. It was like the Tesla of my time. And my mom got so emotional about the manual and she was like, “You never know when yo

In [30]:
data = gt_representations["Ali_Wong"].float()
data = torch.nn.functional.normalize(data, p=2, dim=-1)

(data @ make_subspace(data)).square().mean(-1).sqrt()

tensor([0.0282, 0.0282, 0.0282, 0.0282], device='cuda:0')

In [11]:
gt_reference_list = []
for i in all_representations:
    for comedian, rep in i.items():
        gt_references = torch.stack(rep)
        *_, gt_reference_subspace = torch.pca_lowrank(gt_references.float(), q=10)
        gt_reference_list.append({comedian: gt_reference_subspace})
    
# gt_reference_subspace.shape

NameError: name 'all_representations' is not defined

In [38]:
#sentences = [f"{inst}\n{text}\n" for inst, text in product(INSTRUCTIONS, TRANSCRIPTS)]

#inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True).to(model.device)
#with torch.inference_mode():
    #outputs = model(input_ids=model.generate(**inputs, max_new_tokens=128))


# representations = torch.cat(outputs.hidden_states)[-1, -16:].flatten()
# all_representations.append(representation)