# Setup

## Libraries

In [1]:
import os

if os.getcwd().endswith('notebooks'):
    os.chdir('..')

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

from openai import OpenAI
from dotenv import load_dotenv

## Functions

In [2]:
def classify_tweet(client, prompt, keyword, tweet):
    """
    Classify a tweet as related to a natural disaster (1) or not (0) using OpenAI's GPT-4 model.

    This function takes a predefined prompt template, a keyword, and a tweet, then uses
    the OpenAI API to determine if the tweet is about a natural disaster.

    Parameters
    ----------
    prompt : str
        A string template for the classification prompt, with placeholders for keyword and tweet.
    keyword : str
        A relevant keyword related to potential disasters, or 'N/A' if not available.
    tweet : str
        The full text of the tweet to be classified.

    Returns
    -------
    int
        1 if the tweet is classified as being about a natural disaster, 0 otherwise.

    Examples
    --------
    >>> prompt_template = "Classify the tweet. Keyword: {keyword}\\nTweet: {tweet}\\nOutput 1 for disaster, 0 for not."
    >>> result = classify_tweet(prompt_template, "earthquake", "Just felt a huge shake! #earthquake")
    >>> print(result)
    1
    """
    
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are a precise tweet classifier for natural disasters."},
            {"role": "user", "content": prompt.format(keyword=keyword, tweet=tweet)}
        ],
        max_tokens=1,
        temperature=0
    )
    
    classification = response.choices[0].message.content.strip()
    return int(classification)

## Parameters

In [3]:
# Load environment file
load_dotenv()

# Set your API key
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# Default prompt for OpenAI
prompt = """Classify the following tweet as related to a natural disaster (1) or not (0).
Keyword: {keyword}
Tweet: {tweet}

Consider the keyword if provided (N/A means not available). Focus on identifying mentions or implications of natural disasters like earthquakes, hurricanes, floods, wildfires, etc.

Output only a single digit: 1 for tweets about natural disasters, 0 for unrelated tweets.
"""

## Load data

In [4]:
# Load dataset
df_train = pd.read_csv('data/train.csv', index_col='id')
df_test = pd.read_csv('data/test.csv', index_col='id')

# Split dataset
X = df_train.drop(columns='target')
y = df_train['target']

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Main

## Preprocess Keywords

In [5]:
X_val['keyword'] = X_val.keyword.fillna('N/A')
df_test['keyword'] = df_test.keyword.fillna('N/A')

## Validation set

In [6]:
# Compute predictons
y_pred = np.array([classify_tweet(client, prompt, row.keyword, row.text) for id,row in tqdm(list(X_val.iterrows()))])

100%|██████████| 1523/1523 [12:53<00:00,  1.97it/s]


In [7]:
# Print metrics
print(classification_report(y_val, y_pred))

              precision    recall  f1-score   support

           0       0.64      0.99      0.77       874
           1       0.94      0.24      0.38       649

    accuracy                           0.67      1523
   macro avg       0.79      0.61      0.58      1523
weighted avg       0.77      0.67      0.61      1523



## Test set

In [8]:
# Compute predictons
y_pred = np.array([classify_tweet(client, prompt, row.keyword, row.text) for id,row in tqdm(list(df_test.iterrows()))])

100%|██████████| 3263/3263 [27:37<00:00,  1.97it/s] 


In [9]:
# Save predictions
df_test_submission = pd.DataFrame({
    'id': df_test.index,
    'target': y_pred
})

df_test_submission.to_csv('results/OpenAi_submission.csv', index=False)