In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 17 15:11:03 2025

@author: randy

this file performs inference on a semantic search model training with custom data

you must unzip the model from the tar file in the repo and copy the content to the top level dir
    --> that means the files are in the directory as defined by embedding_model_filename below

"""

# %pip install accelerate==1.3.0 #0.26.0
# %pip install sentence-transformers==3.4.1
# %pip install datasets==3.3.1

import datetime as dt
import pandas as pd
from sklearn.neighbors import NearestNeighbors
import numpy as np
import pickle

#for vector embeddings
from sentence_transformers import SentenceTransformer, losses, InputExample, util

data_filename = "/home/randy/supportiv/mle_screening_dataset.csv"
embedding_model_filename = '/home/randy/supportiv/custom_model'
nn_model_filename = "/home/randy/supportiv/trained_nn_model.pkl"
test_set_filename = "/home/randy/supportiv/test_set.csv"

In [2]:

#############################################################################################################
# Step one: Build the search model <-------------------- you only need to do this once
#   1. load the data
#   2. load the model from HuggingFace and save locallly
#   3. encode all the answers
#   4. compute the nearest neighbor search model using the encoded answers
#   5. save the nearest neigbor model (as pickle file)

# read the raw data
train_data = pd.read_csv(data_filename)

# seperate the answer space
docs = train_data['answer'].astype(str).tolist()

#Load the pre-trained (not fine-tuned) embedding model
model = SentenceTransformer(embedding_model_filename)

# encode the answeres with the embedding model
doc_emb = model.encode(docs)

# compute a nearest neighbor model with answer embeddings
nbrs = NearestNeighbors(n_neighbors=10, algorithm='ball_tree').fit(doc_emb)

# same model to storage
pickle.dump(nbrs, open(nn_model_filename, 'wb'))

  return torch._C._cuda_getDeviceCount() > 0


In [3]:

#############################################################################################################
# Step two: inference
#   1. load nearest neigbor model
#   2. load the custom HuggingFace model from local strage
#   3. load the orgnial answers
#   4. ask a question
#   5. show the answers that are closet to the question

# load the nearest meighbor model
loaded_nn_model = pickle.load(open(nn_model_filename, 'rb'))

# load the embedding model
loaded_embedding_mode = SentenceTransformer(embedding_model_filename)

# load the orginal data set
train_data = pd.read_csv(data_filename)

In [4]:

def ask_question(the_question, num_results = 5):
    """Aks the semantic search a question and return the results"""
    new_embedding = loaded_embedding_mode.encode(the_question)
    distances, indices = loaded_nn_model.kneighbors([new_embedding], num_results)
    return_len = len(indices[0])

    return_results = []
    for i in range(return_len): #an_index in indices[0]:
        this_answer = train_data.loc[indices[0][i]]['answer']
        this_answers_distance = distances[0][i]
        return_results.append([i+1, the_question, this_answers_distance, this_answer])
    
    return return_results


In [5]:
# Ask a question

answers = ask_question("can someone be allergic to water?")

print(answers[0][1])

for i in range(len(answers)):
    answer_rank = answers[i][0]
    this_answer = answers[i][3]
    answer_distance = answers[i][2]
    
    print("******************************************************************************")
    print(answer_rank, ":", answer_distance,":", this_answer)

can someone be allergic to water?
******************************************************************************
1 : 0.8751718586794703 : Summary : We all need clean water. People need it to grow crops and to operate factories, and for drinking and recreation. Fish and wildlife depend on it to survive.     Many different pollutants can harm our rivers, streams, lakes, and oceans. The three most common are soil, nutrients, and bacteria. Rain washes soil into streams and rivers. The soil can kill tiny animals and fish eggs. It can clog the gills of fish and block light, causing plants to die. Nutrients, often from fertilizers, cause problems in lakes, ponds, and reservoirs. Nitrogen and phosphorus make algae grow and can turn water green. Bacteria, often from sewage spills, can pollute fresh or salt water.     You can help protect your water supply:       - Don't pour household products such as cleansers, beauty products, medicines, auto fluids, paint, and lawn care products down the dra

In [6]:
# Ask a question

answers = ask_question("tell me about glaucoma research?")

print(answers[0][1])

for i in range(len(answers)):
    answer_rank = answers[i][0]
    this_answer = answers[i][3]
    answer_distance = answers[i][2]
    
    print("******************************************************************************")
    print(answer_rank, ":", answer_distance,":", this_answer)

tell me about glaucoma research?
******************************************************************************
1 : 0.5502713660438636 : Through studies in the laboratory and with patients, the National Eye Institute is seeking better ways to detect, treat, and prevent vision loss in people with glaucoma. For example, researchers have discovered genes that could help explain how glaucoma damages the eye. NEI also is supporting studies to learn more about who is likely to get glaucoma, when to treat people who have increased eye pressure, and which treatment to use first.
******************************************************************************
2 : 0.6147933817954384 : National Eye Institute  National Institutes of Health  2020 Vision Place  Bethesda, MD 20892-3655  301-496-5248  E-mail: 2020@nei.nih.gov  www.nei.nih.gov The Glaucoma Foundation  80 Maiden Lane, Suite 700  New York, NY 10038  212-285-0080 Glaucoma Research Foundation  251 Post Street, Suite 600  San Francisco, CA 94

In [7]:
# Ask a question

answers = ask_question("Are pets good to have around?")

print(answers[0][1])

for i in range(len(answers)):
    answer_rank = answers[i][0]
    this_answer = answers[i][3]
    answer_distance = answers[i][2]
    
    print("******************************************************************************")
    print(answer_rank, ":", answer_distance,":", this_answer)

    

Are pets good to have around?
******************************************************************************
1 : 0.8111947444944919 : Summary : Pets can add fun, companionship and a feeling of safety to your life. Before getting a pet, think carefully about which animal is best for your family. What is each family member looking for in a pet? Who will take care of it? Does anyone have pet allergies? What type of animal suits your lifestyle and budget?    Once you own a pet, keep it healthy. Know the signs of medical problems. Take your pet to the veterinarian if you notice:       - Loss of appetite    - Drinking a lot of water    - Gaining or losing a lot of weight quickly    - Strange behavior    - Being sluggish and tired    - Trouble getting up or down    - Strange lumps
******************************************************************************
2 : 0.9261778560184258 : Summary : You can't remove all the safety hazards from your life, but you can reduce them. To avoid many major 

In [10]:
# Test the model accuracy

# Criteria:
# Success: The correct answer is in the top 5 returned results
# Failure: The correct answer is NOT in the top 5 returned results

test_set = pd.read_csv(test_set_filename)

num_correct = 0
for i,r in test_set.iterrows():
    the_question = r['question']
    the_expected_answer = r['answer']
    
    answers = ask_question(the_question)
    
    #check the results for the correct answer
    for i in range(len(answers)):
        answer_rank = answers[i][0]
        this_answer = answers[i][3]
        answer_distance = answers[i][2]
        
        if this_answer==the_expected_answer:
            num_correct+=1

print("test set accuracy: ", num_correct/test_set.shape[0], "  number correct:", num_correct)     
   
# test set accuracy:  0.9710542352224254   number correct: 3187

test set accuracy:  0.9710542352224254   number correct: 3187
