# Yelp Review Classification Task

## Objective
Classify Yelp reviews into 1–5 stars using 3 different prompting approaches.
Evaluate Accuracy, JSON Validity, and Reliability.

## Setup
Ensure `yelp.csv` is in the same directory.

In [None]:
!pip -q install -U google-genai pandas numpy tqdm

import os, json, re, time
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# Put your Gemini API key here or set it as an environment variable
# os.environ["GEMINI_API_KEY"] = "YOUR_KEY"

from google import genai
os.environ["GEMINI_API_KEY"] = "YOUR_API_KEY"
client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])

# Load dataset (your attached file)
df = pd.read_csv("yelp.csv")
df = df[df["type"] == "review"].copy()  # safety; file has type column [file:1]
df = df.dropna(subset=["text", "stars"]).reset_index(drop=True)  # [file:1]

# Sample ~200 rows for evaluation
SAMPLE_N = 200
SEED = 42
eval_df = df.sample(n=SAMPLE_N, random_state=SEED).reset_index(drop=True)

eval_df[["stars", "text"]].head()


In [None]:
PROMPT_V1 = """
Classify the Yelp review into an integer star rating from 1 to 5.

Return ONLY valid JSON in this exact format:
{{
  "predicted_stars": 4,
  "explanation": "Brief reasoning for the assigned rating."
}}

Review:
{review}
""".strip()

PROMPT_V2 = """
You are a strict Yelp star-rating classifier.

Use this rubric:
- 1 star: terrible experience, strong complaints, would not return.
- 2 stars: below average, several issues, disappointed.
- 3 stars: mixed/average, pros and cons, acceptable.
- 4 stars: good experience, minor issues, would return.
- 5 stars: excellent, enthusiastic praise, highly recommend.

Return ONLY valid JSON exactly:
{{
  "predicted_stars": <1-5 integer>,
  "explanation": "One short sentence justifying the rating."
}}

Review:
{review}
""".strip()

PROMPT_V3 = """
You classify Yelp reviews into 1–5 stars.

Rules:
- Output MUST be valid JSON only (no markdown, no extra text).
- Output MUST contain exactly two keys: predicted_stars, explanation.
- predicted_stars must be an integer 1..5.

Examples:
Review: "Absolutely amazing service and the food was perfect. Can't wait to come back!"
Output: {{"predicted_stars": 5, "explanation": "Strong praise and eagerness to return indicates an excellent experience."}}

Review: "It was okay. Some things were good, but the wait was long and the place felt average."
Output: {{"predicted_stars": 3, "explanation": "Mixed feedback with notable downsides fits an average experience."}}

Review: "Never again. Rude staff, cold food, and it took forever."
Output: {{"predicted_stars": 1, "explanation": "Severe complaints and refusal to return indicate a very poor experience."}}

Now classify:
Review: {review}
""".strip()


In [None]:
MODEL = "gemini-2.5-flash"  # fast/cheap for batch eval; switch to pro if needed

def call_gemini(prompt: str, temperature=0.2, max_retries=3, retry_sleep=1.0):
    """
    Calls Gemini and returns raw text.
    """
    for attempt in range(max_retries):
        try:
            resp = client.models.generate_content(
                model=MODEL,
                contents=prompt,
                config=genai.types.GenerateContentConfig(
                    temperature=temperature,
                    # Keep outputs short to reduce JSON breakage
                    max_output_tokens=120,
                )
            )
            return resp.text
        except Exception as e:
            if attempt == max_retries - 1:
                return f"__ERROR__:{repr(e)}"
            time.sleep(retry_sleep * (attempt + 1))

def extract_json_str(s: str) -> str:
    """
    Attempts to extract a JSON object substring if the model adds extra text.
    """
    if s is None:
        return ""
    s = s.strip()
    # If already looks like JSON
    if s.startswith("{") and s.endswith("}"):
        return s
    # Fallback: find first {...} block
    m = re.search(r"\{.*\}", s, flags=re.DOTALL)
    return m.group(0).strip() if m else s

def parse_prediction(raw: str):
    """
    Returns: (is_valid_json, predicted_stars_or_none, explanation_or_none, parsed_obj_or_none)
    """
    if raw.startswith("__ERROR__"):
        return False, None, None, None

    j = extract_json_str(raw)
    try:
        obj = json.loads(j)
        # Validate schema
        if not isinstance(obj, dict):
            return False, None, None, None
        if set(obj.keys()) != {"predicted_stars", "explanation"}:
            return False, None, None, obj
        ps = obj["predicted_stars"]
        ex = obj["explanation"]
        if not isinstance(ps, int) or ps < 1 or ps > 5:
            return False, None, None, obj
        if not isinstance(ex, str) or len(ex.strip()) == 0:
            return False, None, None, obj
        return True, ps, ex.strip(), obj
    except Exception:
        return False, None, None, None


