<a href="https://colab.research.google.com/github/scorzo/bert-colab-inmem/blob/main/bert_colab_inmem.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install langchain flask transformers torch pandas

In [None]:
from langchain.vectorstores import Chroma
from flask import Flask, request, jsonify
from transformers import BertModel, BertTokenizer
import torch
import pandas as pd

# Initialize Flask app
app = Flask(__name__)

# Load pre-trained BERT model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Initialize Chroma vector store
chroma = Chroma()

# Function to generate embeddings
def generate_embeddings(text):
    inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy()[0]

# Function to store ticket in Chroma
def store_ticket(ticket_id, issue_description, embedding):
    chroma.add(ticket_id, embedding)
    return ticket_id

# REST API to create embeddings for a ticket and store it
@app.route('/create_ticket', methods=['POST'])
def create_ticket():
    data = request.json
    issue_description = data['Detailed Description']
    embedding = generate_embeddings(issue_description)
    ticket_id = data['Ticket ID']

    # Store in Chroma
    store_ticket(ticket_id, issue_description, embedding)
    return jsonify({"ticket_id": ticket_id})

# Function to find similar tickets using Chroma
def find_similar_tickets(query_embedding):
    similar_tickets = chroma.get_nns_by_vector(query_embedding, 1)
    return similar_tickets[0] if similar_tickets else None

# REST API to query similar tickets
@app.route('/query_tickets', methods=['POST'])
def query_tickets():
    query = request.json['query']
    query_embedding = generate_embeddings(query)

    # Find similar tickets
    similar_ticket = find_similar_tickets(query_embedding)
    return jsonify({"similar_ticket": similar_ticket})

def read_tickets_from_csv(file_path):
    df = pd.read_csv(file_path)
    return df.to_dict(orient='records')

def process_ticket(ticket_id, issue_description):
    embedding = generate_embeddings(issue_description)
    return store_ticket(ticket_id, issue_description, embedding)

def process_all_tickets(file_path):
    tickets = read_tickets_from_csv(file_path)
    for ticket in tickets:
        ticket_id = ticket['Ticket ID']
        description = ticket['Support Issue Description']
        process_ticket(ticket_id, description)

if __name__ == '__main__':
    # app.run(debug=True)
    csv_file_path = './content/sample_data/support_tickets.csv' # path to CSV file
    process_all_tickets(csv_file_path)
