#### Evaluate model prediction against human annotations

In [3]:
import sys
sys.path.append("../")

%load_ext autoreload
%autoreload 2

In [13]:
from src.annotation_loader import Phase1Loader, Phase1OutputLoader
from src.utils import read_jsonlines, read_json 
import json 
from tqdm import tqdm 
import re
from collections import defaultdict
from pathlib import Path

In [5]:
loader = Phase1Loader()

In [6]:
total_friction_indices = 0 

all_conv_ids = loader.list_all_convo_ids()

In [None]:
num_frictions = [] 

for elem in tqdm(all_conv_ids):
	friction_turns = loader.get_annotation_data(elem)['friction_turns']
	num_frictions.append(len(friction_turns))

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:11<00:00, 17.57it/s]


In [None]:
total_turns = 0

for elem in all_conv_ids:
	length = int(elem.split('.')[0])
	total_turns += length

In [9]:
total_turns

7950

In [10]:
import numpy as np 

np.sum(num_frictions)

238

In [12]:
friction_turns

{1: [6], 2: [20], 3: [37]}

### Load Model results, define overlap metrics

In [14]:
# load model results 
#self_c = Phase1OutputLoader("/fs/clip-political/rupak/common_ground/experiments/phase1_experiments/phase1_outputs/friction_prediction_self_consistency/friction_detection_temp_0.01_gpt-4o_w_gpt_assist_w_self_explanation.jsonl")
results_path = "../data/model_outputs/friction_prediction_outputs"


gpt4o = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_gpt-4o_wo_gpt_assist.jsonl")
gpt4omini = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_gpt-4o-mini_wo_gpt_assist.jsonl")
gpt4omini_w_assist = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_gpt-4o-mini_w_gpt_assist.jsonl")
gpt4o_w_assist = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_gpt-4o_w_gpt_assist.jsonl")
llama8b = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_Llama-3.1-8B-Instruct_wo_gpt_assist.jsonl")
llama70b = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_Llama-3.1-70B-Instruct_wo_gpt_assist.jsonl")
llama8b_w_assist = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_Llama-3.1-8B-Instruct_w_gpt_assist.jsonl")
llama70b_w_assist = Phase1OutputLoader(f"{results_path}/friction_detection_temp_0.01_Llama-3.1-70B-Instruct_w_gpt_assist.jsonl")


In [None]:
def intervals_intersect(interval1, interval2):
	"""Check if two non contiguous intervals intersect."""
	return len(set(interval1).intersection(set(interval2))) > 0

def interval_jaccard(interval1, interval2):
	"""Computer the Jaccard similarity between two intervals."""

	intersection = len(set(interval1).intersection(set(interval2)))
	union = len(set(interval1).union(set(interval2)))

	return intersection / union if union > 0 else 0

def calculate_turn_overlap(human_response_turns, model_response_turns):
	# same as span overlap, but calculate turn overlap in an interval agnostic way 
	turn_score = 0 

	model_response_turns = [ [i for i in range(interval[0], interval[-1] + 1)] for interval in model_response_turns]

	# flatten the two lists
	all_human_turns = set([turn for interval in human_response_turns for turn in interval])
	all_model_turns = set([turn for interval in model_response_turns for turn in interval])

	# output the turn overlap
	turn_score = len(all_human_turns.intersection(all_model_turns))

	total_human_turns = len(all_human_turns)
	total_model_turns = len(all_model_turns)

	recall = turn_score / total_human_turns if total_human_turns > 0 else 0
	precision = turn_score / total_model_turns if total_model_turns > 0 else 0

	f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

	return {
		"overlap": turn_score,
		"total_human_intervals": total_human_turns,
		"total_model_intervals": total_model_turns,
		"precision": precision,
		"recall": recall,
		"f1": f1
	}


