# Amazon Comprehend - Custom Text Classification

This lab is based off the blog post found [here](https://aws.amazon.com/blogs/machine-learning/building-a-custom-classifier-using-amazon-comprehend/).  This uses some data that has been pre-parsed and split in a public facing S3 bucket.  You will need to update this notebook with your own s3 output location and IAM user policy.

Furthermore, we've reduced the number of documents to speed up the classification training time.

Please email awsaaron@amazon.com for questions

# Data


Typically you'll copy data from S3 into the Sagemaker notebook instance, however, in this example we are not really using the power of Sagemaker for custom model training but using the AI Service - Amazon Comprehend.  To use that, we'll need our data in S3


In [1]:
import os
import boto3
import sagemaker
import pandas as pd

prefix = 'NLP.Classification'
os.environ["AWS_REGION"] = region
region = boto3.Session().region_name
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker.Session().default_bucket()

print(region)
print(bucket_name)

us-east-1
sagemaker-us-east-1-626825435328


In [2]:
training_data = 's3://aws-ml-blog/artifacts/comprehend-custom-classification/comprehend-train.csv'
testing_data = 's3://aws-ml-blog/artifacts/comprehend-custom-classification/comprehend-test.csv'

Copy the data from the public bucket to your local instance

In [3]:
!aws s3 cp {training_data} .
!aws s3 cp {testing_data} .

download: s3://aws-ml-blog/artifacts/comprehend-custom-classification/comprehend-train.csv to ./comprehend-train.csv
download: s3://aws-ml-blog/artifacts/comprehend-custom-classification/comprehend-test.csv to ./comprehend-test.csv


In [4]:
df = pd.read_csv(training_data,header=None,names=['class','text'])
df

Unnamed: 0,class,text
0,SCIENCE_AND_MATHEMATICS,"What is an \imaginary number\""? \n What is an ..."
1,ENTERTAINMENT_AND_MUSIC,What's the cheapest source for ordering DVDs f...
2,BUSINESS_AND_FINANCE,If I lose lots of money in stock in one year&#...
3,SCIENCE_AND_MATHEMATICS,When can a common man fly to moon? \n My realt...
4,SOCIETY_AND_CULTURE,When do you use a semicolon instead of a colon...
...,...,...
99986,POLITICS_AND_GOVERNMENT,I need help reporting a person who is working ...
99987,ENTERTAINMENT_AND_MUSIC,What happened to the rebate for 'Friends' seas...
99988,POLITICS_AND_GOVERNMENT,Are terrorists allowed to edit \factual inform...
99989,SPORTS,do u think that STEVE NASH deserves the MVP???...


In [5]:
df['class'].unique()

array(['SCIENCE_AND_MATHEMATICS', 'ENTERTAINMENT_AND_MUSIC',
       'BUSINESS_AND_FINANCE', 'SOCIETY_AND_CULTURE',
       'EDUCATION_AND_REFERENCE', 'COMPUTERS_AND_INTERNET',
       'POLITICS_AND_GOVERNMENT', 'HEALTH', 'SPORTS',
       'FAMILY_AND_RELATIONSHIPS'], dtype=object)

In [6]:
df['class'].value_counts()

SPORTS                      10000
HEALTH                      10000
FAMILY_AND_RELATIONSHIPS    10000
SOCIETY_AND_CULTURE          9999
POLITICS_AND_GOVERNMENT      9999
SCIENCE_AND_MATHEMATICS      9999
EDUCATION_AND_REFERENCE      9999
BUSINESS_AND_FINANCE         9999
ENTERTAINMENT_AND_MUSIC      9998
COMPUTERS_AND_INTERNET       9998
Name: class, dtype: int64

Due to the size of the dataset, let's downsample it for the purposes of this lab

In [7]:
a = df.sample(1000)

In [8]:
a['class'].value_counts()

POLITICS_AND_GOVERNMENT     109
HEALTH                      106
FAMILY_AND_RELATIONSHIPS    106
BUSINESS_AND_FINANCE        106
EDUCATION_AND_REFERENCE     100
ENTERTAINMENT_AND_MUSIC     100
SOCIETY_AND_CULTURE          97
SCIENCE_AND_MATHEMATICS      97
COMPUTERS_AND_INTERNET       94
SPORTS                       85
Name: class, dtype: int64

Save the downsampled dataset to a CSV file

In [9]:
import csv
a['text'] = '"' + a['text'] + '"'
a.to_csv('limited_dataset.csv',header=None,index=None,quoting=csv.QUOTE_NONE)

In [10]:
limited_dataset_path = 's3://'+bucket_name+'/'+prefix+'/limited_dataset.csv'

In [11]:
!aws s3 cp limited_dataset.csv {limited_dataset_path}

upload: ./limited_dataset.csv to s3://sagemaker-us-east-1-626825435328/NLP.Classification/limited_dataset.csv


## Create Classifier
Create a custom document classifier, supply the name, location of training data, access role ARN, language, and output S3 bucket location

Make sure you've added the following policy to your assumed role:

```

{
  "Version": "2012-10-17",
  "Statement": [
      {
         "Action": [
            "iam:PassRole"
         ],
         "Effect": "Allow",
         "Resource": "*"
      }
   ]
}

```

In [17]:
import boto3

# Instantiate Boto3 SDK:
client = boto3.client('comprehend', region_name='us-east-1')
classifier_name = 'custom_classification_immersion_day'

# Create a document classifier
create_response = client.create_document_classifier(
      DocumentClassifierName=classifier_name,
      DataAccessRoleArn=role,
      InputDataConfig={
          'S3Uri': limited_dataset_path,
      },
      LanguageCode='en',
  )

print(create_response)

{'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:626825435328:document-classifier/test', 'ResponseMetadata': {'RequestId': 'e189d608-740d-4d5e-a3c0-8e700b2be7c5', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'e189d608-740d-4d5e-a3c0-8e700b2be7c5', 'content-type': 'application/x-amz-json-1.1', 'content-length': '94', 'date': 'Mon, 28 Sep 2020 21:41:27 GMT'}, 'RetryAttempts': 0}}


