In [1]:
import pandas as pd
import ollama
import json
from tqdm import tqdm
import ast


In [2]:
# Define your 10 class labels
LABELS = [
'husband and wife',	'murder',	'the desire for vengeance',	'infatuation',	'romantic love',
    'friendship',	'spouse murder',	'greed for riches',	'father and son',	'extramarital affair'
]


In [3]:
test = pd.read_csv('test_sub.csv')

In [7]:
def classify_text_with_ollama(text: str, model: str) -> list:
    prompt = PROMPT_TEMPLATE.format(labels="\n".join(LABELS), text=text)
    response = ollama.chat(model=model, messages=[{"role": "user", "content": prompt}])
    content = response['message']['content']
    content = content.lower()
    return content


def process_content(content):
    try:
        predicted_labels = ast.literal_eval(content)
        if isinstance(predicted_labels, list):

            return [label.strip() for label in predicted_labels if label.strip() in LABELS]
    except:
        # Fallback: extract matching labels from text
        return [label for label in LABELS if label in content]

    return []

def main(df1: pd.DataFrame, output_csv: str, model: str, text_column: str = 'file_content'):
    # Add binary columns
    df = pd.DataFrame(index=df1.index, columns=df1.columns)
    df['content']= df1['content']
    
    for label in LABELS:
        df[label] = 0

    df["raw_prediction"] = ""

    print(f"Processing {len(df)} rows using model: {model}")
    for idx, row in tqdm(df.iterrows(), total=len(df)):
        text = row[text_column]
        content = classify_text_with_ollama(text, model=model)
        predicted_labels = process_content(content)
        df.at[idx, "message"] = content
        df.at[idx, "raw_prediction"] = str(predicted_labels)
        for label in predicted_labels:
            df.at[idx, label] = 1

    df.to_csv(output_csv, index=False)
    return df 
    print(f"✅ Done. Output saved to: {output_csv}")

In [8]:
PROMPT_TEMPLATE = """
You are a theme classifier for TV episode subtitles. Based on the subtitles provided, identify which of the following 10 EXACT themes are present as **main themes** in the episode. 

You must only choose from the following ten themes: {labels}. Use the definitions to guide your selections.

--- Theme Definitions ---

husband and wife: The relationship between husband and wife is featured.  
murder: The crime of unlawful and intentional homicide is featured.  
the desire for vengeance: A character seeks retribution over a perceived injury or wrong.  
infatuation: An intense but (typically) short-lived passion that may peter out or settle into more enduring romantic love.  
romantic love: Featured is that peculiar sort of love between people so often associated with sexual attraction.  
friendship: The friendship between two characters is featured.  
spouse murder: One spouse in a married couple seeks to bring about the death of their partner.  
greed for riches: A character exhibits an inordinate desire for wealth such as money, luxuries, and the like.  
father and son: The relationship between a father and his son is featured.  
extramarital affair: A character who is married engages in a sexual encounter or relationship outside the marriage, and deals with the consequences.


--- Example Output ---

Output:  
['murder', 'greed for riches']

--- Instructions ---

- Return a valid Python list of strings.  
- Use only the exact theme names listed above.  
- Do NOT include any explanation.  
- Do NOT invent new labels.  



--- New Input ---

Text:  
{text}

Output:
"""


In [9]:
models = "mistral:7b-instruct"
output = 'prompting_sub/instruct/prompt3/test_mistral:latest.csv'
# Call with any model and output name
df1 =main(df1=test, output_csv=output, text_column="content", model=models)




Processing 192 rows using model: mistral:7b-instruct


100%|███████████████████████████████████████| 192/192 [1:19:31<00:00, 24.85s/it]


In [13]:
models = "gemma3:12b-it-qat"
output = 'prompting_sub/instruct/prompt3/test_gemma3:12b-it-qat.csv'
df4 = main(df1=test, output_csv=output, text_column="content", model=models)


Processing 192 rows using model: gemma3:12b-it-qat


100%|███████████████████████████████████████| 192/192 [3:06:55<00:00, 58.41s/it]


In [9]:
models = "llama3.1:8b-instruct-q8_0"
output = 'prompting_sub/instruct/prompt3/test_llama3.1:8b-instruct-q8_0.csv'
df5 = main(df1=test, output_csv=output, text_column="content", model=models)

Processing 192 rows using model: llama3.1:8b-instruct-q8_0


100%|███████████████████████████████████████| 192/192 [1:28:23<00:00, 27.62s/it]
