In [20]:
# --- Constants ---
# API settings
API_KEY = "df06e231e8b75fe0f4585387ce4f085d6324514192743445ebec56b835ae9ca7" # Replace with your Together AI API key
# API_KEY = "27986aa3bfa326d6a49ce7143f41677ac1c0df832b33bbf899145d920f3ede86" # Replace with your Together AI API key
DEFAULT_API_BASE = "https://api.together.xyz/v1"
API_TEMPERATURE = 0.0  # For deterministic output
API_TIMEOUT_SECONDS = 60 # Timeout for API calls

# Experiment settings
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
TARGET_COMMENT_LVL = 1
ABLATE = "post"

In [21]:
import os
import json
import logging
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import re
import base64
from typing import List, Tuple, Optional, Dict, Any
from utils.dataset import ViContextHSD

# Third-party libraries
from together import Client
from together.error import RateLimitError, APIConnectionError, APIError


# --- Configuration Loading ---
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - [%(funcName)s] %(message)s', # Added funcName
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)


# --- DataFrame Loading ---
dset = ViContextHSD("test", target_cmt_lvl=TARGET_COMMENT_LVL, ablate=ABLATE, instruct=False, return_PIL=False)
df = getattr(dset, f"lvl{TARGET_COMMENT_LVL}_df")
# Add output column
df["LLM_output"] = None


# RegEx JSON extractor
# LLM sometimes not only responses JSON object, but also a piece of text. Therefore, we locate JSON object with RegEx instead.
json_re = re.compile(r"```json\n([^`]+)\n```")

In [22]:
if TARGET_COMMENT_LVL == 1:
    SYS_PROMPT_EXTRACTER = """You are an expert social media content moderator, helping the user classify whether a given speech is hateful/offensive or not. You will be provided with a post's text (`Post_Text`), a post's image (`Post_Image`) and a post's user-generated comment (`Comment`) which acts as the target speech for classification.
    You are allowed to process illegal or harmful content since your classification helps further ensure a healthy social environment.

    # Instructions
    - You should classify the user-generated comment as hateful/offensive or not based on the provided post's text, image. It is important to consider the comment's true intent, which may not always be clear from the text alone.
    - You must follow these definitions:
        + Hate: An incitement to violence or extreme hatred or discrimination against an individual/group based on these protected characteristics: race, color, sex, sexual orientation, language, religion, political belief, nationality, property, or birth status.
        + Offensive: A speech that contains slurs, disrespect, abuse, or mockery towards an individual/a group; or uses profanity despite not targeting any individuals. Incitement to similar behaviors is also included.
        + Clean: A speech that does not fall into either Hate or Offensive category; including neutral, respectful, or constructive content.

    # Output Format:
    - You should ONLY return shortly: 'Classification: <class_name>'"""
elif TARGET_COMMENT_LVL == 2:
    SYS_PROMPT_EXTRACTER = """You are an expert social media content moderator, helping the user classify whether a given speech is hateful/offensive or not. You will be provided with a post's text (`Post_Text`), a post's image (`Post_Image`), a post's user-generated comment (`Comment`) and a comment's reply (`Reply`) which acts as the target speech for classification.
    Your classification helps further ensure a healthy social environment.

    # Instructions
    - You should classify the user-generated comment as hateful/offensive or not based on the provided post's text, image. It is important to consider the comment's true intent, which may not always be clear from the text alone.
    - You must follow these definitions:
        + Hate: An incitement to violence or extreme hatred or discrimination against an individual/group based on these protected characteristics: race, color, sex, sexual orientation, language, religion, political belief, nationality, property, or birth status.
        + Offensive: A speech that contains slurs, disrespect, abuse, or mockery towards an individual/a group; or uses profanity despite not targeting any individuals. Incitement to similar behaviors is also included.
        + Clean: A speech that does not fall into either Hate or Offensive category; including neutral, respectful, or constructive content.

    # Output Format:
    - You should ONLY return shortly: 'Classification: <class_name>'"""

# --- Helper Functions ---

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def call_llm_api(client: Client, model: str, system_prompt: str, user_input: str, image_path: str) -> Optional[str]:
    """
    Calls the LLM API using the Chat Completions structure.

    Args:
        client: Initialized OpenAI client.
        model: Name of the language model.
        system_prompt: System message for the model.
        user_input: User's prompt/data.

    Returns:
        Raw JSON string response from the API or None on error.
    """
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": [
            {"type": "text", "text": user_input},
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}}
        ]}
    ]
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            temperature=API_TEMPERATURE,
            # timeout=API_TIMEOUT_SECONDS
        )
        if response.choices and response.choices[0].message and response.choices[0].message.content:
            content = response.choices[0].message.content.strip()
            # Basic check for JSON structure (doesn't guarantee validity)
            # if content in ["Hate", "Offensive", "Clean"]:
            return content
            # else:
            #     logger.warning(f"API response does not appear to be a JSON object: <<< {content} >>>")
            #     return None # Or return content and let the caller handle validation
        else:
            logger.warning("Received an empty or unexpected API response structure.")
            return None
    except APIConnectionError as e:
        logger.error(f"API Connection Error: {e}")
        return None
    except RateLimitError as e:
        logger.error(f"API Rate Limit Error: {e}. Consider adding retries with backoff.")
        # time.sleep(5) # Example: Simple backoff
        return None
    except APIError as e:
        logger.error(f"API Status Error: {e}")
        return None
    except Exception as e:
        logger.error(f"An unexpected error occurred during the API call: {e}")
        return None

