In [1]:
import os
import re
import pickle
import numpy as np
import pandas as pd
from dotenv import dotenv_values
from langchain import PromptTemplate, LLMChain, OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage, AIMessage

In [2]:
config = dotenv_values("../.env")
os.environ['OPENAI_API_KEY'] = config["OPENAI_API_KEY"]
OPENAI_API_KEY = config["OPENAI_API_KEY"]

In [3]:
labels_to_text = {
"time.event.locations": "time event locations",
"music.artist.album": "music artist album",
"sports.sports_team.sport": "sports team",
"baseball.baseball_team.league": "baseball team league",
"tv.tv_program.country_of_origin": "tv program origin country",
"music.album.artist": "music album artist",
"sports.sports_team.location": "sports team location",
"time.event.instance_of_recurring_event": "instance of recurring event",
"aviation.airline.hubs": "airline hubs",
"sports.sports_championship_event.champion": "sports championship event",
"sports.sports_facility.teams": "sports facility teams",
"baseball.baseball_player.position_s": "baseball player positions",
"sports.sports_league.teams-sports.sports_league_participation.team": "sports league participation team",
"tv.tv_network.programs-tv.tv_network_duration.program": "tv duration program",
"sports.sports_league_season.league": "sports season league",
"olympics.olympic_athlete.country-olympics.olympic_athlete_affiliation.country": "olympic athlete affiliation country",
"american_football.football_player.position_s": "american football player position",
"music.composer.compositions": "music composer compositions",
"meteorology.tropical_cyclone.tropical_cyclone_season": "tropical cyclone season",
"cvg.computer_videogame.developer": "cvg videogame developer",
"tv.tv_character.appeared_in_tv_program-tv.regular_tv_appearance.actor": "tv actor appearance",
"cvg.computer_videogame.publisher": "cvg videogame publisher",
"soccer.football_player.position_s": "football player position",
"tv.tv_program.original_network-tv.tv_network_duration.network": "tv duration network",
"music.composition.composer": "music composition composer",
"ice_hockey.hockey_player.hockey_position": "ice hockey player position",
"book.author.works_written": "authors written books",
"film.film.genre": "film genre",
"film.film.directed_by": "film directed by",
"film.film.produced_by": "film produced by",
"film.film.language": "film language",
"broadcast.broadcast.area_served": "broadcast area served",
"award.award_category.category_of": "award category",
"location.location.nearby_airports": "location of nearby airports",
"location.country.official_language": "country official language"
}

In [4]:
text_to_label = {
'time event locations': 'time.event.locations',
'music artist album': 'music.artist.album',
'sports team': 'sports.sports_team.sport',
'baseball team league': 'baseball.baseball_team.league',
'tv program origin country': 'tv.tv_program.country_of_origin',
'music album artist': 'music.album.artist',
'sports team location': 'sports.sports_team.location',
'instance of recurring event': 'time.event.instance_of_recurring_event',
'airline hubs': 'aviation.airline.hubs',
'sports championship event': 'sports.sports_championship_event.champion',
'sports facility teams': 'sports.sports_facility.teams',
'baseball player positions': 'baseball.baseball_player.position_s',
'sports league participation team': 'sports.sports_league.teams-sports.sports_league_participation.team',
'tv duration program': 'tv.tv_network.programs-tv.tv_network_duration.program',
'sports season league': 'sports.sports_league_season.league',
'olympic athlete affiliation country': 'olympics.olympic_athlete.country-olympics.olympic_athlete_affiliation.country',
'american football player position': 'american_football.football_player.position_s',
'music composer compositions': 'music.composer.compositions',
'tropical cyclone season': 'meteorology.tropical_cyclone.tropical_cyclone_season',
'cvg videogame developer': 'cvg.computer_videogame.developer',
'tv actor appearance': 'tv.tv_character.appeared_in_tv_program-tv.regular_tv_appearance.actor',
'cvg videogame publisher': 'cvg.computer_videogame.publisher',
'football player position': 'soccer.football_player.position_s',
'tv duration network': 'tv.tv_program.original_network-tv.tv_network_duration.network',
'music composition composer': 'music.composition.composer',
'ice hockey player position': 'ice_hockey.hockey_player.hockey_position',
'authors written books': 'book.author.works_written',
'film genre': 'film.film.genre',
'film directed by': 'film.film.directed_by',
'film produced by': 'film.film.produced_by',
'film language': 'film.film.language',
'broadcast area served': 'broadcast.broadcast.area_served',
'award category': 'award.award_category.category_of',
'location of nearby airports': 'location.location.nearby_airports',
'country official language': 'location.country.official_language'
}

