In [None]:
import dspy
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Union

from utils import json_to_dataframe, json_to_string_list

In [None]:
filepath = '../../data/vector_veterinary_imaging_2.json'

df = json_to_dataframe(filepath) 
rad_strings = json_to_string_list(filepath)

In [None]:
df

## Retrieval

In [None]:
class SentenceTransformerRetriever(dspy.Retrieve):
    def __init__(self, model: str, findings: List[str], k: int):
        self.model = model if isinstance(model, SentenceTransformer) else SentenceTransformer(model, trust_remote_code=True)
        self.findings = findings
        self.k = k
        self.embeddings = None
        self.init_embeddings()

    def init_embeddings(self):
        self.embeddings = self.model.encode(self.findings)

    def forward(self, query: str, k: int) -> List[Dict[str, Union[str, float]]]:
        query_embedding = self.model.encode([query])
        similarities = cosine_similarity(query_embedding, self.embeddings)[0]
        top_k_indicies = np.argsort(similarities)[-k:][::-1]

        results = [
            {'long_text': self.findings[idx], 'score': similarities[idx]} for idx in top_k_indicies
        ]

        return results