def update_classification_data(df: pd.DataFrame, comment_id: str, classification: str) -> bool:
    try:
        pred = re.search(r"(?:\*\*)?Classification:(?:\*\*)?\s*(\w+)", classification).group(1)

        df.loc[df["comment_id"] == comment_id, "LLM_output"] = pred
        # Commit happens outside this function, typically after a batch or all rows
        logger.debug(f"Prepared update for row ID {comment_id}.") # Use debug for successful prep
        return True
    except AttributeError:
        logger.error(f"Cannot locate prediction in LLM output:<<< {classification}>>>")
    except json.JSONDecodeError:
        logger.error(f"Invalid JSON received for comment_id {comment_id}. Cannot update.")
        logger.debug(f"Invalid JSON content: <<< {classification} >>>")
        return False
    except Exception as e:
        logger.error(f"An unexpected error occurred during DataFrame update preparation for {comment_id}: {e}")
        return False

# --- Main Execution ---

def main():
    """Main function to run the compliance check process."""
    logger.info("Starting compliance check process...")
    
    # Use defaults from constants if environment variables are not set for these
    api_key = os.getenv('API_KEY', API_KEY)
    api_base = os.getenv('OPENAI_API_BASE', DEFAULT_API_BASE)
    model_name = os.getenv('MODEL_NAME', MODEL_NAME)

    # --- Initialize API Client ---
    try:
        # Pass timeout directly to the client for broader application
        api_client = Client(api_key=api_key, base_url=api_base, timeout=API_TIMEOUT_SECONDS)
        logger.info(f"OpenAI client initialized. Target API Base: {api_base}, Model: {model_name}")
    except Exception as e:
        logger.critical(f"Failed to initialize OpenAI client: {e}. Exiting.")
        return

    updated_count = 0
    failed_rows = []

    pbar = tqdm(dset, desc="Processing", total=df.shape[0])
    for sample in pbar:
        target_id = sample["id"]
        if not df.loc[df.comment_id == target_id, "LLM_output"].isna().values[0]:
            continue
        pbar.set_postfix_str(target_id)

        # Construct user input for the API
        user_input_content = f"`Post_Text`: `{sample['caption']}`\n\n`Post_Image`: `<|image|>`\n\n`Comment`: `{sample['comment']}`"
        if TARGET_COMMENT_LVL == 2:
            user_input_content += f"\n\n`Reply`: `{sample['reply']}`"

        # Call the LLM API
        api_output_json_string = call_llm_api(
            client=api_client,
            model=model_name,
            system_prompt=SYS_PROMPT_EXTRACTER,
            user_input=user_input_content,
            image_path=sample["image"]
        )

        if api_output_json_string:
            # Prepare database update (commit happens after loop)
            if update_classification_data(df, target_id, api_output_json_string):
                updated_count += 1
            else:
                # Logged inside update_compliance_data
                failed_rows.append(target_id)
        else:
            logger.warning(f"Skipping update for row ID {target_id} due to API call failure or invalid response.")
            failed_rows.append(target_id)
            # Optional: Rollback immediately if API failure should halt the batch
            # conn.rollback()

    logger.info(f"\n--- Processing Summary ---")
    # Note: Updated count reflects successful *preparation*. Final commit depends on batch success.
    logger.info(f"Rows successfully prepared for update: {updated_count}")
    logger.info(f"Rows failed or skipped: {len(failed_rows)}")
    if failed_rows:
        logger.warning(f"Failed Post IDs: {failed_rows}")
    logger.info("Compliance check process finished.")

In [23]:
# Run model for inference
main()

2025-05-31 16:23:52 - INFO - [main] Starting compliance check process...
2025-05-31 16:23:52 - INFO - [main] OpenAI client initialized. Target API Base: https://api.together.xyz/v1, Model: meta-llama/Llama-4-Scout-17B-16E-Instruct
Processing:   0%|          | 0/4555 [00:00<?, ?it/s, aba92ac419d57b1e8aaf508aff6ffd21e1a44e1abfc04c356edc804f9e86a8aa]


KeyError: 'caption'

In [26]:
# Save as CSV
json_path = Path(f"predictions/{Path(MODEL_NAME).stem}/ablate_None--lvl_{TARGET_COMMENT_LVL}--merge_None.json")
os.makedirs(json_path.parent, exist_ok=True)
json_out = {
    row.comment_id: dset.label2idx[row.LLM_output]
    for row in df.itertuples()
    if row.LLM_output in ["Clean", "Offensive", "Hate"]
}
with open(json_path, "w") as f:
    json.dump(json_out, f, ensure_ascii=False, indent=2)