# GPT-2 Flask REST Endpoint

This notebook demonstrates how to create a local REST API endpoint using Flask to serve a GPT-2 text generation model. You can send text prompts to this endpoint and receive generated text back.

## 1. Installation

First, we need to install the required libraries: `flask`, `transformers`, and `torch`.

In [None]:
!pip install Flask transformers torch

## 2. Import Libraries and Load Model

We'll import Flask and the necessary components from the `transformers` library to load the GPT-2 model and tokenizer and set up the generation function.

In [None]:
from flask import Flask, request, jsonify
from transformers import pipeline, set_seed

# Initialize model
generator = pipeline('text-generation', model='gpt2')
set_seed(42)

# Define the text generation function
def run_gpt2(input, answer_length=30):
    # Generate text, limiting the total length to 150 tokens
    # and removing the input prompt from the start of the output.
    generated_text = generator(input, max_length=150, num_return_sequences=1)[0]['generated_text']
    # Remove the input prompt from the beginning of the generated text
    if generated_text.startswith(input):
        return generated_text[len(input):].strip()
    return generated_text.strip()


## 3. Create Flask App and Endpoint

Now, we'll create a Flask application and define a route (`/generate`) that will handle text generation requests. This endpoint will accept POST requests containing JSON data with a 'text' field. We will use the `run_gpt2` function defined above.

In [None]:
# Create a Flask application instance
app = Flask(__name__)

# Define the route for text generation
@app.route('/generate', methods=['POST'])
def generate_text():
    # Get the JSON data from the request
    data = request.get_json()

    # Check if the 'text' field is present in the request data
    if 'text' not in data:
        return jsonify({'error': 'Missing "text" field in request'}), 400

    prompt = data['text']

    try:
        # Use the run_gpt2 function to generate text
        generated_text = run_gpt2(prompt)

        # Return the generated text as a JSON response
        return jsonify({'generated_text': generated_text})

    except Exception as e:
        # Return an error response if something goes wrong
        return jsonify({'error': str(e)}), 500


## 4. Run the Flask App

Finally, run the Flask development server. By default, it will run on `http://127.0.0.1:5000/`. The server will block the notebook cell execution. To stop the server, interrupt the kernel.

In [None]:
# Run the Flask app
# debug=True allows for hot-reloading and better error messages during development
if __name__ == '__main__':
    app.run(port=10003, debug=False)

## 5. Testing the Endpoint (Optional - requires a new cell)

You can test the endpoint by sending a POST request to `http://127.0.0.1:5000/generate` with a JSON body like `{"text": "Your prompt here"}`. You can use libraries like `requests` in another notebook cell or a tool like `curl` or Postman.

Example using `requests` (run in a separate cell *after* the Flask app is running):

```python
import requests

url = '[http://127.0.0.1:5000/generate](http://127.0.0.1:5000/generate)'
data = {'text': 'Once upon a time,'}

try:
    response = requests.post(url, json=data)
    response.raise_for_status() # Raise an exception for bad status codes
    result = response.json()
    print(result)
except requests.exceptions.RequestException as e:
    print(f"Error: {e}")
```

In [None]:
import requests

url = 'http://127.0.0.1:5000/generate'
data = {'text': 'Once upon a time,'}

try:
    response = requests.post(url, json=data)
    response.raise_for_status() # Raise an exception for bad status codes
    result = response.json()
    print(result)
except requests.exceptions.RequestException as e:
    print(f"Error: {e}")