def calculate_span_overlap(human_response_intervals, model_response_intervals):
	# first, go over the human response intervals and match a unique model response interval based on the degree of overlap 
	# if found, mark it as matched and remove it from consideration for future matches
	span_score = 0 

	# fill up all the model intervals
	model_response_intervals = [ [i for i in range(interval[0], interval[-1] + 1)] for interval in model_response_intervals]

	# save a match as a tuple of indices 
	matches = []
	for human_interval_index, human_interval in enumerate(human_response_intervals):
		# calculate the interval with the max overlap 
		max_overlap = 0
		max_overlap_index = -1
		for model_interval_index, model_interval in enumerate(model_response_intervals):
			# first check if this model interval has already been matched
			if model_interval_index in [m[1] for m in matches]:
				continue

			overlap = interval_jaccard(human_interval, model_interval)
			if overlap > max_overlap:
				max_overlap = overlap
				max_overlap_index = model_interval_index

		if max_overlap > 0:
			matches.append((human_interval_index, max_overlap_index))
			span_score += max_overlap

	total_human_intervals = len(human_response_intervals)
	total_model_intervals = len(model_response_intervals)

	# now, recall is the span_score / total_human_intervals
	recall = span_score / total_human_intervals if total_human_intervals > 0 else 0
	precision = span_score / total_model_intervals if total_model_intervals > 0 else 0

	f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

	return {
		"overlap": span_score,
		"total_human_intervals": total_human_intervals,
		"total_model_intervals": total_model_intervals,
		"precision": precision,
		"recall": recall,
		"f1": f1
	}

	# Test cases for the calculate_span_overlap and calculate_overlap functions

def calculate_overlap(human_response_intervals, model_response_intervals):
	"""
	Calculate precision, recall, and F1 score based on the overlap between 
	human response intervals and model response intervals.
	
	Parameters:
	human_response_intervals (list of list): List of interval(s) representing human responses.
	model_response_intervals (list of list): List of interval(s) representing model responses.
	
	Returns:
	dict: A dictionary containing 'overlap', 'precision', 'recall', and 'f1' scores.
	"""
	
	overlap = 0

	for model_interval in model_response_intervals:
		# fill up the model interval - 

		model_interval_filled = [ i for i in range(model_interval[0], model_interval[-1] + 1)]
		for human_interval in human_response_intervals:
			if intervals_intersect(model_interval_filled, human_interval):
				overlap += 1
				break

	total_human_intervals = len(human_response_intervals)
	total_model_intervals = len(model_response_intervals)
	
	recall = overlap / total_human_intervals if total_human_intervals > 0 else 0
	precision = overlap / total_model_intervals if total_model_intervals > 0 else 0

	f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

	return {
		"overlap": overlap,
		"total_human_intervals": total_human_intervals,
		"total_model_intervals": total_model_intervals,
		"precision": precision,
		"recall": recall,
		"f1": f1
	}


### Small snippet showing how the overlap metrics work 

In [18]:
# Define some test intervals
human_response_intervals = [[1, 2, 3], [5, 6, 7], [10, 11, 12]]
model_response_intervals = [[2, 4], [6, 8], [11, 13]]

# Test calculate_span_overlap function
span_overlap_result = calculate_span_overlap(human_response_intervals, model_response_intervals)
print("Span Overlap Result:", span_overlap_result)

# Test calculate_overlap function
overlap_result = calculate_overlap(human_response_intervals, model_response_intervals)
print("Overlap Result:", overlap_result)

# Test calculate_turn_overlap function
turn_overlap_result = calculate_turn_overlap(human_response_intervals, model_response_intervals)
print("Turn Overlap Result:", turn_overlap_result)


Span Overlap Result: {'overlap': 1.5, 'total_human_intervals': 3, 'total_model_intervals': 3, 'precision': 0.5, 'recall': 0.5, 'f1': 0.5}
Overlap Result: {'overlap': 3, 'total_human_intervals': 3, 'total_model_intervals': 3, 'precision': 1.0, 'recall': 1.0, 'f1': 1.0}
Turn Overlap Result: {'overlap': 6, 'total_human_intervals': 9, 'total_model_intervals': 9, 'precision': 0.6666666666666666, 'recall': 0.6666666666666666, 'f1': 0.6666666666666666}


In [19]:
# # Some Additional test cases
# human_response_intervals_2 = [[1, 2], [4, 8], [9, 10]]
# model_response_intervals_2 = [[2, 3], [5, 6], [7, 8], [9, 10]]

