# Classify Documents

In [1]:
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

## Load Data
The next cell will load embeddings generated in notebook [01-get-embeddings.ipynb](./01-get-embeddings.ipynb).

In [28]:
import pandas as pd

df_orig = pd.read_csv("bbc-news-data-embedding.csv", delimiter='\t', index_col=False)

In [29]:
df = df_orig.copy()
df

Unnamed: 0,category,filename,title,content,embedding
0,business,001.txt,Ad sales boost Time Warner profit,Quarterly profits at US media giant TimeWarne...,"[-0.02130456641316414, -0.01682969368994236, -..."
1,business,002.txt,Dollar gains on Greenspan speech,The dollar has hit its highest level against ...,"[-0.024546362459659576, -0.013037018477916718,..."
2,business,003.txt,Yukos unit buyer faces loan claim,The owners of embattled Russian oil giant Yuk...,"[-0.021691035479307175, -0.03697184473276138, ..."
3,business,004.txt,High fuel prices hit BA's profits,British Airways has blamed high fuel prices f...,"[-0.021805426105856895, -0.016833839938044548,..."
4,business,005.txt,Pernod takeover talk lifts Domecq,Shares in UK drinks and food firm Allied Dome...,"[-0.008381073363125324, -0.008448663167655468,..."
...,...,...,...,...,...
2241,tech,397.txt,BT program to beat dialler scams,BT is introducing two initiatives to help bea...,
2242,tech,398.txt,Spam e-mails tempt net shoppers,Computer users across the world continue to i...,
2243,tech,399.txt,Be careful how you code,A new European directive could put software w...,
2244,tech,400.txt,US cyber security chief resigns,The man making sure US computer networks are ...,


In [30]:
# drop rows with NaN
df.dropna(inplace=True)
len(df)

20

## Classify documents with their embeddings and Predict category
ref: https://github.com/openai/openai-cookbook/blob/main/examples/Classification_using_embeddings.ipynb

In [31]:
import requests
import json
import pandas as pd
from transformers import AutoTokenizer

# Initialize the tokenizer for the DistilBERT model
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english')

API_URL = "https://distil-eng-tsisodia.apps.gm0e690i8ebff95b7a.eastus.aroapp.io/v2/models/distil-eng/infer"

label_mapping = {0: 'BUSINESS', 1: 'ENTERTAINMENT', 2: 'POLITICS'}  # Adjust labels if needed

# Function to send a request to the OpenShift model and classify the document
def classify_document(document_text):

    inputs = tokenizer(document_text, return_tensors='pt', truncation=True, padding='max_length', max_length=128)
    
    input_ids = inputs['input_ids'].tolist()
    attention_mask = inputs['attention_mask'].tolist()

    payload = {
        "inputs": [
            {
                "name": "input_ids",
                "shape": [1, len(input_ids[0])],
                "datatype": "INT64",
                "data": input_ids[0]
            },
            {
                "name": "attention_mask",
                "shape": [1, len(attention_mask[0])],
                "datatype": "INT64",
                "data": attention_mask[0]
            }
        ]
    }

    try:
        response = requests.post(API_URL, json=payload, timeout=30)

        if response.status_code == 200:

            response_data = response.json()

            logits = response_data['outputs'][0]['data']

            predicted_class_id = logits.index(max(logits))

            classification = label_mapping.get(predicted_class_id, "Unknown")
            return classification
        else:
            print(f"Error: {response.status_code} - {response.text}")
            return None
    except requests.exceptions.RequestException as e:
        print(f"Error during API call: {e}")
        return None

df['predicted_label'] = df['content'].apply(classify_document)

## Save Results

In [None]:
# Save the results to a file
df.to_csv('classified_documents.csv', index=False)

print("Classification complete. Results saved to classified_documents.csv.")