Now let's check the status of the custom classifier.  You can run the following cell's multiple times to check the status if needed

In [23]:
describe_response = client.describe_document_classifier(
    DocumentClassifierArn=create_response['DocumentClassifierArn'])
print("Describe response: \n",describe_response)
print()

# List all classifiers in account
list_response = client.list_document_classifiers()
print("List response: \n", list_response)

Describe response: 
 {'DocumentClassifierProperties': {'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:626825435328:document-classifier/test', 'LanguageCode': 'en', 'Status': 'TRAINING', 'SubmitTime': datetime.datetime(2020, 9, 28, 21, 41, 27, 541000, tzinfo=tzlocal()), 'InputDataConfig': {'S3Uri': 's3://sagemaker-us-east-1-626825435328/NLP.Classification/limited_dataset.csv'}, 'OutputDataConfig': {}, 'DataAccessRoleArn': 'arn:aws:iam::626825435328:role/service-role/AmazonSageMaker-ExecutionRole-20200926T132738', 'Mode': 'MULTI_CLASS'}, 'ResponseMetadata': {'RequestId': '2dce2589-4b5d-44bd-814b-e60a7d533cc6', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '2dce2589-4b5d-44bd-814b-e60a7d533cc6', 'content-type': 'application/x-amz-json-1.1', 'content-length': '489', 'date': 'Mon, 28 Sep 2020 22:01:22 GMT'}, 'RetryAttempts': 0}}

List response: 
 {'DocumentClassifierPropertiesList': [{'DocumentClassifierArn': 'arn:aws:comprehend:us-east-1:626825435328:document-classifie

In [20]:
describe_response['DocumentClassifierProperties']['DocumentClassifierArn']

'arn:aws:comprehend:us-east-1:626825435328:document-classifier/test'

# Predictions!

Once the custom classification model is trained, now you can use if for batch or real-time predictions.

Create an end point for real time model prediction.  

In [None]:
# create end point
response = client.create_endpoint(
    EndpointName='my-custom-classification-endpoint2',
    ModelArn=describe_response['DocumentClassifierProperties']['DocumentClassifierArn'],
    DesiredInferenceUnits=1,
)

In [None]:
print(response)

In [None]:
response['EndpointArn']

In [None]:
txt = 'After my most recent doctors appointment, I came down with the flu'

In [None]:
# real-time
real_time_response = client.classify_document(
    Text=txt,
    EndpointArn=response['EndpointArn']
)
print(real_time_response['Classes'])

Next, let's try a batch async prediction

In [None]:
# batch
start_response = client.start_document_classification_job(
    InputDataConfig={
        'S3Uri': testing_data,
    },
    OutputDataConfig={
        'S3Uri': s3_output_bucket
    },
    DataAccessRoleArn=data_access_arn,
    DocumentClassifierArn=describe_response['DocumentClassifierProperties']['DocumentClassifierArn']
)

print("Start response: %s\n", start_response)


In [None]:
# Check the status of the job
describe_response = client.describe_document_classification_job(JobId=start_response['JobId'])
print("Describe response: %s\n", describe_response)

# List all classification jobs in account
list_response = client.list_document_classification_jobs()
print("List response: %s\n", list_response)