# Schema Linking Accuracy Metric (SLAM):

### Imports

In [116]:
import pandas as pd
import ast
import json

## Creating DFs

### Function to import JSON

In [123]:
def import_json_file(file_path):
    # Load the JSON data from the file
    with open(file_path, 'r') as json_file:
        json_data = json.load(json_file)

    # Create a DataFrame from the extracted data
    df_dev_set = pd.DataFrame(json_data, columns=["question_id", "db_id", "tables"])

    #df_dev_set.head()
    return df_dev_set

### Function to import CSV

In [121]:
# Function to import a CSV file into a pandas DataFrame with the given schema
def import_csv_file(file_path):
    # Define a custom converter to parse the string representation of lists
    def parse_list(x):
        try:
            return ast.literal_eval(x)
        except (ValueError, SyntaxError):
            return []

    # Specify the column data types
    dtype_dict = {
        'question_id': int,
        'tables': str,
        'gen_tables': str,
        'total_tables': int
    }

    # Specify the converters for list columns
    converters = {
        'tables': parse_list,
        'gen_tables': parse_list
    }

    # Read the CSV file into a DataFrame
    df = pd.read_csv(file_path, dtype=dtype_dict, converters=converters)
    return df

### Function to compare tables from dev and gen

In [122]:
def compare_lists(ground_truth, predicted):
    # Initialize variables for TP, TN, FP, and FN
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0

    true_positives = len(set(ground_truth) & set(predicted))
    false_positives = len(set(predicted) - set(ground_truth))
    false_negatives = len(set(ground_truth) - set(predicted))
    true_negatives = 0  # Not applicable for this scenario

    return true_positives, true_negatives, false_positives, false_negatives



### Processing

In [124]:

df1 = import_json_file('dev/dev.json') 
#df1.head()
df2 = import_csv_file('gen_tables.csv')
#df2.head()
merged_df = df1.merge(df2, on='question_id', how='inner')
merged_df = merged_df.rename(columns={'tables_x': 'tables_dev', 'tables_y': 'tables_csv'})

merged_df.head()

final_df = pd.DataFrame(columns=['question_id','db_id','tables_dev','tables_csv','gen_tables','total_tables','tp','tn','fp','fn','precision','recall','f1_score'])

for index, row in enumerate(merged_df.iterrows()):
    row = row[1]
    #row2 = row2[1]
    tp, tn, fp, fn = compare_lists(row['tables_dev'], row['gen_tables'])


    # Calculate precision, recall, and F1 score for each row
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    if precision == 0 and recall == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    
    #print('question_id: ',row['question_id'],'|| gen_tables:',row['gen_tables'],'|| true_tables:',row['tables_dev'])
    #print("True Positives:", tp)
    #print("True Negatives:", tn)
    #print("False Positives:", fp)
    #print("False Negatives:", fn)
    #print("Precision:", precision)
    #print("Recall:", recall)
    #print("F1-Score:", f1_score)
    #print('-----------------------------------')

    metrics_dict = {
        'question_id': row['question_id'],
        'db_id': row['db_id'],
        'tables_dev': row['tables_dev'],
        'tables_csv': row['tables_csv'],
        'gen_tables': row['gen_tables'],
        'total_tables': row['total_tables'],
        'tp': tp,
        'tn': tn,
        'fp': fp,
        'fn': fn,
        'precision': precision,
        'recall': recall,
        'f1_score': f1_score
    }

    final_df = pd.concat([final_df, pd.DataFrame([metrics_dict])], ignore_index=True)

final_df.to_csv('metrics_result.csv', index=False)
final_df.head()

  df = pd.read_csv(file_path, dtype=dtype_dict, converters=converters)
  df = pd.read_csv(file_path, dtype=dtype_dict, converters=converters)
  final_df = pd.concat([final_df, pd.DataFrame([metrics_dict])], ignore_index=True)


Unnamed: 0,question_id,db_id,tables_dev,tables_csv,gen_tables,total_tables,tp,tn,fp,fn,precision,recall,f1_score
0,1,california_schools,[frpm],[frpm],"[frpm, Grapes, Lemon]",12,1,0,2,0,0.333333,1.0,0.5
1,2,california_schools,"[frpm, schools]","[Orange, Mango, Strawberry]","[Apple, Blueberry, Pineapple]",17,0,0,3,2,0.0,0.0,0.0
2,3,california_schools,"[frpm, schools]","[Grapes, Kiwi, Peach]","[Lemon, Banana, Cherry]",25,0,0,3,2,0.0,0.0,0.0
3,4,california_schools,"[frpm, schools]","[Blueberry, Pear, Watermelon]","[Mango, Strawberry, Apple]",8,0,0,3,2,0.0,0.0,0.0
4,5,california_schools,"[satscores, schools]","[Pineapple, Lime, Raspberry]","[Kiwi, Grapes, Peach]",36,0,0,3,2,0.0,0.0,0.0
