In [1]:
import pandas as pd
import requests
import time
import json
import re
import os
from requests.exceptions import Timeout, RequestException


In [None]:
ollama_node = "arctrddgxa004" # TODO: Change this variable to the node where Ollama is running
BASE_URL = f"http://{ollama_node}:11434/api/chat"
model = "gemma3-optimized:27b" # TODO: Change this variable to the model you want to use
MAX_ATTEMPTS = 3  # Maximum attempts per participant

In [None]:


def parse_score_and_explanation(response_text):
    """Extract score and explanation from model response"""
    score_patterns = [
        r'score[:\s]*(\d+)',
        r'(\d+)[/\s]*(?:out of\s*)?10',
        r'(\d+)[/\s]*10',
        r'rating[:\s]*(\d+)',
        r'^(\d+)',  # Number at start of line
    ]
    
    score = None
    for pattern in score_patterns:
        match = re.search(pattern, response_text, re.IGNORECASE | re.MULTILINE)
        if match:
            potential_score = int(match.group(1))
            if 1 <= potential_score <= 10:
                score = potential_score
                break
    
    return score, response_text.strip()

def get_transcript_and_regenerate_assessment(participant_id, dataset_type, ollama_node, model):
    """Re-run the qualitative assessment for a participant"""
    print(f"  Re-generating assessment for {participant_id}...")
    
    # Load transcript (same logic as first script)
    id_transcript = os.path.join("/data/users4/xli/ai-psychiatrist/datasets/daic_woz_dataset/", 
                                f"{participant_id}_P", f"{participant_id}_TRANSCRIPT.csv")
    
    if not os.path.exists(id_transcript):
        print(f"  Transcript not found for {participant_id}")
        return None
    
    try:
        transcript_data = pd.read_csv(id_transcript, sep='\t')
        dialogue_txt = transcript_data.dropna(subset=['speaker', 'value'])
        dialogue_txt['dialogue'] = dialogue_txt['speaker'] + " : " + dialogue_txt['value']
        full_dialogue = "\n".join(dialogue_txt['dialogue'])
        
        system_prompt = "You are a psychiatrist. Your job is to assess and diagnose patients for depression. Be as objective, concise, and accurate as possible."
        qual_assessment = f"""The following is an interview transcript from a psychiatric assessment of a subject who is being screened for depression. Please note that 'Ellie' is the interviewer, and 'Participant' is the subject being assessed.
        {full_dialogue}
      Please:
      1. Provide an overall qualitative assessment of the subject's mental health. 
      2. Summarize social aspects that may influence the subject's mental health. (ex. familial relationships, frienship dynamics, work environment, etc. that are relevant to the subjects mental health)
      3. Summarize biological aspects that may influence the subject's mental health. (ex. famillial history of mental health issues, previous or pre-existing mental health issues, stress levels, etc. that are relevant to the subjects mental health)
      4. Identify potential risk factors the subject may be experiencing.
      5. Use exact quotes from the transcript to support your assessment for each tag.

      Output should be formatted as bullet points with headings for each section using stars. Example: **Tiredness** <explanation of tiredness>. Do not include any additional text outside the bullet points
      Please answer in this XML format with each tag on a new line, properly indented. Use straight quotes instead of curly quotes, and do not include any additional text outside the XML tags:

      <assessment>
        <!-- Summary of participant's overall mental health -->
       <exact_quotes>
        <!-- Quotes from the transcript that support the assessment -->
        </exact_quotes>
      </assessment>

      <social_factors>
        <!-- Summary of social influences on patient's health -->
        <exact_quotes>
      </social_factors>

      <biological_factors>
        <!-- Summary of biological influences on patient's health -->
       <exact_quotes>
        <!-- Quotes from the transcript that support the assessment -->
        </exact_quotes>
      </biological_factors>

      <risk_factors>
        <!-- Summary of potential risk factors -->
         <exact_quotes>
       <!-- Quotes from the transcript that support the assessment -->
       </exact_quotes>
      </risk_factors>
    """
        
        # Make API request
        BASE_URL = f"http://{ollama_node}:11434/api/chat"
        response = requests.post(
            BASE_URL,
            json={
                "model": model,
                "messages": [{"role": "system", "content": system_prompt},
                           {"role": "user", "content": qual_assessment}],
                "stream": False,
                "options": {
                    "temperature": 0.3,  # Slightly higher temperature for variation
                    "top_k": 20,
                    "top_p": 0.9
                }
            },
            timeout=300
        )
        
        if response.status_code == 200:
            return response.json()['message']['content']
        else:
            print(f"  API request failed: {response.status_code}")
            return None
            
    except Exception as e:
        print(f"  Error regenerating assessment: {e}")
        return None

