In [1]:
from src.api import API
from src.dataloader import DataLoader
from src.prompter import prompter_factory
from src.evaluator import evaluate_model

In [2]:
TOGETHER_API_KEY = "92a6ac4a8feb39c91b4a3f77219e9c452d927f5f4d543d5969cc11c210795719"
BASE_URL = "https://api.together.xyz"
ALLOWED_SYMPTOMS = ['anxiety', 'concentration problems', 'constipation', 'cough',
                    'diarrhea', 'fatigue', 'fever', 'headache', 'nausea', 
                    'numbness and tingling', 'pain', 'poor appetite', 'rash', 
                    'shortness of breath', 'trouble drinking fluids', 'vomiting', 'other']

Using the API class provides a modular approach to client definition.

In [3]:
api = API(api_key=TOGETHER_API_KEY, base_url=BASE_URL)
client = api.get_openai()

The DataLoader class loads the csv files. It verifies their existence and provides testing tools.

DataLoader allows for dataframe creation, with uniformization of the labels and column formatting.

In [4]:
dataloader = DataLoader(path="data/")
print(dataloader.list_csv_files())

['data/batch_1_gs.csv', 'data/batch_2_gs.csv', 'data/batch_3_gs.csv', 'data/batch_4_gs.csv', 'data/batch_5_gs.csv', 'data/batch_6_gs.csv', 'data/batch_7_gs.csv', 'data/batch_8_gs.csv', 'data/batch_9_gs.csv', 'data/batch_10_gs.csv']


In [5]:
df = dataloader.get_standardized_dataframe(context_col="Text Data",
                                           target_binary_col="symptom_status_gs",
                                           target_multilabel_col="symptom_detail_gs",
                                           keep_other_cols=True)

In [6]:
dataloader.check_symptoms_validity(allowed_symptoms=ALLOWED_SYMPTOMS,
                                   symptoms_col="symptom_detail_gs")

Symptoms in dataframe are valid.


In [7]:
prompter = prompter_factory(prompter_type="binary",
                            client=client,
                            model="mistralai/Mixtral-8x7B-Instruct-v0.1")

In [8]:
prompt = "Are any medical symptoms mentioned in the transcript"
context = "i have a headache every time i see a cat and i hear voices that are not there"
print(prompter.generate_single(prompt=prompt, context=context))

{'status': True}


In [9]:
results = prompter.generate(df=df, prompt=prompt)

binary task using: mistralai/Mixtral-8x7B-Instruct-v0.1: 100%|██████████| 550/550 [00:51<00:00, 10.59it/s]


In [10]:
results

Unnamed: 0,status
0,True
1,True
2,True
3,False
4,False
...,...
545,True
546,False
547,True
548,True


In [11]:
evaluate_model(data=df, results=results, verbose=True)

Accuracy: 0.8109090909090909
Precision: 0.8556701030927835
Recall: 0.8736842105263158
F1: 0.8645833333333334
Confusion Matrix: [[114  56]
 [ 48 332]]


  y_true = y_true.astype(str).replace({"Positive": True, "Negative": False})