In [None]:
def evaluate_prompt(prompt_template: str, name: str, temperature=0.2):
    rows = []
    for i in tqdm(range(len(eval_df)), desc=f"Running {name}"):
        review = eval_df.loc[i, "text"]
        actual = int(eval_df.loc[i, "stars"])
        prompt = prompt_template.format(review=review)

        raw = call_gemini(prompt, temperature=temperature)
        is_valid, pred, expl, _ = parse_prediction(raw)

        rows.append({
            "approach": name,
            "row_id": i,
            "actual_stars": actual,
            "raw_output": raw,
            "json_valid": is_valid,
            "predicted_stars": pred,
            "explanation": expl,
        })
    out = pd.DataFrame(rows)
    # Valid rows = valid JSON and a numeric prediction
    valid_mask = out["json_valid"] & out["predicted_stars"].notna()
    out["correct"] = False  # default for invalid rows

    # Only cast on filtered rows (so no None enters astype(int))
    out.loc[valid_mask, "correct"] = (
        out.loc[valid_mask, "predicted_stars"].astype(int).values
        == out.loc[valid_mask, "actual_stars"].astype(int).values
    )

    accuracy = out.loc[valid_mask, "correct"].mean() if valid_mask.any() else 0.0
    json_valid_rate = out["json_valid"].mean()

    return out, accuracy, json_valid_rate


res_v1, acc_v1, jvr_v1 = evaluate_prompt(PROMPT_V1, "V1_simple", temperature=0.2)
res_v2, acc_v2, jvr_v2 = evaluate_prompt(PROMPT_V2, "V2_rubric", temperature=0.2)
res_v3, acc_v3, jvr_v3 = evaluate_prompt(PROMPT_V3, "V3_fewshot", temperature=0.2)

acc_v1, jvr_v1, acc_v2, jvr_v2, acc_v3, jvr_v3


In [None]:
def reliability_test(prompt_template: str, name: str, temperature=0.2):
    r1, _, _ = evaluate_prompt(prompt_template, name + "_run1", temperature=temperature)
    r2, _, _ = evaluate_prompt(prompt_template, name + "_run2", temperature=temperature)

    merged = r1.merge(r2, on="row_id", suffixes=("_1", "_2"))
    # Only count rows where both outputs were valid JSON
    both_valid = merged["json_valid_1"] & merged["json_valid_2"]
    # Consistency = same predicted stars when both valid
    consistency = (merged.loc[both_valid, "predicted_stars_1"] == merged.loc[both_valid, "predicted_stars_2"]).mean() if both_valid.any() else 0.0
    return consistency, merged

cons_v1, _ = reliability_test(PROMPT_V1, "V1_simple", temperature=0.2)
cons_v2, _ = reliability_test(PROMPT_V2, "V2_rubric", temperature=0.2)
cons_v3, _ = reliability_test(PROMPT_V3, "V3_fewshot", temperature=0.2)

cons_v1, cons_v2, cons_v3


In [None]:
comparison = pd.DataFrame([
    {"approach": "V1_simple", "accuracy_on_valid": acc_v1, "json_valid_rate": jvr_v1, "consistency_2runs": cons_v1},
    {"approach": "V2_rubric", "accuracy_on_valid": acc_v2, "json_valid_rate": jvr_v2, "consistency_2runs": cons_v2},
    {"approach": "V3_fewshot", "accuracy_on_valid": acc_v3, "json_valid_rate": jvr_v3, "consistency_2runs": cons_v3},
])

comparison.sort_values(by=["accuracy_on_valid", "json_valid_rate"], ascending=False)


In [None]:
print("""
Short discussion (edit after seeing your numbers):
- V1 (simple) is a baseline; it may drift or output non-JSON occasionally.
- V2 (rubric) usually improves rating calibration by anchoring what each star means.
- V3 (few-shot + strict JSON rules) often improves JSON validity and reduces formatting errors; it may also improve accuracy by demonstrating the target style.
""")
all_results = pd.concat([res_v1, res_v2, res_v3], ignore_index=True)
all_results.to_csv("task1_gemini_results.csv", index=False)
comparison.to_csv("task1_gemini_comparison.csv", index=False)

all_results.head()