# span_overlap_result_2 = calculate_span_overlap(human_response_intervals_2, model_response_intervals_2)
# print("Span Overlap Result 2:", span_overlap_result_2)

# overlap_result_2 = calculate_overlap(human_response_intervals_2, model_response_intervals_2)
# print("Overlap Result 2:", overlap_result_2)

# turn_overlap_result_2 = calculate_turn_overlap(human_response_intervals_2, model_response_intervals_2)
# print("Turn Overlap Result 2:", turn_overlap_result_2)


In [20]:
# # Edge case: No overlap
# human_response_intervals_3 = [[1, 2], [4, 5]]
# model_response_intervals_3 = [[6, 7], [8, 9]]

# span_overlap_result_3 = calculate_span_overlap(human_response_intervals_3, model_response_intervals_3)
# print("Span Overlap Result 3:", span_overlap_result_3)

# overlap_result_3 = calculate_overlap(human_response_intervals_3, model_response_intervals_3)
# print("Overlap Result 3:", overlap_result_3)

# turn_overlap_result_3 = calculate_turn_overlap(human_response_intervals_3, model_response_intervals_3)
# print("Turn Overlap Result 3:", turn_overlap_result_3)

In [None]:
batches = [str(i) for i in range(1, 11)]

all_convos = [] 
for batch in batches: 
	loader = Phase1Loader()
	convos = loader.get_convos_in_batch(batch)
	all_convos.extend(convos)

In [22]:
# check if the leader is working
loader.get_annotation_data("57.1")['friction_turns']

{1: [17], 2: [38, 39]}

In [None]:
# load the selected convos 
count = 0 
# count the total number of friction turns found by humans 

for batch in batches:
	loader = Phase1Loader()
	convos = loader.get_convos_in_batch(batch)
	for conv in convos:
		annotation_data = loader.get_annotation_data(conv)
		friction_turns = annotation_data['friction_turns']
		count += len(friction_turns)

In [24]:
count

238

### Methods to extract the spans from model outputs 

In [None]:
def extract_friction_details_corrected_new(text):
	# Initialize pattern to capture various parts of the input
	pattern = re.compile(
		r'friction_present:\s*(true|false)\s*|'          # Match "friction_present: true/false"
		r'"friction(\d+)":\s*\[([^\]]+)\]|'              # Match "frictionN": [x, y]
		r'"explanation(\d+)":\s*"((?:[^"]|\\")*?)"',     # Match "explanationN": including any " escaped with \"
		re.DOTALL
	)

	extracted_data = {}

	for match in pattern.finditer(text):
		if match.group(1) is not None:
			# Extracting the presence of friction
			extracted_data["friction_present"] = match.group(1).lower() == "true"
		elif match.group(2) is not None:
			# Extracting the friction index and values
			index = match.group(2)
			# Handle each item as a string rather than trying to convert to int immediately
			items = [x.strip().strip('"') for x in match.group(3).split(",")]
			extracted_data[f"friction{index}"] = items
		elif match.group(4) is not None:
			# Extracting explanations
			index = match.group(4)
			explanation = match.group(5).replace('\\"', '"')  # Correct escaped quotes
			extracted_data[f"explanation{index}"] = explanation

	# Post-process to convert the friction values to integers where applicable
	to_remove_keys = []
	for key, value in extracted_data.items():
		if re.match(r"friction(\d+)", key):
			# Try converting items to integers or extract from strings
			try:
				extracted_data[key] = [int(item) for item in value]
			except ValueError:
				# Handle cases where items are strings with "Turn X"
				if len(value[0].split()) == 1:
					# Discard this key and value if conversion is not applicable
					to_remove_keys.append(key)
					continue

				# Extract integer from strings like "Turn X"
				extracted_data[key] = [int(item.split(" ")[1]) for item in value]

	for key in to_remove_keys:
		extracted_data.pop(key)

	return extracted_data

