<a href="https://colab.research.google.com/github/scorzo/colab-bert/blob/main/bert-colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# rest functions

In [None]:
from flask import Flask, request, jsonify
from transformers import BertModel, BertTokenizer
import torch
import weaviate
from langchain import embeddings

# Initialize Flask app
app = Flask(__name__)

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Initialize Weaviate client
client = weaviate.Client("http://localhost:8080")  # Replace with your Weaviate instance URL

# Check if the necessary class exists in Weaviate, create if not
schema = {
    "classes": [
        {
            "class": "SupportTicket",
            "vectorizer": "none",  # As we use BERT for embeddings
            "properties": [
                {
                    "name": "description",
                    "dataType": ["text"],
                },
                {
                    "name": "embedding",
                    "dataType": ["number[]"],
                }
            ],
        }
    ]
}
client.schema.create(schema)

# Function to generate embeddings
def generate_embeddings(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy()[0]

# REST API to create embeddings for a ticket and store it
@app.route('/create_ticket', methods=['POST'])
def create_ticket():
    data = request.json
    issue_description = data['Detailed Description']
    embedding = generate_embeddings(issue_description)

    # Store in Weaviate
    ticket_uuid = client.data_object.create(
        data_object={
            "description": issue_description,
            "embedding": embedding.tolist()
        },
        class_name="SupportTicket"
    )

    return jsonify({"ticket_uuid": ticket_uuid})

# REST API to query similar tickets
@app.route('/query_tickets', methods=['POST'])
def query_tickets():
    query = request.json['query']
    query_embedding = generate_embeddings(query)

    # Query similar tickets from Weaviate
    similar_tickets = client.query.get(
        class_name="SupportTicket",
        properties=["description"]
    ).with_near_vector(
        vector=query_embedding,
        certainty=True
    ).do()

    return jsonify({"similar_tickets": similar_tickets})

# test using csv file w/o rest

In [None]:
def read_tickets_from_csv(file_path):
    """
    Read tickets from a CSV file and return them as a list of dictionaries.
    """
    df = pd.read_csv(file_path)
    return df.to_dict(orient='records')

def process_ticket(ticket_id, issue_description):
    """
    Process a single ticket: generate its embedding and store it in the vector database.
    """
    embedding = generate_embeddings(issue_description)

    # Store in Weaviate (or any other vector DB)
    ticket_uuid = client.data_object.create(
        data_object={
            "ticket_id": ticket_id,
            "description": issue_description,
            "embedding": embedding.tolist()
        },
        class_name="SupportTicket"
    )
    return ticket_uuid

# Function to process all tickets in the CSV file
def process_all_tickets(file_path):
    tickets = read_tickets_from_csv(file_path)
    for ticket in tickets:
        ticket_id = ticket['Ticket ID']
        description = ticket['Support Issue Description']
        process_ticket(ticket_id, description)

# run code

In [None]:
if __name__ == '__main__':
    #app.run(debug=True)
    csv_file_path = './content/sample_data/support_tickets.csv'  # path to CSV file
    process_all_tickets(csv_file_path)