In [5]:
all_labels = list(labels_to_text.keys())
print(all_labels)
len(all_labels)

['time.event.locations', 'music.artist.album', 'sports.sports_team.sport', 'baseball.baseball_team.league', 'tv.tv_program.country_of_origin', 'music.album.artist', 'sports.sports_team.location', 'time.event.instance_of_recurring_event', 'aviation.airline.hubs', 'sports.sports_championship_event.champion', 'sports.sports_facility.teams', 'baseball.baseball_player.position_s', 'sports.sports_league.teams-sports.sports_league_participation.team', 'tv.tv_network.programs-tv.tv_network_duration.program', 'sports.sports_league_season.league', 'olympics.olympic_athlete.country-olympics.olympic_athlete_affiliation.country', 'american_football.football_player.position_s', 'music.composer.compositions', 'meteorology.tropical_cyclone.tropical_cyclone_season', 'cvg.computer_videogame.developer', 'tv.tv_character.appeared_in_tv_program-tv.regular_tv_appearance.actor', 'cvg.computer_videogame.publisher', 'soccer.football_player.position_s', 'tv.tv_program.original_network-tv.tv_network_duration.net

35

## Load test (and training) set

In [6]:
with open('wiki-cpa-train-column.pkl', "rb") as f:
    train = pickle.load(f)
with open('wiki-cpa-test-column.pkl', "rb") as f:
    test = pickle.load(f)

examples = [example[2] for example in test ]
labels = [example[3] for example in test ]

train_examples = [ example[2] for example in train ]
train_labels = [ [labels_to_text[label]] for example in train for label in example[3] ]


In [8]:
len(train)

18340

In [9]:
len(test)

535

In [8]:
labels_joined = ", ".join([labels_to_text[l] for l in labels_to_text])
labels_joined

'time event locations, music artist album, sports team, baseball team league, tv program origin country, music album artist, sports team location, instance of recurring event, airline hubs, sports championship event, sports facility teams, baseball player positions, sports league participation team, tv duration program, sports season league, olympic athlete affiliation country, american football player position, music composer compositions, tropical cyclone season, cvg videogame developer, tv actor appearance, cvg videogame publisher, football player position, tv duration network, music composition composer, ice hockey player position, authors written books, film genre, film directed by, film produced by, film language, broadcast area served, award category, location of nearby airports, country official language'

In [9]:
model_name = 'gpt-3.5-turbo-1106'
chat = ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=0, model=model_name)   #max_tokens=4

## Choose setup: zero-shot, one-shot or five-shot

CPA COLUMN


ZERO-SHOT

In [132]:
#role
nr="zero"
prompt_name = "r"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))

    messages.append(HumanMessage(content=f"Classify this relation: {example}"))

    res = chat(messages)
    preds.append(res.content)

In [15]:
#role + instructions 
nr="zero"
prompt_name = "r+i"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))

    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [143]:
#role + instructions + step by step
nr="zero"
prompt_name = "r+i+s_b_s"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [146]:
#role + instructions + motivation
nr="zero"
prompt_name = "r+i+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))

    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [149]:
#role + instructions + CONTEXT
nr="zero"
prompt_name = "r+i+c"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content=f"CONTEXT: Column Property Annotation is a sub-task of Table Annotation and refers to predicting the semantic relation between two or more columns. You have the same task, you are required to annotate the relation between two given columns."))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

