### Imports

In [76]:
import json
import time
import os
import getpass
import pandas as pd
from datasets import Dataset, load_dataset
from tqdm import tqdm
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

### Load dataset

In [39]:
dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
NUM_SAMPLES = 10000
articles = dataset["train"][:NUM_SAMPLES]["text"]
ids = dataset["train"][:NUM_SAMPLES]["id"]
articles = [x.split("\n")[0] for x in articles]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

In [13]:
len(articles)

10000

In [11]:
articles[4]

"In Greek mythology, Achilles ( ) or Achilleus () was a hero of the Trojan War who was known as being the greatest of all the Greek warriors. A central character in Homer's Iliad, he was the son of the Nereid Thetis and Peleus, king of Phthia and famous Argonaut. Achilles was raised in Phthia along his childhood companion Patroclus and received his education by the centaur Chiron. In the Iliad, he is presented as the commander of the mythical tribe of the Myrmidons. "

### Initialize OpenAI Environment

In [14]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

Enter your OpenAI API key: ········


In [15]:
llm = ChatOpenAI()

In [16]:
llm.model_name

'gpt-3.5-turbo'

### Create classification prompt

In [37]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """Your task is to assess the article and categorize the article into one of the following predfined categories:
'History', 'Geography', 'Science', 'Technology', 'Mathematics', 'Literature', 'Art', 'Music', 'Film', 'Television', 'Sports', 'Politics', 'Philosophy', 'Religion', 'Sociology', 'Psychology', 'Economics', 'Business', 'Medicine', 'Biology', 'Chemistry', 'Physics', 'Astronomy', 'Environmental Science', 'Engineering', 'Computer Science', 'Linguistics', 'Anthropology', 'Archaeology', 'Education', 'Law', 'Military', 'Architecture', 'Fashion', 'Cuisine', 'Travel', 'Mythology', 'Folklore', 'Biography', 'Mythology', 'Social Issues', 'Human Rights', 'Technology Ethics', 'Climate Change', 'Conservation', 'Urban Studies', 'Demographics', 'Journalism', 'Cryptocurrency', 'Artificial Intelligence'
you will output a json object containing the following information:

{{
    "id": string
    "category": string
}}
"""),
    ("human", "{input}")
])

### Build llm chain

In [38]:
chain = prompt | llm

### sample Inference

In [43]:
content = json.dumps({"id": ids[0], "article": articles[0]})
response = chain.invoke(content)
response.content

'{\n    "id": "12",\n    "category": "Politics"\n}'

In [44]:
batches = []
for index in range(3):
    batches.append(json.dumps({"id": ids[index], "article": articles[index]}))
chain.batch(batches)

[AIMessage(content='{\n    "id": "12",\n    "category": "Politics"\n}', response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 361, 'total_tokens': 377}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-d481a8e8-3135-4c07-a896-ed054aefb9de-0'),
 AIMessage(content='{\n    "id": "39",\n    "category": "Physics"\n}', response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 320, 'total_tokens': 336}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-214a446c-9863-4de0-ae66-634796d9d491-0'),
 AIMessage(content='{\n    "id": "290",\n    "category": "Linguistics"\n}', response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 408, 'total_tokens': 426}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-30f1744a-a7ee-4488-9a7d-d97a21d4eef1-0')]

In [47]:
json.loads(response.content)

{'id': '12', 'category': 'Politics'}

### Run inference

In [36]:
results = []
for article in tqdm(articles[:100]):
    try:
        result = chain.invoke({"article": article})
        results.append(result)
    except Exception as e:
        print("Exception Occured", e)
        results.append("")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:14<00:00,  1.35it/s]


In [98]:
740 / 60

12.333333333333334

In [97]:
results = []
BATCH_SIZE = 8
inputs = []

for index, article in tqdm(enumerate(articles[:1000])):
    inputs.append(json.dumps({"id": ids[index], "article": articles[index]}))
    
    if len(inputs) == BATCH_SIZE:
        time.sleep(1.5)
        response = chain.batch(inputs)
        results += response
        inputs = []
        
if inputs:
    response = chain.batch(inputs)
    results += response
           


1000it [05:13,  3.19it/s]


### Postprocessing

In [85]:
pd.DataFrame([x.response_metadata["token_usage"] for x in results])

Unnamed: 0,completion_tokens,prompt_tokens,total_tokens
0,16,361,377
1,17,320,337
2,18,408,426
3,17,328,345
4,18,379,397
...,...,...,...
95,17,354,371
96,18,363,381
97,18,465,483
98,16,296,312


In [89]:
success = []
failure = []

for output in results:
    content = output.content
    try:
        content = json.loads(content)
        success.append(content)
    except ValueError as e:
        failure.append(content)

In [90]:
pd.DataFrame(success)

Unnamed: 0,id,category
0,12,Politics
1,39,Environmental Science
2,290,Linguistics
3,303,Geography
4,305,Mythology
...,...,...
95,746,Geography
96,748,Astronomy
97,751,Martial Arts
98,752,Art
