# Custom text classification using OCI Language Service Endpoint

This Notebook demonstrates how to call batch text classification API to classify text using custom text classification model from OCI Language

In [18]:
import time
import oci
import pandas as pd
import math
import datetime

from typing import Any, Dict, List, Union
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix


## Initialize OCI AI Client

Make sure you have setup config file by following steps mentioned in </br>
[OCI Langauge Service Live lab Task-1](https://apexapps.oracle.com/pls/apex/dbpm/r/livelabs/run-workshop?p210_wid=887&p210_wec=&session=108183149172107)

In [19]:
ai_client = oci.ai_language.AIServiceLanguageClient(oci.config.from_file(profile_name='OCASCUST'))
ai_client.base_client.timeout = 60
wait_between_batch = 3
wait_between_retries = 5

Read the input dataset

In [20]:
df = pd.read_csv('~/Downloads/TicketData_train.csv')#, nrows=200)

### split the dataset into batches
OCI endpoint can process upto 500 characters per second. OCI Language batch API has a limit of max 100 documents and 20k characters

In [21]:
max_rec_for_batch_call = 10
max_tex_per_batch = 19900
model_endpoint = 'ocid1.ailanguageendpoint.oc1.phx.amaaaaaasxs7gpyajgpwo3txxfk7o7lgfdsclauvvjzacfztilm57p2rt3oa'

In [22]:
def process_and_upadate_batch(df):
    output = None
    documents = [oci.ai_language.models.TextDocument(key=str(p), text=row['text']) for p,row in df.iterrows()]
    prediction = [None] * len(documents)
    confidence = [None] * len(documents)
    
    classificaton_details = oci.ai_language.models.BatchDetectLanguageTextClassificationDetails(endpoint_id=model_endpoint,documents = documents)
    retry_count = 0
    success=False
    
    start_index = df.index.min()
    
    MAX_RETRYCOUNT=3
    while retry_count <MAX_RETRYCOUNT and success != True:
        try:
            output = ai_client.batch_detect_language_text_classification(classificaton_details)
            predicted_labels = [d.text_classification[0].label if len(d.text_classification) >0 else None for d in output.data.documents]
            predicted_conf = [d.text_classification[0].label if len(d.text_classification) >0 else None for d in output.data.documents]
            errors = [d.key for e in output.data.errors]
            success = True
            
            if len(df) != len(predicted_labels):
                print(f'{datetime.datetime.now()} failed inference for {len(df)-len(predicted_labels)} records')
            
            #dealing with predictoin errors, predicting a class could fail due to max/min length, wrong encoding ,etc
            i = 0
            for l in predicted_labels:
                if i+start_index in errors:
                    print(f'there was an error at {i+start_index}')
                    prediction[i] = None
                else:
                    prediction[i] = predicted_labels[i]
                    confidence[i] = predicted_conf [i]
                i = i+1

        except oci.exceptions.ServiceError as e:
            print(f'{datetime.datetime.now()} Unable to process these records {df.index.min()}: {df.index.max}. Retrying {retry_count} time')
            if retry_count == 0: 
                print(f'Error details:{e}')
            time.sleep(wait_between_retries)
        except oci.exceptions.ClientError as e:
            print(f'{datetime.datetime.now()} Error occurred while processing records {df.index.min()}: {df.index.max()}. Retrying {retry_count} time')
            if retry_count == 0: 
                print(f'Error details:{e}')
            time.sleep(wait_between_retries)
        except Exception as e:
            print(f'{datetime.datetime.now()} Error occurred while processing records {df.index.min()}: {df.index.max()}. Retrying {retry_count} time')
            if retry_count == 0: 
                print(f'Error details:{e}')
            time.sleep(wait_between_retries)
        finally:
            retry_count = retry_count +1

    return prediction, confidence
    

In [23]:
def process_and_update_slice(df):
    for name, group in df.groupby((df.text.str.len().cumsum()/max_tex_per_batch).apply(math.floor)):
        row_start = 0
        while row_start < group.shape[0]:
            rows = group[row_start:row_start+max_rec_for_batch_call]
            
            print(f'{datetime.datetime.now()} processing rows:{group[row_start:row_start+max_rec_for_batch_call].index.min()}:{group[row_start:row_start+max_rec_for_batch_call].index.max()}')
            
            prediction, confidence = process_and_upadate_batch(rows)
            #print(f'sub batch items{row_start}:{row_start+max_rec_for_batch_call} prediction:{len(prediction)}, conf:{len(confidence)}')
            df.loc[rows.index,'predicted'] = prediction
            df.loc[rows.index,'confidence'] = confidence
            
            row_start = row_start + max_rec_for_batch_call
            time.sleep(wait_between_batch)
    print(f'{datetime.datetime.now()} completed processing {len(df)} rows')

Predicting classes

In [24]:
df['predicted']=None
df['confidence'] = None
process_and_update_slice(df)

2023-01-14 15:05:59.073531 processing rows:0:9
2023-01-14 15:06:03.815451 processing rows:10:19
2023-01-14 15:06:07.431467 processing rows:20:29
2023-01-14 15:06:11.067057 processing rows:30:39
2023-01-14 15:06:14.713705 processing rows:40:49
2023-01-14 15:06:18.405699 processing rows:50:59
2023-01-14 15:06:22.036903 processing rows:60:69
2023-01-14 15:06:25.589011 processing rows:70:79
2023-01-14 15:06:29.222927 processing rows:80:89
2023-01-14 15:06:32.914631 processing rows:90:99
2023-01-14 15:06:36.585881 processing rows:100:109
2023-01-14 15:06:40.322552 processing rows:110:119
2023-01-14 15:06:44.117983 processing rows:120:129
2023-01-14 15:06:48.005123 processing rows:130:139
2023-01-14 15:06:51.656080 processing rows:140:149
2023-01-14 15:06:55.220648 processing rows:150:159
2023-01-14 15:06:58.950416 processing rows:160:169
2023-01-14 15:07:02.704937 processing rows:170:179
2023-01-14 15:07:06.230715 processing rows:180:189
2023-01-14 15:07:09.891015 processing rows:190:199
20

2023-01-14 15:16:12.357249 processing rows:1584:1593
2023-01-14 15:16:16.302891 processing rows:1594:1603
2023-01-14 15:16:20.144762 processing rows:1604:1613
2023-01-14 15:16:23.940801 processing rows:1614:1623
2023-01-14 15:16:27.840091 processing rows:1624:1633
2023-01-14 15:16:32.044622 processing rows:1634:1643
2023-01-14 15:16:35.930550 processing rows:1644:1653
2023-01-14 15:16:39.904142 processing rows:1654:1663
2023-01-14 15:16:43.751789 processing rows:1664:1673
2023-01-14 15:16:47.445438 processing rows:1674:1683
2023-01-14 15:16:51.335701 processing rows:1684:1693
2023-01-14 15:16:55.037157 processing rows:1694:1703
2023-01-14 15:16:58.955880 processing rows:1704:1713
2023-01-14 15:17:02.827554 processing rows:1714:1723
2023-01-14 15:17:06.439937 processing rows:1724:1733
2023-01-14 15:17:10.547302 processing rows:1734:1743
2023-01-14 15:17:14.641422 processing rows:1744:1753
2023-01-14 15:17:18.732639 processing rows:1754:1763
2023-01-14 15:17:22.626050 processing rows:176

2023-01-14 15:26:10.308591 processing rows:3122:3131
2023-01-14 15:26:14.063449 processing rows:3132:3141
2023-01-14 15:26:17.753864 processing rows:3142:3151
2023-01-14 15:26:21.642972 processing rows:3152:3161
2023-01-14 15:26:25.338786 processing rows:3162:3171
2023-01-14 15:26:29.342011 processing rows:3172:3181
2023-01-14 15:26:33.261668 processing rows:3182:3191
2023-01-14 15:26:37.040092 processing rows:3192:3201
2023-01-14 15:26:40.751873 processing rows:3202:3211
2023-01-14 15:26:44.658300 processing rows:3212:3221
2023-01-14 15:26:48.353393 processing rows:3222:3231
2023-01-14 15:26:52.357830 processing rows:3232:3241
2023-01-14 15:26:56.286139 processing rows:3242:3251
2023-01-14 15:27:00.042475 processing rows:3252:3261
2023-01-14 15:27:03.947809 processing rows:3262:3271
2023-01-14 15:27:07.552627 processing rows:3272:3281
2023-01-14 15:27:11.459333 processing rows:3282:3291
2023-01-14 15:27:15.261293 processing rows:3292:3301
2023-01-14 15:27:18.945909 processing rows:330

2023-01-14 15:35:58.871798 processing rows:4671:4680
2023-01-14 15:36:02.462881 processing rows:4681:4690
2023-01-14 15:36:06.454827 processing rows:4691:4700
2023-01-14 15:36:10.335308 processing rows:4701:4710
2023-01-14 15:36:14.199846 processing rows:4711:4720
2023-01-14 15:36:17.977979 processing rows:4721:4730
2023-01-14 15:36:21.871493 processing rows:4731:4740
2023-01-14 15:36:25.768647 processing rows:4741:4750
2023-01-14 15:36:29.479734 processing rows:4751:4760
2023-01-14 15:36:33.487600 processing rows:4761:4770
2023-01-14 15:36:37.351674 processing rows:4771:4780
2023-01-14 15:36:41.161873 processing rows:4781:4782
2023-01-14 15:36:44.751654 processing rows:4783:4792
2023-01-14 15:36:48.555632 processing rows:4793:4802
2023-01-14 15:36:52.261478 processing rows:4803:4812
2023-01-14 15:36:55.914440 processing rows:4813:4822
2023-01-14 15:36:59.980896 processing rows:4823:4832
2023-01-14 15:37:03.776783 processing rows:4833:4842
2023-01-14 15:37:07.587296 processing rows:484

2023-01-14 15:45:52.569733 processing rows:6205:6214
2023-01-14 15:45:56.175790 processing rows:6215:6224
2023-01-14 15:46:00.086364 processing rows:6225:6234
2023-01-14 15:46:03.878760 processing rows:6235:6244
2023-01-14 15:46:07.676914 processing rows:6245:6254
2023-01-14 15:46:11.375267 processing rows:6255:6264
2023-01-14 15:46:15.174978 processing rows:6265:6274
2023-01-14 15:46:19.123575 processing rows:6275:6284
2023-01-14 15:46:22.893463 processing rows:6285:6294
2023-01-14 15:46:26.837304 processing rows:6295:6304
2023-01-14 15:46:30.566924 processing rows:6305:6314
2023-01-14 15:46:34.275298 processing rows:6315:6324
2023-01-14 15:46:37.989326 processing rows:6325:6334
2023-01-14 15:46:41.675063 processing rows:6335:6344
2023-01-14 15:46:45.475844 processing rows:6345:6354
2023-01-14 15:46:49.185495 processing rows:6355:6364
2023-01-14 15:46:52.891138 processing rows:6365:6374
2023-01-14 15:46:56.794325 processing rows:6375:6384
2023-01-14 15:47:01.086145 processing rows:638

2023-01-14 15:55:41.981752 processing rows:7746:7755
2023-01-14 15:55:45.593667 processing rows:7756:7765
2023-01-14 15:55:49.479898 processing rows:7766:7767
2023-01-14 15:55:52.981595 processing rows:7768:7777
2023-01-14 15:55:56.899287 processing rows:7778:7787
2023-01-14 15:56:00.687010 processing rows:7788:7797
2023-01-14 15:56:04.484945 processing rows:7798:7807
2023-01-14 15:56:08.110388 processing rows:7808:7817
2023-01-14 15:56:11.685893 processing rows:7818:7827
2023-01-14 15:56:15.797832 processing rows:7828:7837
2023-01-14 15:56:19.560327 processing rows:7838:7847
2023-01-14 15:56:23.289747 processing rows:7848:7857
2023-01-14 15:56:27.258773 processing rows:7858:7867
2023-01-14 15:56:30.916589 processing rows:7868:7877
2023-01-14 15:56:34.687103 processing rows:7878:7887
2023-01-14 15:56:38.499228 processing rows:7888:7897
2023-01-14 15:56:42.087816 processing rows:7898:7907
2023-01-14 15:56:45.884605 processing rows:7908:7917
2023-01-14 15:56:49.986433 processing rows:791

2023-01-14 16:05:27.602647 processing rows:9284:9293
2023-01-14 16:05:31.669869 processing rows:9294:9303
2023-01-14 16:05:35.300429 processing rows:9304:9313
2023-01-14 16:05:39.223243 processing rows:9314:9323
2023-01-14 16:05:42.820327 processing rows:9324:9333
2023-01-14 16:05:46.821554 processing rows:9334:9343
2023-01-14 16:05:50.404785 processing rows:9344:9353
2023-01-14 16:05:54.096762 processing rows:9354:9363
2023-01-14 16:05:57.903994 processing rows:9364:9373
2023-01-14 16:06:01.818286 processing rows:9374:9383
2023-01-14 16:06:05.589441 processing rows:9384:9393
2023-01-14 16:06:09.400641 processing rows:9394:9403
2023-01-14 16:06:13.131346 processing rows:9404:9413
2023-01-14 16:06:17.124335 processing rows:9414:9423
2023-01-14 16:06:20.896710 processing rows:9424:9433
2023-01-14 16:06:24.678038 processing rows:9434:9443
2023-01-14 16:06:28.499727 processing rows:9444:9453
2023-01-14 16:06:32.711451 processing rows:9454:9463
2023-01-14 16:06:36.400843 processing rows:946

2023-01-14 16:15:05.141809 processing rows:10787:10789
2023-01-14 16:15:08.678575 processing rows:10790:10799
2023-01-14 16:15:12.480777 processing rows:10800:10809
2023-01-14 16:15:16.327793 processing rows:10810:10819
2023-01-14 16:15:20.120556 processing rows:10820:10829
2023-01-14 16:15:24.024714 processing rows:10830:10839
2023-01-14 16:15:27.620534 processing rows:10840:10849
2023-01-14 16:15:31.328808 processing rows:10850:10859
2023-01-14 16:15:35.312763 processing rows:10860:10869
2023-01-14 16:15:39.036657 processing rows:10870:10879
2023-01-14 16:15:42.938224 processing rows:10880:10889
2023-01-14 16:15:46.623507 processing rows:10890:10899
2023-01-14 16:15:50.424559 processing rows:10900:10909
2023-01-14 16:15:54.050847 processing rows:10910:10919
2023-01-14 16:15:57.815313 processing rows:10920:10929
2023-01-14 16:16:01.515403 processing rows:10930:10939
2023-01-14 16:16:05.222031 processing rows:10940:10949
2023-01-14 16:16:08.891920 processing rows:10950:10959
2023-01-14

In [25]:
#Ignore failed inferences, could be due to wrong encoding format - TBD investiggate further
predicted_df = df.dropna(subset=['predicted', 'confidence'])

## Generating class metrics and confusion matrix

In [26]:
#Setting this as false as our case is single label
IS_MULTI_LABEL = False

In [27]:
def run_sklearn_report(y_true, y_pred, labels=None):
    """prepare classification report"""

    if labels is None:
        labels = sorted(set(y_true))

    #print('\n%s', classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, digits=3))
    clf_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, output_dict=True)
    
    eval_report = _get_evaluation_report(clf_report, y_true, y_pred, labels)
    return eval_report