def extract_friction_details_corrected(text):
	# Pattern to match the JSON structure specifically for the desired keys
	pattern = re.compile(
		r'"friction_present":\s*(true|false)|'           # Match "friction_present": true/false
		r'"friction(\d+)":\s*\[([^\]]+)\]|'              # Match "frictionN": [x, y]
		r'"explanation(\d+)":\s*"((?:[^"]|\\")*?)"',     # Match "explanationN": including any " escaped with \"
		re.DOTALL
	)

	extracted_data = {}

	for match in pattern.finditer(text):
		if match.group(1) is not None:
			# Extracting the presence of friction
			extracted_data["friction_present"] = match.group(1) == "true"
		elif match.group(2) is not None:
			# Extracting the friction index and values
			index = match.group(2)
			# Handle each item as a string rather than trying to convert to int
			items = [x.strip().strip('"') for x in match.group(3).split(",")]
			extracted_data[f"friction{index}"] = items
		elif match.group(4) is not None:
			# Extracting explanations
			index = match.group(4)
			explanation = match.group(5).replace('\\"', '"')  # Correct escaped quotes
			extracted_data[f"explanation{index}"] = explanation

	# Do some post processing to convert the friction values to integers
	# sometimes, the model will say None, and sometimes it wil say Turn X instead of X 

	to_remove_keys = []
	for key, value in extracted_data.items():
		# if key is of the format "frictionN", then we need to convert the values to integers
		if re.match(r"friction(\d+)", key):
			# check if all items can be converted to integers and if so, convert them
			try : 
				extracted_data[key] = [int(item) for item in value]
			except ValueError:
				# if it's a single item, then it can be converted to an integer and we can move on 
				if len(value[0].split()) == 1: 
					# discard this key and value 
					to_remove_keys.append(key)
					continue
				
				# if it's two items, then we need to extract the integer from the string
				extracted_data[key] = [int(item.split(" ")[1]) for item in value]

	for key in to_remove_keys:
		extracted_data.pop(key)

	return extracted_data



def get_macro_metrics(convo_ids: list[str],
					model_outputs: Phase1OutputLoader, 
					parsing_function: callable, 
					interval_overlap_function: callable) -> dict: 
	
	overlap = 0 
	total_human_intervals = 0
	total_model_intervals = 0

	for convo_id in convo_ids:
		#print(convo_id)

		# load the human annotations 
		human_annotations = loader.get_annotation_data(convo_id)
		friction_turns = human_annotations["friction_turns"]

		# if no friction turns are present, skip the conversation
		if len(friction_turns) == 0:
			#continue
			pass 

		human_response_intervals = [interval for order, interval in friction_turns.items()]

		
		model_response = model_outputs.get_output_of_convo(convo_id)
		model_response_dict = parsing_function(model_response["response"])

		if "friction_present" not in model_response_dict:
			print("Something has gone wrong")

		# if keys start with "friction", then it is a friction interval
		# make all intervals ints 
		model_response_intervals = [interval for order, interval in model_response_dict.items() if re.match(r"friction(\d+)", order)] 

		metrics = interval_overlap_function(human_response_intervals, model_response_intervals)
		
		overlap += metrics["overlap"]
		total_human_intervals += metrics["total_human_intervals"]
		total_model_intervals += metrics["total_model_intervals"]

	results =  {
		"overlap": overlap,
		"total_human_intervals": total_human_intervals,
		"total_model_intervals": total_model_intervals,
	}

	precision = overlap / total_model_intervals if total_model_intervals > 0 else 0
	recall = overlap / total_human_intervals if total_human_intervals > 0 else 0
	f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

	results["precision"] = round(precision*100, 4)
	results["recall"] = round(recall*100, 4)
	results["f1"] = round(f1*100, 4)

	return results