ONE-SHOT

In [None]:
#role + instructions
import random 
 
nr="one"
prompt_name = "r+i"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
 
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + step by step
import random

nr="one"
prompt_name = "r+i+s_b_s"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))

    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [138]:
#role + instructions + motivation
import random

nr="one"
prompt_name = "r+i+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))


    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + CONTEXT
import random

nr="one"
prompt_name = "r+i+c"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content=f"CONTEXT: Column Property Annotation is a sub-task of Table Annotation and refers to predicting the semantic relation between two or more columns. You have the same task, you are required to annotate the relation between two given columns."))
    
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + CONTEXT example
import random

nr="one"
prompt_name = "r+i+c.example"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
 
    messages.append(SystemMessage(content=f"CONTEXT: Column Property Annotation is a sub-task of Table Annotation and refers to predicting the semantic relation between two or more columns. You have the same task, here is an example how you could solve a CPA task: 'Classify the relationship between these two columns: Columm1: Dog, Cat, Dog.  Column2: lis, moli, brauni.'"
                                  "First we check Columm1: Dog, Cat, Dog."   
                                  "Now we check Column2: lis, moli, brauni. Analyze Column 2 in relation to Column 1. Predict the relation between Column 2 and Column 1"
                                  "Answer: Column 2: animal name, pet name"))
    
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

FIVE-SHOT

In [149]:
#role + instructions
import random 
 
nr= "five"
prompt_name = "r+i"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))


    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + step by step
import random

nr= "five"
prompt_name = "r+i+s_b_s"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))

    
    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [164]:
#role + instructions + motivation
import random

nr= "five"
prompt_name = "r+i+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))


    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + CONTEXT
import random

nr= "five"
prompt_name = "r+i+c"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify the relationship between two columns with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Your instructions are: 1. Review the provided column values. 2. Carefully examine the values of the two columns. 3. Select one label or more only if needed, that best represents the relationship between these two columns. 4. Answer with your final selected labels. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. 5. Ensure that you answer using only the labels from the provided label-set."))
    messages.append(SystemMessage(content=f"CONTEXT: Column Property Annotation is a sub-task of Table Annotation and refers to predicting the semantic relation between two or more columns. You have the same task, you are required to annotate the relation between two given columns."))
    
    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this relation: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this relation: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
preds

In [167]:
file_name=f'Predictions/{model_name}/column/{nr}-shot/chat-column-{prompt_name}-{nr}-shot.pkl'
f = open(file_name,'wb')
pickle.dump(preds,f)
f.close()

## Evaluation

In [None]:
predictions = []
i = 0

for j, pred in enumerate(preds):
    from_sent = re.findall('"([^"]*)"', pred)

    if len(from_sent) == 0:
        if ":" in pred:
            pred = pred.split(':')[1]
        if "." in pred:
            pred = pred.split('.')[0]
        pred = pred.strip().lower()

        column_predictions = []

        if "," in pred:
            separator = ","

            multilabels = pred.split(separator)

            for multi in multilabels:
                if "\n" in multi:
                    multi = multi.split('\n')[0].strip()
                if "," in multi:
                    multi = multi.split(",")[0].strip()
                if '(' in multi:
                    multi = multi.split("(")[0].strip()
                if '.' in multi:
                    multi = multi.split(".")[0].strip()
                multi = multi.strip().lower()

                if multi in text_to_label:
                    column_predictions.append(text_to_label[multi])
                else:
                    print(f"For test example {i} out of label space prediction: {multi}")
                    column_predictions.append('-')
        else:
            if pred in text_to_label:
                column_predictions.append(text_to_label[pred])
            else:
                if any(label in pred for label in text_to_label):
                    for label in text_to_label:
                        if label in pred:
                            column_predictions.append(text_to_label[label])
                            break
                else:
                    print(f"For test example {i} out of label space prediction: {pred}")
                    column_predictions.append('-')

        predictions.append(column_predictions)

    else:
        if from_sent[0].lower() in text_to_label:
            predictions.append(text_to_label[from_sent[0].lower()])
        else:
            print(f"For test example {i} out of label space prediction: {pred}")
            predictions.append('-')
            
    i += 1