def _get_confusion_matrix(y_pred: List[int], y_true: List[int], labels: List[Any] = None, format_matrix: bool = True) -> Union[List[List[int]], Dict[int, List[int]]]:
    """get formated confusion matrix"""
    # TODO: Check if we want actual labels, and if yes, how should they be passed?
    if labels is None:
        labels = sorted(set(y_true))
    # TODO: Check how to handle missing classes if any
    confusion = confusion_matrix(y_true=y_true, y_pred=y_pred, labels=labels)
    # TODO: Check if we need this kind of formatting in case of no labels provided
    # CNER does this because they have tags which are strings
    if format_matrix:
        formatted_confusion = {}
        for i, row_label in enumerate(labels):
            formatted_confusion[row_label] = {}
            for j, col_label in enumerate(labels):
                formatted_confusion[row_label][col_label] = int(confusion[i][j])
        return formatted_confusion
    return confusion


def _get_evaluation_report(report, y_true, y_pred, labels): 
    """
    Method to return:
    - accuracy
    - micro precision
    - micro recall
    - micro F1
    - macro precision
    - macro recall
    - macro F1
    - confusion matrix

    From the docs: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
    Micro average (averaging the total true positives, false negatives and false positives) is only shown
    for multi-label or multi-class with a subset of classes,
    because it corresponds to accuracy otherwise and would be the same for all metrics.
    These metrics are calculated at both class-level and document-level.
    """
    evaluation_results = {'ClassMetrics': {}, 'TextClassificationModelMetrics': {}}
    metric_names = ['micro avg', 'macro avg', 'accuracy', 'weighted avg']
    for class_name, metrics in report.items():
        # Handling class-wise metrics
        if class_name not in metric_names:
            class_metrics = {}
            for metric_name, metric_value in metrics.items():
                if metric_name == 'f1-score':
                    class_metrics['f1'] = metric_value
                else:
                    class_metrics[metric_name] = metric_value
            evaluation_results['ClassMetrics'][class_name] = class_metrics

        # Micro-avg and accuracy scores
        elif class_name == 'micro avg' or class_name == 'accuracy':
            if class_name == 'accuracy':
                evaluation_results['TextClassificationModelMetrics']['microF1'] = float(metrics)
                evaluation_results['TextClassificationModelMetrics']['microPrecision'] = float(metrics)
                evaluation_results['TextClassificationModelMetrics']['microRecall'] = float(metrics)
                evaluation_results['TextClassificationModelMetrics']['accuracy'] = float(metrics)
            else:
                evaluation_results['TextClassificationModelMetrics']['microF1'] = metrics['f1-score']
                evaluation_results['TextClassificationModelMetrics']['microPrecision'] = metrics['precision']
                evaluation_results['TextClassificationModelMetrics']['microRecall'] = metrics['recall']
                evaluation_results['TextClassificationModelMetrics']['accuracy'] = accuracy_score(
                    y_true, y_pred
                )  # TODO: Check if we can use some other score here
        # Others
        else:
            key_head = class_name.split(' ')[0]  # macro/weighted
            for metric_name, metric_value in metrics.items():
                if metric_name == 'support':
                    continue
                if metric_name == 'f1-score':
                    key_tail = 'f1'
                else:
                    key_tail = metric_name
                evaluation_results['TextClassificationModelMetrics'][
                    f'{key_head}{key_tail.capitalize()}'
                ] = metric_value
    if IS_MULTI_LABEL:
        confusion_matrix_data = {}
    else:
        confusion_matrix_data = _get_confusion_matrix(y_pred=y_pred, y_true=y_true, labels=labels, format_matrix=True)
    evaluation_results['confusionMatrix'] = confusion_matrix_data
    evaluation_results['labels'] = labels

    return evaluation_results

In [28]:
eval_report = run_sklearn_report(y_true=predicted_df.labels.values, y_pred=predicted_df.predicted.values, labels=sorted(set(predicted_df.labels.values).union(set(predicted_df.predicted.values))))


In [29]:
pd.DataFrame(eval_report['confusionMatrix']).to_csv('confusionMatrix.csv')
pd.DataFrame(eval_report['ClassMetrics']).to_csv('ClassMetrics.csv')
pd.Series(eval_report['TextClassificationModelMetrics']).to_csv('TextClassificationModelMetrics.csv')
