## Importing Dependencies

In [1]:
# Importing necessary tools
from flask import Flask
from flask_restful import Resource, Api, reqparse
from flask_cors import CORS
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf

In [2]:
# Instantiating our DistilBERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-tweet-emotion", num_labels=4)

All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.

All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-tweet-emotion.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.


## Building a processing function for predictions

In [3]:
# Making a dictionary with class names for conversion
class_names = {0: "anger", 1: "joy", 2: "optimism", 3: "sadness"}

# A function containing the transformation steps from above
def logits_to_class_names(predictions):
    predictions = tf.nn.softmax(predictions.logits)
    predictions = tf.argmax(predictions, axis=1).numpy()
    predictions = [class_names[prediction] for prediction in predictions]
    
    return predictions

## Building the API

### Creating a Flask application

In [4]:
# Setting up a Flask application
app = Flask(import_name=__name__)
CORS(app)
api = Api(app=app)

### Defining arguments for HTTP requests

In [5]:
parser = reqparse.RequestParser()
parser.add_argument(name="Sequences", type=str, action="append",
                    help="The sequence to be classified", required=True)

<flask_restful.reqparse.RequestParser at 0x2428a324e20>

### Building an endpoint for inference

In [6]:
# Creating a class to represent our endpoint
class Inference(Resource):
    # A method corresponding to a GET request
    def get(self):
        # Parsing the arguments we defined earlier
        args = parser.parse_args()
        
        # Tokenizing the sequence
        sequence = tokenizer(args["Sequences"], return_tensors="tf", padding=True)
        
        # Obtaining a prediction
        prediction = logits_to_class_names(model(sequence))
        
        # Returning the prediction
        return {"Predictions": prediction}, 200

In [7]:
# Adding the endpoint to our app
api.add_resource(Inference, "/inference")

### Launching our application

In [8]:
# launching our app
if __name__ == "__main__":
    app.run()

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [26/Oct/2021 12:47:21] "GET /inference HTTP/1.1" 200 -
127.0.0.1 - - [26/Oct/2021 12:47:21] "GET /inference HTTP/1.1" 200 -
127.0.0.1 - - [26/Oct/2021 12:47:52] "GET /inference HTTP/1.1" 200 -
127.0.0.1 - - [26/Oct/2021 12:47:53] "GET /inference HTTP/1.1" 200 -
127.0.0.1 - - [26/Oct/2021 12:48:12] "GET /inference HTTP/1.1" 200 -
127.0.0.1 - - [26/Oct/2021 12:48:13] "GET /inference HTTP/1.1" 200 -
