### Imports

In [4]:
import json
import time
import os
import getpass
import pandas as pd
from datasets import Dataset, load_dataset
from tqdm import tqdm
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI

### Load dataset

In [5]:
dataset = load_dataset("wikimedia/wikipedia", "20231101.en")
NUM_SAMPLES = 10000
articles = dataset["train"][:NUM_SAMPLES]["text"]
ids = dataset["train"][:NUM_SAMPLES]["id"]
articles = [x.split("\n")[0] for x in articles]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

In [6]:
len(articles)

10000

### Initialize OpenAI Environment

In [9]:
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your OpenAI API key: ")

Enter your OpenAI API key: ········


In [10]:
llm = ChatOpenAI()

In [11]:
llm.model_name

'gpt-3.5-turbo'

### Create classification prompt

In [12]:
prompt = ChatPromptTemplate.from_messages([
    ("system", """You will be provided with a JSON input which contains id and article content. Your task is to assess the article and categorize the article into one of the following predfined categories:
'History', 'Geography', 'Science', 'Technology', 'Mathematics', 'Literature', 'Art', 'Music', 'Film', 'Television', 'Sports', 'Politics', 'Philosophy', 'Religion', 'Sociology', 'Psychology', 'Economics', 'Business', 'Medicine', 'Biology', 'Chemistry', 'Physics', 'Astronomy', 'Environmental Science', 'Engineering', 'Computer Science', 'Linguistics', 'Anthropology', 'Archaeology', 'Education', 'Law', 'Military', 'Architecture', 'Fashion', 'Cuisine', 'Travel', 'Mythology', 'Folklore', 'Biography', 'Mythology', 'Social Issues', 'Human Rights', 'Technology Ethics', 'Climate Change', 'Conservation', 'Urban Studies', 'Demographics', 'Journalism', 'Cryptocurrency', 'Artificial Intelligence'
you will output a json object containing the following information:
[
{{
    "id": string // id of the artice1
    "category": string // category of the article1
}},
{{
    "id": string // id of the artice2
    "category": string // category of the article2
}},
...
]
"""),
    ("human", "{input}")
])

### Build llm chain

In [13]:
chain = prompt | llm

### sample Inference

In [14]:
content = json.dumps([
    {"id": ids[0], "article": articles[0]},
    {"id": ids[1], "article": articles[1]}
])
response = chain.invoke(content)
response.content

'[\n    {\n        "id": "12",\n        "category": "Politics"\n    },\n    {\n        "id": "39",\n        "category": "Environmental Science"\n    }\n]'

In [15]:
json.loads(response.content)

[{'id': '12', 'category': 'Politics'},
 {'id': '39', 'category': 'Environmental Science'}]

### Run inference

In [15]:
# results = []
# for article in tqdm(articles[:100]):
#     try:
#         result = chain.invoke({"article": article})
#         results.append(result)
#     except Exception as e:
#         print("Exception Occured", e)
#         results.append("")

In [16]:
results = []
BATCH_SIZE = 8
inputs = []

for index, article in tqdm(enumerate(articles[:100])):
    inputs.append({"id": ids[index], "article": articles[index]})
    
    if len(inputs) == BATCH_SIZE:
        response = chain.invoke(json.dumps(inputs))
        results.append(response)
        inputs = []
if inputs:
    response = chain.invoke(json.dumps(inputs))
    results.append(response)


100it [00:30,  3.30it/s]


### Postprocessing

In [25]:
pd.DataFrame([x.response_metadata["token_usage"] for x in results])["total_tokens"].sum()

17511

In [28]:
success = []
failure = []

for output in results:
    content = output.content
    try:
        content = json.loads(content)
        success += content
    except ValueError as e:
        failure.append(content)

In [29]:
pd.DataFrame(success)

Unnamed: 0,id,category
0,12,Politics
1,39,Science
2,290,Linguistics
3,303,Geography
4,305,Mythology
...,...,...
71,746,Geography
72,748,Astronomy
73,751,Martial Arts
74,752,Art