def get_micro_metrics(convo_ids: list[str],
					model_outputs: Phase1OutputLoader, 
					parsing_function: callable, 
					interval_overlap_function: callable) -> dict: 
	
	precision = 0
	recall = 0
	f1 = 0

	for convo_id in convo_ids:
		#print(convo_id)

		# load the human annotations 
		human_annotations = loader.get_annotation_data(convo_id)
		friction_turns = human_annotations["friction_turns"]

		# if no friction turns are present, skip the conversation
		if len(friction_turns) == 0:
			#continue
			pass 

		human_response_intervals = [interval for order, interval in friction_turns.items()]

		
		model_response = model_outputs.get_output_of_convo(convo_id)
		model_response_dict = parsing_function(model_response["response"])

		if "friction_present" not in model_response_dict:
			print("Something has gone wrong")

		# if keys start with "friction", then it is a friction interval
		# make all intervals ints 
		model_response_intervals = [interval for order, interval in model_response_dict.items() if re.match(r"friction(\d+)", order)] 

		metrics = interval_overlap_function(human_response_intervals, model_response_intervals)
		
		precision += metrics["precision"]
		recall += metrics["recall"]
		f1 += metrics["f1"]

	precision = precision / len(convo_ids)
	recall = recall / len(convo_ids)
	f1 = f1 / len(convo_ids)

	results = {}
	results["precision"] = round(precision*100, 4)
	results["recall"] = round(recall*100, 4)
	results["f1"] = round(f1*100, 4)

	return results


In [58]:
model_map = {
	"gpt4o": gpt4o,
	"gpt4o_w_explanation": gpt4o_w_assist,
	"gpt4o_mini": gpt4omini,
	"gpt4o_mini_w_explanation": gpt4omini_w_assist,
	"llama_8b": llama8b,
	"llama_8b_w_explanation": llama8b_w_assist, 
	"llama_70b": llama70b,
	"llama_70b_w_explanation": llama70b_w_assist
}

In [59]:
import pandas as pd

for metric_function in [get_micro_metrics, get_macro_metrics]:	
	print(f"Metric Function: {metric_function.__name__}")
	for interval_matching_function in [calculate_overlap,  calculate_span_overlap]: 
		print(f"Interval Matching Function: {interval_matching_function.__name__}")
		for model_name, model_output in model_map.items():
			
			results = metric_function(all_convos, model_output, extract_friction_details_corrected, interval_matching_function)
			precision, recall, f1 = results["precision"], results["recall"], results["f1"]
			print(f"Model: {model_name}: Precision: {precision}, Recall: {recall}, F1: {f1}")
		
		print("\n")
	print("\n")
	

Metric Function: get_micro_metrics
Interval Matching Function: calculate_overlap
Model: gpt4o: Precision: 31.5, Recall: 43.6937, F1: 34.011
Model: gpt4o_w_explanation: Precision: 31.6333, Recall: 37.4568, F1: 32.2195
Model: gpt4o_mini: Precision: 32.75, Recall: 27.8585, F1: 28.0058
Model: gpt4o_mini_w_explanation: Precision: 28.5411, Recall: 28.6689, F1: 26.5134
Model: llama_8b: Precision: 16.7171, Recall: 47.2779, F1: 22.531
Model: llama_8b_w_explanation: Precision: 15.977, Recall: 46.3289, F1: 21.7343
Model: llama_70b: Precision: 21.6966, Recall: 48.0949, F1: 27.9698
Model: llama_70b_w_explanation: Precision: 16.7179, Recall: 39.8273, F1: 22.0649


Interval Matching Function: calculate_span_overlap
Model: gpt4o: Precision: 13.4995, Recall: 18.7378, F1: 14.6067
Model: gpt4o_w_explanation: Precision: 13.5358, Recall: 16.5915, F1: 14.003
Model: gpt4o_mini: Precision: 13.6653, Recall: 12.3207, F1: 12.1031
Model: gpt4o_mini_w_explanation: Precision: 13.6279, Recall: 14.1069, F1: 12.8093
M

### Rewrite the macro and micro averaging code to work with DistilRoberta outputs