def evaluate_assessment(assessment_text, ollama_node, model):
    """Evaluate a qualitative assessment and return scores"""
    BASE_URL = f"http://{ollama_node}:11434/api/chat"
    
    
    coherence_prompt = f"""Evaluate the following qualitative assessment output for COHERENCE only.

Score from 1-10 where higher scores indicate better performance:
- Coherence (1–10): Is the response logically consistent?

Format your response as:
Score: [your score]

---
{assessment_text}
---"""

    completeness_prompt = f"""Evaluate the following qualitative assessment output for COMPLETENESS only.

Score from 1-10 where higher scores indicate better performance:
- Completeness (1–10): Does the assessment cover all relevant symptoms, severities, duration/frequency?

Format your response as:
Score: [your score]

---
{assessment_text}
---"""

    specificity_prompt = f"""Evaluate the following qualitative assessment output for SPECIFICITY only.

Score from 1-10 where higher scores indicate better performance:
- Specificity (1–10): Does it avoid vague/generic statements like 'the patient seems depressed'?

Format your response as:
Score: [your score]

---
{assessment_text}
---"""

    accuracy_prompt = f"""Evaluate the following qualitative assessment output for ACCURACY only.

Score from 1-10 where higher scores indicate better performance:
- Accuracy (1–10): Are the signs/symptoms aligned with DSM-5 or PHQ-8? Are there factual inconsistencies?

Format your response as:
Score: [your score]

---
{assessment_text}
---"""

    collective_prompt = f"""Provide a brief collective explanation/summary of the overall quality of this qualitative assessment, considering all aspects: coherence, completeness, specificity, and accuracy together. Only use 2-3 sentences.

---
{assessment_text}
---"""

    try:
        result = {}
        
        # Evaluate each metric
        for metric, prompt in [('coherence', coherence_prompt), 
                              ('completeness', completeness_prompt),
                              ('specificity', specificity_prompt), 
                              ('accuracy', accuracy_prompt)]:
            
            response = requests.post(BASE_URL, json={
                "model": model,
                "messages": [{"role": "user", "content": prompt}],
                "stream": False,
                "options": {"temperature": 0, "top_k": 20, "top_p": 0.9}
            }, timeout=290)
            
            if response.status_code == 200:
                content = response.json()['message']['content']
                score, _ = parse_score_and_explanation(content)
                result[metric] = score
            else:
                result[metric] = None
            
            time.sleep(2)
        
        # Get collective explanation
        collective_response = requests.post(BASE_URL, json={
            "model": model,
            "messages": [{"role": "user", "content": collective_prompt}],
            "stream": False,
            "options": {"temperature": 0, "top_k": 20, "top_p": 0.9}
        }, timeout=290)
        
        if collective_response.status_code == 200:
            result['collective_explanation'] = collective_response.json()['message']['content']
        else:
            result['collective_explanation'] = 'API request failed'
            
        return result
        
    except Exception as e:
        print(f"  Error evaluating assessment: {e}")
        return None




# File paths
input_csv_path = "/data/users2/nblair7/analysis_results/qual_resultsfin.csv"  # Original assessments
output_csv_path = "/home/users/nblair7/ai-psychiatrist/eval_results_with_feedback.csv"

# Load dataset info for re-generation
train_path = pd.read_csv("/data/users4/xli/ai-psychiatrist/datasets/daic_woz_dataset/train_split_Depression_AVEC2017.csv")
dev_path = pd.read_csv("/data/users4/xli/ai-psychiatrist/datasets/daic_woz_dataset/dev_split_Depression_AVEC2017.csv")
id_train = train_path.iloc[:, 0].tolist()
id_dev = dev_path.iloc[:, 0].tolist()
dataset_mapping = {subj: 'train' for subj in id_train}
dataset_mapping.update({subj: 'dev' for subj in id_dev})

# Load the CSV file
print("Loading CSV file...")
df = pd.read_csv(input_csv_path)
print(f"Loaded {len(df)} participants")

results = []
processed_count = 0

# Check for existing results
if os.path.exists(output_csv_path):
    print(f"Found existing results file: {output_csv_path}")
    existing_results = pd.read_csv(output_csv_path)
    completed_subjects = set(existing_results['participant_id'].tolist())
    print(f"Already completed {len(completed_subjects)} subjects")
    df = df[~df['participant_id'].isin(completed_subjects)]
    print(f"Remaining subjects to process: {len(df)}")
    results = existing_results.to_dict('records')
else:
    print("No existing results found, starting fresh")

