In [11]:
from dotenv import load_dotenv
from openai import OpenAI
from datasets import load_dataset
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score
import json

In [17]:
load_dotenv()
client = OpenAI()

In [6]:
completion = client.chat.completions.create(
    model='gpt-3.5-turbo-0125',
    messages=[
        {'role': 'system', 'content': 'You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.'},
        {'role': 'user', 'content': 'Compose a poem that explains the concept of recursion in programming.'}
    ]
)

In [7]:
print(completion.choices[0].message)

ChatCompletionMessage(content="In the realm of loops and functions sublime,\nResides a technique quite divine,\nRecursion, a method of coding magic,\nA wondrous and powerful programming fabric.\n\nLike a mirror reflecting its own reflection,\nRecursion calls upon itself, with no objection,\nA function that calls itself, with grace,\nUnraveling problems in an elegant embrace.\n\nLike a Russian doll nested inside another,\nRecursion breaks problems down like a brother,\nDivide and conquer, the strategy it employs,\nSolving complexity with recursive joys.\n\nA dance of functions repeating in endless fashion,\nSolving puzzles with recursive passion,\nBase cases to break the cycle's hold,\nIn recursion's intricate tale, beautifully told.\n\nSo tread carefully in this recursive terrain,\nFor in its depths lies both pleasure and pain,\nBut wield it with care, let creativity soar,\nFor recursion's beauty, forevermore.", role='assistant', function_call=None, tool_calls=None)


In [13]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_classification_of_news_article",
            "description": "Get classification of news article",
            "parameters": {
                "type": "object",
                "properties": {
                    "classification": {
                        "type": "string",
                        "description": "the classification label of the news article",
                        "enum": [
                            "world",
                            "sports",
                            "business",
                            "technology",
                            "science"
                        ]   
                    }
                },
                "required": [ "classification" ]
            }
        }
    }
]

In [18]:
dataset = load_dataset('ag_news')

In [19]:
# Let's convert them to dataframes
df_test = pd.DataFrame(dataset['test'])

In [20]:
df_test.head(10)

Unnamed: 0,text,label
0,Fears for T N pension after talks Unions repre...,2
1,The Race is On: Second Private Team Sets Launc...,3
2,Ky. Company Wins Grant to Study Peptides (AP) ...,3
3,Prediction Unit Helps Forecast Wildfires (AP) ...,3
4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,3
5,Open Letter Against British Copyright Indoctri...,3
6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3
7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3
8,E-mail scam targets police chief Wiltshire Pol...,3
9,"Card fraud unit nets 36,000 cards In its first...",3


In [14]:
completion = client.chat.completions.create(
    model='gpt-3.5-turbo-0125',
    messages=[
        {'role': 'system', 'content': 'You are an agent that classifies news articles into appropriate categories. Return the classification as world if no other classification is appropriate.'},
        {'role': 'user', 'content': 'Fears for T N pension after talks Unions representing workers at Turner   Newall say they are \'disappointed\' after talks with stricken parent firm Federal Mogul.'}
    ],
    tools=tools,
    tool_choice="auto"
)

In [15]:
json.loads(completion.choices[0].message.tool_calls[0].function.arguments)

{'classification': 'world'}

In [24]:
def get_openai_classification(news_text):
    completion = client.chat.completions.create(
        model='gpt-3.5-turbo-0125',
        messages=[
            {'role': 'system', 'content': 'You are an agent that classifies news articles into appropriate categories. Return the classification as world if no other classification is appropriate.'},
            {'role': 'user', 'content': f'{news_text}'}
        ],
        tools=tools,
        tool_choice='auto'
    )

    label = json.loads(completion.choices[0].message.tool_calls[0].function.arguments)['classification']
    
    if label == 'world':
        return 0
    if label == 'sports':
        return 1
    if label == 'business':
        return 2
    if label == 'technology':
        return 3
    if label == 'science':
        return 3

In [25]:
get_openai_classification(df_test.iloc[0].text)

0

In [26]:
res_df_path = '../../data/openai_text_classification.csv'

In [27]:
res_df = pd.read_csv(res_df_path)

In [28]:
res_df.head(50)

Unnamed: 0,index,text,predicted_label,actual_label,execution_time
0,0,Fears for T N pension after talks Unions repre...,0.0,2,1.585969
1,1,The Race is On: Second Private Team Sets Launc...,3.0,3,1.012847
2,2,Ky. Company Wins Grant to Study Peptides (AP) ...,3.0,3,0.998038
3,3,Prediction Unit Helps Forecast Wildfires (AP) ...,3.0,3,1.043703
4,4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,0.0,3,1.288217
5,5,Open Letter Against British Copyright Indoctri...,3.0,3,0.963824
6,6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3.0,3,0.983094
7,7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3.0,3,0.907009
8,8,E-mail scam targets police chief Wiltshire Pol...,0.0,3,1.271042
9,9,"Card fraud unit nets 36,000 cards In its first...",2.0,3,1.034202


In [30]:
res_df['predicted_label'] = res_df['predicted_label'].fillna(0)

In [31]:
print(f'Total execution time {res_df['execution_time'].sum()} seconds')

Total execution time 872.2161402629827 seconds


In [32]:
print('Test Accuracy: {}'.format(
    accuracy_score(y_true=res_df['actual_label'], y_pred=res_df['predicted_label'])
))

Test Accuracy: 0.8257491675915649


In [33]:
print('Test f-score: {}'.format(
    f1_score(y_true=res_df['actual_label'], y_pred=res_df['predicted_label'], average='weighted'), 
))

Test f-score: 0.826007147952462