In [50]:
def get_macro_metrics_dr(convo_ids: list[str],
					model_outputs: dict,
					interval_overlap_function: callable) -> dict: 
	
	overlap = 0 
	total_human_intervals = 0
	total_model_intervals = 0

	for convo_id in convo_ids:
		#print(convo_id)

		# load the human annotations 
		human_annotations = loader.get_annotation_data(convo_id)
		model_response_intervals = model_outputs[convo_id]
		friction_turns = human_annotations["friction_turns"]

		# if no friction turns are present, skip the conversation
		if len(friction_turns) == 0:
			#continue
			pass 

		human_response_intervals = [interval for order, interval in friction_turns.items()]
		# model_response = model_outputs.get_output_of_convo(convo_id)
		# model_response_dict = parsing_function(model_response["response"])

		# if "friction_present" not in model_response_dict:
		# 	print("Something has gone wrong")

		# # if keys start with "friction", then it is a friction interval
		# # make all intervals ints 
		# model_response_intervals = [interval for order, interval in model_response_dict.items() if re.match(r"friction(\d+)", order)] 

		metrics = interval_overlap_function(human_response_intervals, model_response_intervals)
		
		overlap += metrics["overlap"]
		total_human_intervals += metrics["total_human_intervals"]
		total_model_intervals += metrics["total_model_intervals"]

	results =  {
		"overlap": overlap,
		"total_human_intervals": total_human_intervals,
		"total_model_intervals": total_model_intervals,
	}

	precision = overlap / total_model_intervals if total_model_intervals > 0 else 0
	recall = overlap / total_human_intervals if total_human_intervals > 0 else 0
	f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

	results["precision"] = round(precision*100, 4)
	results["recall"] = round(recall*100, 4)
	results["f1"] = round(f1*100, 4)

	return results


def get_micro_metrics_dr(convo_ids: list[str],
					model_outputs: dict,
					interval_overlap_function: callable) -> dict: 
	
	precision = 0
	recall = 0
	f1 = 0

	for convo_id in convo_ids:
		#print(convo_id)

		# load the human annotations 
		human_annotations = loader.get_annotation_data(convo_id)
		model_response_intervals = model_outputs[convo_id]
		friction_turns = human_annotations["friction_turns"]

		# if no friction turns are present, skip the conversation
		if len(friction_turns) == 0:
			#continue
			pass 

		human_response_intervals = [interval for order, interval in friction_turns.items()]

		metrics = interval_overlap_function(human_response_intervals, model_response_intervals)
		
		precision += metrics["precision"]
		recall += metrics["recall"]
		f1 += metrics["f1"]

	precision = precision / len(convo_ids)
	recall = recall / len(convo_ids)
	f1 = f1 / len(convo_ids)

	results = {}
	results["precision"] = round(precision*100, 4)
	results["recall"] = round(recall*100, 4)
	results["f1"] = round(f1*100, 4)

	return results

#### Finally, calculate results for distilroberta with context 3 and context 5



In [55]:
# in case of distilroberta, we are doing k-fold cross validation, so we need to load the batch to test map
batch_test_map = read_json("../data/model_outputs/distilroberta_outputs/batch_to_test_map.json")

import pandas as pd

def find_intervals_by_conv_id(conv_id, df):
	# Load the CSV file into a DataFrame
	# first, turn all conv_id to strings
	df['conv_id'] = df['conv_id'].astype(str)

	# Filter rows by the given conv_id
	filtered_df = df[df['conv_id'] == conv_id]

	# if that conv_id is not found, raise an error 
	if filtered_df.empty:
		# try to see if the float() of the conv_id is in the dataframe
		try:
			filtered_df = df[df['conv_id'] == str(int(float(conv_id)))]
		except:
			raise ValueError(f"Conversation ID {conv_id} not found in the DataFrame.")
	
	# Sort the filtered DataFrame by turn_id
	sorted_df = filtered_df.sort_values(by='turn_id')
	
	# Initialize a list to store the intervals
	intervals = []
	
	# Initialize a list to collect turn_ids for the current interval
	current_interval = []

	# Iterate over each row in the sorted DataFrame
	for _, row in sorted_df.iterrows():
		turn_id = row['turn_id']
		label = row['prediction']

		if label == 1:
			current_interval.append(turn_id)  # Add to current interval
		else:
			if current_interval:
				intervals.append(current_interval)  # Save the interval
				current_interval = []  # Clear for next interval

	# Check if there's an unfinished interval
	if current_interval:
		intervals.append(current_interval)

	return intervals