In [None]:
predictions

### Calculate Precision, Recall, Macro-F1 and Micro-F1

In [172]:
def calculate_f1_scores(y_tests, y_preds, num_classes, types):
    y_tests = [[types.index(l) for l in y] for y in y_tests]
    y_preds = [[types.index(l) if l in types and l != -1 else num_classes - 1  for l in y] for y in y_preds]
  
    cm = np.zeros(shape=(num_classes, num_classes))

    for i, labels in enumerate(y_tests):
        for label in labels:
            if label not in y_preds[i]:
                cm[-1][label] += 1  # FN
            else:
                cm[label][label] += 1  # TP

    for i, labels in enumerate(y_preds):
        for label in labels:
            if label not in y_tests[i]:
                cm[label][-1] += 1  # FP

    report = {}

    for j in range(len(cm[0])):
        report[j] = {}
        report[j]['FN'] = 0
        report[j]['FP'] = 0
        report[j]['TP'] = cm[j][j]

        for i in range(len(cm)):
            if i != j:
                report[j]['FN'] += cm[i][j]
        for k in range(len(cm[0])):
            if k != j:
                report[j]['FP'] += cm[j][k]

        precision = report[j]['TP'] / (report[j]['TP'] + report[j]['FP'])
        recall = report[j]['TP'] / (report[j]['TP'] + report[j]['FN'])
        f1 = 2 * precision * recall / (precision + recall)

        if np.isnan(f1):
            f1 = 0
        if np.isnan(precision):
            precision = 0
        if np.isnan(recall):
            recall = 0

        report[j]['p'] = precision
        report[j]['r'] = recall
        report[j]['f1'] = f1

    all_fn = 0
    all_tp = 0
    all_fp = 0

    for r in report:
        if r != num_classes - 1:
            all_fn += report[r]['FN']
            all_tp += report[r]['TP']
            all_fp += report[r]['FP']

    class_f1s = [report[class_]['f1'] for class_ in report]
    class_p = [report[class_]['p'] for class_ in report]
    class_r = [report[class_]['r'] for class_ in report]
    macro_f1 = sum(class_f1s[:-1]) / (num_classes - 1)

    p = sum(class_p[:-1]) / (num_classes - 1)
    r = sum(class_r[:-1]) / (num_classes - 1)
    micro_f1 = all_tp / (all_tp + (1 / 2 * (all_fp + all_fn)))

    per_class_eval = {}
    for index, t in enumerate(types[:-1]):
        per_class_eval[t] = {"Precision": class_p[index], "Recall": class_r[index], "F1": class_f1s[index]}

    evaluation = {
        "Micro-F1": micro_f1,
        "Macro-F1": macro_f1,
        "Precision": p,
        "Recall": r
    }

    return [evaluation, per_class_eval]


In [None]:
list_set_labels = list(labels_to_text.keys())
types = list_set_labels
types = types + ["-"] if "-" in predictions else types
evaluation, per_class_eval = calculate_f1_scores(labels, predictions, len(types), types)

In [None]:
evaluation

In [None]:
per_class_eval

## Error Analysis

In [None]:
errors = 0
for i in range(len(predictions)):
    label_set = set(labels[i]) 
    
    prediction_set = set(predictions[i])
    
    if len(list(label_set-prediction_set) + list(prediction_set-label_set)) != 0:
        for y in label_set:
            if y not in prediction_set:
                errors +=1
        print(f"Predicted as {predictions[i]} when it was {label_set}")
errors

### Re-load previous preds files

In [None]:
with open(f'Predictions/{model_name}/column/{nr}-shot/chat-column-{prompt_name}-{nr}-shot.pkl', "rb") as f:
    preds = pickle.load(f)

In [None]:
preds