for index, row in df.iterrows():
    participant_id = row['participant_id']
    original_assessment = row['qualitative_assessment']
    dataset_type = dataset_mapping.get(participant_id, 'unknown')
    
    print(f"\n--- Processing {index + 1}/{len(df)}: {participant_id} ---")
    
    current_assessment = original_assessment
    best_scores = {'coherence': 0, 'completeness': 0, 'specificity': 0, 'accuracy': 0}
    best_result = None
    
    for attempt in range(MAX_ATTEMPTS):
        print(f"  Attempt {attempt + 1}/{MAX_ATTEMPTS}")
        
        # Evaluate current assessment
        evaluation = evaluate_assessment(current_assessment, OLLAMA_NODE, model)
        
        if evaluation is None:
            print(f"  Failed to evaluate assessment")
            continue
            
        # Check scores
        scores = [evaluation.get(metric) for metric in ['coherence', 'completeness', 'specificity', 'accuracy']]
        valid_scores = [s for s in scores if s is not None]
        
        if not valid_scores:
            print(f"  No valid scores obtained")
            continue
            
        min_score = min(valid_scores)
        avg_score = sum(valid_scores) / len(valid_scores)
        
        print(f"  Scores: C={evaluation.get('coherence')}, Comp={evaluation.get('completeness')}, "
              f"S={evaluation.get('specificity')}, A={evaluation.get('accuracy')} (min={min_score:.1f}, avg={avg_score:.1f})")
        
        # Store result
        result = {
            'participant_id': participant_id,
            'attempt': attempt + 1,
            'coherence': evaluation.get('coherence'),
            'completeness': evaluation.get('completeness'),
            'specificity': evaluation.get('specificity'),
            'accuracy': evaluation.get('accuracy'),
            'collective_explanation': evaluation.get('collective_explanation'),
            'qualitative_assessment': current_assessment,
            'min_score': min_score,
            'avg_score': avg_score
        }
        
        # Keep track of best result
        if best_result is None or avg_score > best_result['avg_score']:
            best_result = result.copy()
        
        # Check if we should continue (any score below threshold)
        if min_score >= FEEDBACK_THRESHOLD:
            print(f"  ✓ All scores >= {FEEDBACK_THRESHOLD}, stopping at attempt {attempt + 1}")
            break
        elif attempt < MAX_ATTEMPTS - 1:
            print(f"  ↻ Min score {min_score} < {FEEDBACK_THRESHOLD}, regenerating assessment...")
            
            # Re-generate assessment
            new_assessment = get_transcript_and_regenerate_assessment(
                participant_id, dataset_type, OLLAMA_NODE, model)
            
            if new_assessment:
                current_assessment = new_assessment
                time.sleep(3)  # Brief pause between regenerations
            else:
                print(f"  Failed to regenerate assessment, using previous version")
                break
        else:
            print(f"  Reached maximum attempts, using best result")
    
    # Use the best result obtained
    if best_result:
        results.append(best_result)
        processed_count += 1
        print(f"Completed participant {participant_id} - Final: min={best_result['min_score']:.1f}, "
              f"avg={best_result['avg_score']:.1f} (attempt {best_result['attempt']})")
    
    # Save progress periodically
    if len(results) % 5 == 0 or len(results) == 1:
        resultsdf = pd.DataFrame(results)
        resultsdf.to_csv(output_csv_path, index=False)
        print(f"Saved progress: {len(results)} results")
    
    time.sleep(2)

# Final save and summary
print(f"\nSummary of processing:")
print(f"Total subjects processed: {processed_count}")
print(f"Total results collected: {len(results)}")

if results:
    resultsdf = pd.DataFrame(results)
    resultsdf.to_csv(output_csv_path, index=False)
    print(f"Final save completed: {output_csv_path}")
    
    # Summary statistics
    print(f"\nFeedback loop statistics:")
    attempt_counts = resultsdf['attempt'].value_counts().sort_index()
    for attempt, count in attempt_counts.items():
        print(f"  Attempt {attempt}: {count} participants")
    
    avg_min_score = resultsdf['min_score'].mean()
    avg_avg_score = resultsdf['avg_score'].mean()
    print(f"  Average min score: {avg_min_score:.2f}")
    print(f"  Average avg score: {avg_avg_score:.2f}")
    
    improved_count = len(resultsdf[resultsdf['min_score'] >= FEEDBACK_THRESHOLD])
    print(f"  Participants with all scores >= {FEEDBACK_THRESHOLD}: {improved_count}/{len(resultsdf)}")
else:
    print("No results to save!")

Loading CSV file...
Loaded 142 participants
No existing results found, starting fresh

--- Processing 1/142: 303 ---
  Attempt 1/3
  Error evaluating assessment: HTTPConnectionPool(host='arctrddgxa003', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x72a64b6dc390>: Failed to establish a new connection: [Errno 111] Connection refused'))
  Failed to evaluate assessment
  Attempt 2/3
  Error evaluating assessment: HTTPConnectionPool(host='arctrddgxa003', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x72a64b139f10>: Failed to establish a new connection: [Errno 111] Connection refused'))
  Failed to evaluate assessment
  Attempt 3/3
  Error evaluating assessment: HTTPConnectionPool(host='arctrddgxa003', port=11434): Max retries exceeded with url: /api/chat (Caused by NewConnectionError('<urllib3.connection.HTTPConnection obj

KeyboardInterrupt: 