In [61]:
def compute_model_metrics(context, batch_test_map, metric_function, model_name="distilroberta"):
	"""
	Input:
	context: 3 or 5 
	batch_test_map: dict mapping batch index to list of conversation ids
	metric_function: micro or macro averaging 
	
	Output: 
	
	"""
	import pandas as pd
	import numpy as np
	from src.utils import read_jsonlines
	
	# Initialize metric lists
	overlap_precisions = []
	overlap_recalls = []
	overlap_f1s = []
	
	span_overlap_precisions = []
	span_overlap_recalls = [] 
	span_overlap_f1s = []
	
	print(f"Metric Function: {metric_function.__name__}")
		
	for index, convo_id_list in tqdm(enumerate(batch_test_map.values())):
		path_to_data_file = f"../data/model_outputs/{model_name}_outputs/context_{context}/split_{index}/test.csv"
		path_to_predictions = f"../data/model_outputs/{model_name}_outputs/context_{context}/split_{index}/predict_results_None.txt"
		
		df = pd.read_csv(path_to_data_file)
		predictions = read_jsonlines(path_to_predictions)
		preds = [p["prediction"] for p in predictions]
		df["prediction"] = preds

		model_output_dict = {}

		for conv_id in convo_id_list: 
			intervals = find_intervals_by_conv_id(conv_id, df)
			model_output_dict[conv_id] = intervals

		for interval_matching_function in [calculate_overlap, calculate_span_overlap]:
			#print(f"Interval Matching Function: {interval_matching_function.__name__}")
			results = metric_function(convo_id_list, model_output_dict, interval_matching_function)
			precision, recall, f1 = results["precision"], results["recall"], results["f1"]
			#print(f"Precision: {precision}, Recall: {recall}, F1: {f1}")

			if interval_matching_function == calculate_overlap:
				overlap_precisions.append(precision)
				overlap_recalls.append(recall)
				overlap_f1s.append(f1)

			elif interval_matching_function == calculate_span_overlap:
				span_overlap_precisions.append(precision)
				span_overlap_recalls.append(recall)
				span_overlap_f1s.append(f1)
	
				#print("\n")
			#print("\n")
	
	# Compute means
	results = {
		"overlap_precision_mean": np.mean(overlap_precisions),
		"overlap_recall_mean": np.mean(overlap_recalls),
		"overlap_f1_mean": np.mean(overlap_f1s),
		"span_overlap_precision_mean": np.mean(span_overlap_precisions),
		"span_overlap_recall_mean": np.mean(span_overlap_recalls),
		"span_overlap_f1_mean": np.mean(span_overlap_f1s)
	}
	
	return results

In [62]:
# get macro-averaged results for distilroberta with context 3 and 5

macro_3 = compute_model_metrics(3, batch_test_map, get_macro_metrics_dr, model_name="distilroberta")
macro_5 = compute_model_metrics(5, batch_test_map, get_macro_metrics_dr, model_name="distilroberta")


Metric Function: get_macro_metrics_dr


0it [00:00, ?it/s]

5it [00:31,  6.35s/it]


Metric Function: get_macro_metrics_dr


5it [00:31,  6.32s/it]


In [63]:
# compute micro-averaged results for distilroberta with context 3 and 5

micro_3 = compute_model_metrics(3, batch_test_map, get_micro_metrics_dr, model_name="distilroberta")
micro_5 = compute_model_metrics(5, batch_test_map, get_micro_metrics_dr, model_name="distilroberta")

Metric Function: get_micro_metrics_dr


5it [00:31,  6.32s/it]


Metric Function: get_micro_metrics_dr


5it [00:31,  6.36s/it]


In [65]:
micro_3

{'overlap_precision_mean': 14.5155,
 'overlap_recall_mean': 26.32228,
 'overlap_f1_mean': 16.36644,
 'span_overlap_precision_mean': 5.93104,
 'span_overlap_recall_mean': 10.12008,
 'span_overlap_f1_mean': 6.56902}

In [66]:
micro_5

{'overlap_precision_mean': 13.0267,
 'overlap_recall_mean': 26.13882,
 'overlap_f1_mean': 15.5173,
 'span_overlap_precision_mean': 6.3838,
 'span_overlap_recall_mean': 11.84358,
 'span_overlap_f1_mean': 7.431660000000001}