In [13]:
from dotenv import load_dotenv
import os
import google.generativeai as genai
import google.ai.generativelanguage as glm
import textwrap
from datasets import load_dataset
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score

In [3]:
load_dotenv()

True

In [20]:
dataset = load_dataset('ag_news')

In [21]:
# Let's convert them to dataframes
df_train = pd.DataFrame(dataset['train'])
df_test = pd.DataFrame(dataset['test'])

In [22]:
df_train.head(10)

Unnamed: 0,text,label
0,Wall St. Bears Claw Back Into the Black (Reute...,2
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2
3,Iraq Halts Oil Exports from Main Southern Pipe...,2
4,"Oil prices soar to all-time record, posing new...",2
5,"Stocks End Up, But Near Year Lows (Reuters) Re...",2
6,Money Funds Fell in Latest Week (AP) AP - Asse...,2
7,Fed minutes show dissent over inflation (USATO...,2
8,Safety Net (Forbes.com) Forbes.com - After ear...,2
9,Wall St. Bears Claw Back Into the Black NEW Y...,2


In [11]:
genai.configure(api_key=os.environ['GOOGLE_KEY'])
model = genai.GenerativeModel('gemini-pro')

In [12]:
response = model.generate_content("Write a story about a magic backpack.")

In [13]:
print(response.text)

In the bustling town of Willow Creek, amidst the quaint shops and cobblestone streets, there existed a remarkable backpack—a backpack imbued with extraordinary magic.

Crafted from ancient, shimmering fabric, the backpack possessed an unassuming appearance. But within its four compartments resided secrets that defied all logic.

Emily, a curious and imaginative 12-year-old, stumbled upon the backpack at a dusty antique shop. Intrigued by its enigmatic presence, she decided to take it home. Little did she know that her life was about to change forever.

As Emily unzipped the first compartment, a brilliant array of colors erupted before her eyes. Books of all shapes and sizes floated weightlessly, their pages turning by themselves—a testament to the backpack's unspoken knowledge.

With a whisper of excitement, Emily reached for the second compartment. Suddenly, she was enveloped in a warm breeze that carried the scent of blooming flowers. Clothes of every imaginable style and color dance

In [29]:
classification = glm.Schema(
    type = glm.Type.OBJECT,
    properties = {
        'classification_label':  glm.Schema(type=glm.Type.STRING),
    },
    required=['classification_label']
)

In [30]:
classify_news_article = glm.FunctionDeclaration(
    name="classify_news_article",
    description=textwrap.dedent("""\
        Classify the news article into one of the four categories: world, sports, business, technology
        """),
    parameters=glm.Schema(
        type=glm.Type.OBJECT,
        properties = {
            'classification_label': classification
        }
    )
)

In [31]:
model = genai.GenerativeModel(model_name='gemini-1.0-pro', tools = [classify_news_article])

In [46]:
article = df_test.iloc[0].text

In [47]:
article

"Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."

In [48]:
result = model.generate_content(f"""
    Please classify the provided news article into one of the following categories: world, sports, business, technology. Do not return null as a classification label.

{article}
""")

In [54]:
fc = result.candidates[0].content.parts[0].function_call
print(type(fc).to_dict(fc)['args']['classification_label']['classification_label'])

business


In [57]:
def get_gemini_classification(news_text):
    result = model.generate_content(f"""
    Please classify the provided news article into one of the following categories: world, sports, business, technology. Do not return null as a classification label.

        {article}
    """)
    
    fc_result = result.candidates[0].content.parts[0].function_call
    label = type(fc_result).to_dict(fc_result)['args']['classification_label']['classification_label']
    if label == 'world':
        return 0
    if label == 'sports':
        return 1
    if label == 'business':
        return 2
    if label == 'technology':
        return 3

In [58]:
get_gemini_classification(df_test.iloc[5].text)

2

We will now find the accuracy and other metrics. 

In [4]:
res_df_path = '../../data/gemini_text_classification.csv'

In [5]:
res_df = pd.read_csv(res_df_path)

In [6]:
res_df.head(10)

Unnamed: 0,index,text,predicted_label,actual_label,execution_time
0,0,Fears for T N pension after talks Unions repre...,2.0,2,2.983072
1,1,The Race is On: Second Private Team Sets Launc...,,3,1.62559
2,2,Ky. Company Wins Grant to Study Peptides (AP) ...,2.0,3,1.397071
3,3,Prediction Unit Helps Forecast Wildfires (AP) ...,0.0,3,1.525932
4,4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,0.0,3,1.432235
5,5,Open Letter Against British Copyright Indoctri...,0.0,3,1.773903
6,6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3.0,3,1.356559
7,7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3.0,3,1.621155
8,8,E-mail scam targets police chief Wiltshire Pol...,0.0,3,1.395739
9,9,"Card fraud unit nets 36,000 cards In its first...",2.0,3,1.535353


In [8]:
res_df['predicted_label'] = res_df['predicted_label'].fillna(0)

In [10]:
res_df.head(10)

Unnamed: 0,index,text,predicted_label,actual_label,execution_time
0,0,Fears for T N pension after talks Unions repre...,2.0,2,2.983072
1,1,The Race is On: Second Private Team Sets Launc...,0.0,3,1.62559
2,2,Ky. Company Wins Grant to Study Peptides (AP) ...,2.0,3,1.397071
3,3,Prediction Unit Helps Forecast Wildfires (AP) ...,0.0,3,1.525932
4,4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,0.0,3,1.432235
5,5,Open Letter Against British Copyright Indoctri...,0.0,3,1.773903
6,6,"Loosing the War on Terrorism \\""Sven Jaschan, ...",3.0,3,1.356559
7,7,"FOAFKey: FOAF, PGP, Key Distribution, and Bloo...",3.0,3,1.621155
8,8,E-mail scam targets police chief Wiltshire Pol...,0.0,3,1.395739
9,9,"Card fraud unit nets 36,000 cards In its first...",2.0,3,1.535353


In [11]:
print(f'Total execution time {res_df['execution_time'].sum()} seconds')

Total execution time 11056.712309272007 seconds


In [14]:
print('Test Accuracy: {}'.format(
    accuracy_score(y_true=res_df['actual_label'], y_pred=res_df['predicted_label'])
))

Test Accuracy: 0.7794960671910411


In [15]:
print('Test f-score: {}'.format(
    f1_score(y_true=res_df['actual_label'], y_pred=res_df['predicted_label'], average='weighted'), 
))

Test f-score: 0.7602538701317094
