In [2]:
import os
import re
import pickle
import numpy as np
import pandas as pd
from dotenv import dotenv_values
from langchain import PromptTemplate, LLMChain, OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage, AIMessage

In [3]:
config = dotenv_values("../.env")
os.environ['OPENAI_API_KEY'] = config["OPENAI_API_KEY"]
OPENAI_API_KEY = config["OPENAI_API_KEY"]

In [6]:
labels_to_text = {
    "soccer.football_league": "football league",
    "government.government_office_or_title": "government office or title",
    "organization.non_profit_organization": "non profit organization",
    "olympics.olympic_games": "olympic games",
    "cvg.cvg_genre": "cvg genre",
    "ice_hockey.hockey_position": "ice hockey position",
    "tv.tv_network": "tv network",
    "aviation.airline": "aviation airline",
    "american_football.football_conference": "american football conference",
    "soccer.football_world_cup": "football world cup",
    "american_football.football_coach": "american football coach",
    "military.military_unit": "military unit",
    "military.military_post": "military post",
    "music.media_format": "music media format",
    "tv.tv_personality": "tv personality",
    "baseball.baseball_team": "baseball team",
    "cvg.cvg_developer": "cvg developer",
    "soccer.football_award": "football award",
    "ice_hockey.hockey_team": "ice hockey team",
    "tv.tv_writer": "tv writer",
    "meteorology.tropical_cyclone_season": "tropical cyclone season",
    "soccer.fifa": "soccer fifa",
    "cvg.cvg_publisher": "cvg publisher",
    "baseball.baseball_player": "baseball player",
    "sports.sports_championship": "sports championship",
    "soccer.football_team_manager": "football team manager",
    "sports.golfer": "sports golfer",
    "baseball.baseball_position": "baseball position",
    "military.rank": "military rank",
    "sports.sports_championship": "sports championship",
    "cvg.cvg_platform": "cvg platform",
    "music.musical_group": "musical group",
    "amusement_parks.ride": "amusement parks ride",
    "music.genre": "music genre",
    "music.lyricist": "music lyricist",
    "music.record_label": "music record label",
    "meteorology.tropical_cyclone": "tropical cyclone",
    "aviation.airport": "airport"
}

In [7]:
text_to_label = {
    "football league": "soccer.football_league",
    "government office or title": "government.government_office_or_title",
    "non profit organization": "organization.non_profit_organization",
    "olympic games": "olympics.olympic_games",
    "cvg genre": "cvg.cvg_genre",
    "ice hockey position": "ice_hockey.hockey_position",
    "tv network": "tv.tv_network",
    "aviation airline": "aviation.airline",
    "american football conference": "american_football.football_conference",
    "football world cup": "soccer.football_world_cup",
    "american football coach": "american_football.football_coach",
    "military unit": "military.military_unit",
    "military post": "military.military_post",
    "music media format": "music.media_format",
    "tv personality": "tv.tv_personality",
    "baseball team": "baseball.baseball_team",
    "cvg developer": "cvg.cvg_developer",
    "football award": "soccer.football_award",
    "ice hockey team": "ice_hockey.hockey_team",
    "tv writer": "tv.tv_writer",
    "tropical cyclone season": "meteorology.tropical_cyclone_season",
    "soccer fifa": "soccer.fifa",
    "cvg publisher": "cvg.cvg_publisher",
    "baseball player": "baseball.baseball_player",
    "sports championship": "sports.sports_championship",
    "football team manager": "soccer.football_team_manager",
    "sports golfer": "sports.golfer",
    "baseball position": "baseball.baseball_position",
    "military rank": "military.rank",
    "cvg platform": "cvg.cvg_platform",
    "musical group": "music.musical_group",
    "amusement parks ride": "amusement_parks.ride",
    "music genre": "music.genre",
    "music lyricist": "music.lyricist",
    "music record label": "music.record_label",
    "tropical cyclone": "meteorology.tropical_cyclone",
    "airport": "aviation.airport"
}

## Load test (and training) set

In [8]:
with open('wiki-cta-train-column.pkl', "rb") as f:
    train = pickle.load(f)
with open('wiki-cta-test-column.pkl', "rb") as f:
    test = pickle.load(f)

examples = [example[2] for example in test ]
labels = [example[3] for example in test ]
#labels = [ [labels_to_text[label]] for example in test for label in example[3] ]

train_examples = [ example[2] for example in train ]
train_labels = [ [labels_to_text[label]] for example in train for label in example[3] ]

In [7]:
len(test) #nr i kolonave te test

807

In [8]:
len(train) #nr i kolonave te train

38720

In [7]:
labels_joined = ", ".join([labels_to_text[l] for l in labels_to_text])
labels_joined

'football league, government office or title, non profit organization, olympic games, cvg genre, ice hockey position, tv network, aviation airline, american football conference, football world cup, american football coach, military unit, military post, music media format, tv personality, baseball team, cvg developer, football award, ice hockey team, tv writer, tropical cyclone season, soccer fifa, cvg publisher, baseball player, sports championship, football team manager, sports golfer, baseball position, military rank, cvg platform, musical group, amusement parks ride, music genre, music lyricist, music record label, tropical cyclone, airport'

In [8]:
model_name = 'gpt-3.5-turbo-1106'
chat = ChatOpenAI(openai_api_key=OPENAI_API_KEY, temperature=0, model=model_name)

## Choose setup: zero-shot, one-shot or five-shot

CTA COLUMN

ZERO-SHOT

In [10]:
#role
nr="zero"
prompt_name = "r"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))

    messages.append(HumanMessage(content=f"Classify this column: {example}"))

    res = chat(messages)
    preds.append(res.content)

In [10]:
#role + instructions 
nr="zero"
prompt_name = "r+i"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Respond using only labels from the provided set. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))

    messages.append(HumanMessage(content=f"Classify this column: {example}"))

    #res = chat(messages)
    #preds.append(res.content)

In [20]:
#role + instructions + step by step
nr="zero"
prompt_name = "r+i+s_b_s"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Respond using only labels from the provided set. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))


    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [26]:
#role + instructions + motivation
nr="zero"
prompt_name = "r+i2+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))

    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [31]:
#role + instructions + CONTEXT
nr="zero"
prompt_name = "r+i2+c"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
    
    messages.append(SystemMessage(content=f"CONTEXT: Column Type Annotation is a sub-task of Table Annotation and involves categorizing each column of a table based on its content.  Your task is the same, to analyze and then predict the column type with one or more of the provided labels from the label-set!")) 
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + CONTEXT example  

nr="zero"
prompt_name = "r+i2+c.example"

preds = []
for example in examples:
    messages = []
    
    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
    messages.append(SystemMessage(content=f"CONTEXT: Here is an example of how Column Type Annotation task is solved: 'Classify the given column: Columm1: Dog, Cat, Dog, parrot.'"
                                  "First check the values of Columm1: Dog, Cat, Dog, parrot"   
                                  "Carefully think and analyze the values, decide and predict the label that best fits"
                                  "Answer: Column 1 : animal, pet  "))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

ONE-SHOT

In [37]:
#role + instructions
import random
 
nr="one"
prompt_name = "r+i2"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..


In [None]:
#role + instructions + step by step
import random

nr="one"
prompt_name = "r+i2+s_b_s"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
   
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised RateLimitError: Rate limit reached for gpt-3.5-turbo-1106 in organization org-JnRe4IF9kM1kFkJzhKlurHyV on tokens per min (TPM): Limit 60000, Used 59586, Requested 669. Please try again in 255ms. Visit https://platform.openai.com/account/rate-limits to learn more..


In [40]:
#role + instructions + motivation
import random

nr="one"
prompt_name = "r+i2+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))                         

    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised ServiceUnavailableError: The server is overloaded or not ready yet..


In [None]:
#role + instructions + CONTEXT
import random

nr="one"
prompt_name = "r+i2+c"

preds = []
for example in examples:
    messages = []


    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
    messages.append(SystemMessage(content=f"CONTEXT: Column Type Annotation is a sub-task of Table Annotation and involves categorizing each column of a table based on its content.  Your task is the same, to analyze and then predict the column type with one or more of the provided labels from the label-set!")) 
    
   
    index = random.randint(0, len(train_examples)-1)
    messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
    messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

FIVE-SHOT

In [208]:
#role + instructions
import random 
 
nr= "five"
prompt_name = "r+i2"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + step by step
import random

nr= "five"
prompt_name = "r+i2+s_b_s"

preds = []
for example in examples:
    messages = []


    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content="Let's think step by step."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
                                           
    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [None]:
#role + instructions + motivation
import random

nr= "five"
prompt_name = "r+i2+m"

preds = []
for example in examples:
    messages = []

    messages.append(SystemMessage(content=f"You are a great Table Annotation Specialist and your task is to classify a given column with one or more of the following labels that are separated with comma: {labels_joined}."))
    messages.append(SystemMessage(content= "Your instructions are: \n"
                                            f"1. Examine the column and review the given labels: {labels_joined}. \n"
                                            "2. Analyze the values within the column. \n"
                                            "3. Choose one label or more if needed, that best represents the meaning of the column. \n"
                                            f"4. Avoid duplicate labels when responding. Provide a single unique label, or if multiple, ensure they are distinct. Ensure that your answer contains ONLY labels from the set and no additional text or characters."))
    messages.append(SystemMessage(content="Your answer is very important. Take your time and think well before answering!"))                                      

    
    for i in range(0,5):
        index = random.randint(0, len(train_examples)-1)
        messages.append(HumanMessage(content=f"Classify this column: {train_examples[index]}"))
        messages.append(AIMessage(content=f"{train_labels[index]}"))
    
    messages.append(HumanMessage(content=f"Classify this column: {example}"))
    res = chat(messages)
    preds.append(res.content)

In [228]:
preds

["['football league']",
 "['baseball team']",
 "['baseball team']",
 "['football league', 'ice hockey team']",
 "['baseball player']",
 "['american football coach', 'baseball player']",
 "['ice hockey team']",
 "['football league']",
 "['music record label']",
 "['musical group']",
 "['sports golfer']",
 "['ice hockey position']",
 "['baseball player']",
 "['baseball team']",
 "['baseball position']",
 "['american football conference']",
 "['american football conference', 'football league']",
 "['football league']",
 "['sports golfer', 'baseball player']",
 "['airport']",
 "['ice hockey team', 'sports championship']",
 "['tv personality']",
 "['ice hockey position']",
 "['ice hockey position']",
 "['tv personality']",
 "['football league']",
 "['tv network']",
 "['baseball player']",
 "['football league']",
 "['baseball team']",
 "['baseball team']",
 "['baseball player']",
 "['baseball player']",
 "['sports championship']",
 "['football award', 'sports championship']",
 "['football le

In [229]:
file_name=f'Predictions/{model_name}/column/{nr}-shot/chat-column-{prompt_name}-{nr}-shot.pkl'
f = open(file_name,'wb')
pickle.dump(preds,f)
f.close()

In [None]:
""" with open(f'Predictions/{model_name}/column/one-shot/chat-column-r+i2+m-one-shot.pkl', "rb") as f:
    preds = pickle.load(f)
preds    """ 

## Evaluation

In [None]:
""" predictions = []
for i, pred in enumerate(preds):
    from_sent = re.findall('"([^"]*)"',pred)
    if len(from_sent) == 0:
        if ":" in pred:
            pred = pred.split(':')[1]
        if "." in pred:
            pred = pred.split('.')[0]
        pred = pred.strip().lower()
        
        if pred in text_to_label:
            predictions.append(text_to_label[pred])
        else:
            if any(label in pred for label in text_to_label):
                for label in text_to_label:
                    if label in pred:
                        predictions.append(text_to_label[label])
                        break
            else:
                print(f"For test example {i} out of label space prediction: {pred}")
                predictions.append('-')

    else:
        if from_sent[0].lower() in text_to_label:
            predictions.append(text_to_label[from_sent[0].lower()])
        else:
            print(f"For test example {i} out of label space prediction: {pred}")
            predictions.append('-') """

In [230]:
predictions = []
i = 0

for j, pred in enumerate(preds):
    from_sent = re.findall('"([^"]*)"', pred)

    if len(from_sent) == 0:
        if ":" in pred:
            pred = pred.split(':')[1]
        if "." in pred:
            pred = pred.split('.')[0]
        pred = pred.strip().lower()

        column_predictions = []

        if "," in pred:
            separator = ","

            multilabels = pred.split(separator)

            for multi in multilabels:
                if "\n" in multi:
                    multi = multi.split('\n')[0].strip()
                if "," in multi:
                    multi = multi.split(",")[0].strip()
                if '(' in multi:
                    multi = multi.split("(")[0].strip()
                if '.' in multi:
                    multi = multi.split(".")[0].strip()
                multi = multi.strip().lower()

                if multi in text_to_label:
                    column_predictions.append(text_to_label[multi])
                else:
                    print(f"For test example {i} out of label space prediction: {multi}")
                    column_predictions.append('-')
        else:
            if pred in text_to_label:
                column_predictions.append(text_to_label[pred])
            else:
                if any(label in pred for label in text_to_label):
                    for label in text_to_label:
                        if label in pred:
                            column_predictions.append(text_to_label[label])
                            break
                else:
                    print(f"For test example {i} out of label space prediction: {pred}")
                    column_predictions.append('-')

        predictions.append(column_predictions)

    else:
        if from_sent[0].lower() in text_to_label:
            predictions.append(text_to_label[from_sent[0].lower()])
        else:
            print(f"For test example {i} out of label space prediction: {pred}")
            predictions.append('-')
            
    i += 1


For test example 3 out of label space prediction: ['football league'
For test example 3 out of label space prediction: 'ice hockey team']
For test example 5 out of label space prediction: ['american football coach'
For test example 5 out of label space prediction: 'baseball player']
For test example 16 out of label space prediction: ['american football conference'
For test example 16 out of label space prediction: 'football league']
For test example 18 out of label space prediction: ['sports golfer'
For test example 18 out of label space prediction: 'baseball player']
For test example 20 out of label space prediction: ['ice hockey team'
For test example 20 out of label space prediction: 'sports championship']
For test example 34 out of label space prediction: ['football award'
For test example 34 out of label space prediction: 'sports championship']
For test example 38 out of label space prediction: ['tv personality'
For test example 38 out of label space prediction: 'tv writer']
For t

In [231]:
predictions[:10]

[['soccer.football_league'],
 ['baseball.baseball_team'],
 ['baseball.baseball_team'],
 ['-', '-'],
 ['baseball.baseball_player'],
 ['-', '-'],
 ['ice_hockey.hockey_team'],
 ['soccer.football_league'],
 ['music.record_label'],
 ['music.musical_group']]

### Calculate Precision, Recall, Macro-F1 and Micro-F1

In [None]:
""" def calculate_f1_scores(y_tests, y_preds, num_classes, types):

    y_tests = [types.index(y) for y in y_tests]
    y_preds = [types.index(y) for y in y_preds]
    
    cm = np.zeros(shape=(num_classes,num_classes))
    
    for i in range(len(y_tests)):
        cm[y_preds[i]][y_tests[i]] += 1
        
    report = {}
    
    for j in range(len(cm[0])):
        report[j] = {}
        report[j]['FN'] = 0
        report[j]['FP'] = 0
        report[j]['TP'] = cm[j][j]

        for i in range(len(cm)):
            if i != j:
                report[j]['FN'] += cm[i][j]
        for k in range(len(cm[0])):
            if k != j:
                report[j]['FP'] += cm[j][k]

        precision = report[j]['TP'] / (report[j]['TP'] + report[j]['FP'])
        recall = report[j]['TP'] / (report[j]['TP'] + report[j]['FN'])
        f1 = 2*precision*recall / (precision + recall)
        
        if np.isnan(f1):
            f1 = 0
        if np.isnan(precision):
            f1 = 0
        if np.isnan(recall):
            f1 = 0

        report[j]['p'] =  precision
        report[j]['r'] =  recall
        report[j]['f1'] = f1
    
    all_fn = 0
    all_tp = 0
    all_fp = 0

    for r in report:
        if r != num_classes-1:
            all_fn += report[r]['FN']
            all_tp += report[r]['TP']
            all_fp += report[r]['FP']
        
    class_f1s = [ report[class_]['f1'] for class_ in report]
    class_p = [ 0 if np.isnan(report[class_]['p']) else report[class_]['p'] for class_ in report]
    class_r = [ 0 if np.isnan(report[class_]['r']) else report[class_]['r'] for class_ in report]
    macro_f1 = sum(class_f1s[:-1]) / (num_classes-1)
    
    p =  sum(class_p[:-1]) / (num_classes-1)
    r =  sum(class_r[:-1]) / (num_classes-1)
    micro_f1 = all_tp / ( all_tp + (1/2 * (all_fp + all_fn) )) 
    
    per_class_eval = {}
    for index, t in enumerate(types[:-1]):
        per_class_eval[t] = {"Precision":class_p[index], "Recall": class_r[index], "F1": class_f1s[index]}
    
    evaluation = {
        "Micro-F1": micro_f1,
        "Macro-F1": macro_f1,
        "Precision": p,
        "Recall": r
    }
    
    return [ evaluation, per_class_eval] """


In [232]:
def calculate_f1_scores(y_tests, y_preds, num_classes, types):
    y_tests = [[types.index(l) for l in y] for y in y_tests]
    y_preds = [[types.index(l) if l in types and l != -1 else num_classes - 1  for l in y] for y in y_preds]
  
    cm = np.zeros(shape=(num_classes, num_classes))

    for i, labels in enumerate(y_tests):
        for label in labels:
            if label not in y_preds[i]:
                cm[-1][label] += 1  # FN
            else:
                cm[label][label] += 1  # TP

    for i, labels in enumerate(y_preds):
        for label in labels:
            if label not in y_tests[i]:
                cm[label][-1] += 1  # FP

    report = {}

    for j in range(len(cm[0])):
        report[j] = {}
        report[j]['FN'] = 0
        report[j]['FP'] = 0
        report[j]['TP'] = cm[j][j]

        for i in range(len(cm)):
            if i != j:
                report[j]['FN'] += cm[i][j]
        for k in range(len(cm[0])):
            if k != j:
                report[j]['FP'] += cm[j][k]

        precision = report[j]['TP'] / (report[j]['TP'] + report[j]['FP'])
        recall = report[j]['TP'] / (report[j]['TP'] + report[j]['FN'])
        f1 = 2 * precision * recall / (precision + recall)

        if np.isnan(f1):
            f1 = 0
        if np.isnan(precision):
            precision = 0
        if np.isnan(recall):
            recall = 0

        report[j]['p'] = precision
        report[j]['r'] = recall
        report[j]['f1'] = f1

    all_fn = 0
    all_tp = 0
    all_fp = 0

    for r in report:
        if r != num_classes - 1:
            all_fn += report[r]['FN']
            all_tp += report[r]['TP']
            all_fp += report[r]['FP']

    class_f1s = [report[class_]['f1'] for class_ in report]
    class_p = [report[class_]['p'] for class_ in report]
    class_r = [report[class_]['r'] for class_ in report]
    macro_f1 = sum(class_f1s[:-1]) / (num_classes - 1)

    p = sum(class_p[:-1]) / (num_classes - 1)
    r = sum(class_r[:-1]) / (num_classes - 1)
    micro_f1 = all_tp / (all_tp + (1 / 2 * (all_fp + all_fn)))

    per_class_eval = {}
    for index, t in enumerate(types[:-1]):
        per_class_eval[t] = {"Precision": class_p[index], "Recall": class_r[index], "F1": class_f1s[index]}

    evaluation = {
        "Micro-F1": micro_f1,
        "Macro-F1": macro_f1,
        "Precision": p,
        "Recall": r
    }

    return [evaluation, per_class_eval]


In [233]:
list_set_labels = list(labels_to_text.keys())
types = list_set_labels
types = types + ["-"] if ["-"] in predictions else types
evaluation, per_class_eval = calculate_f1_scores(labels, predictions, len(types), types)

  precision = report[j]['TP'] / (report[j]['TP'] + report[j]['FP'])
  recall = report[j]['TP'] / (report[j]['TP'] + report[j]['FN'])


In [234]:
evaluation

{'Micro-F1': 0.8256274768824307,
 'Macro-F1': 0.6694920488426908,
 'Precision': 0.8168270664065763,
 'Recall': 0.6702602720367603}

In [235]:
per_class_eval

{'soccer.football_league': {'Precision': 0.9439252336448598,
  'Recall': 0.9181818181818182,
  'F1': 0.9308755760368663},
 'government.government_office_or_title': {'Precision': 0.9,
  'Recall': 0.9,
  'F1': 0.9},
 'organization.non_profit_organization': {'Precision': 0,
  'Recall': 0,
  'F1': 0},
 'olympics.olympic_games': {'Precision': 1.0,
  'Recall': 0.7142857142857143,
  'F1': 0.8333333333333333},
 'cvg.cvg_genre': {'Precision': 1.0, 'Recall': 0.3, 'F1': 0.4615384615384615},
 'ice_hockey.hockey_position': {'Precision': 0.9333333333333333,
  'Recall': 0.875,
  'F1': 0.9032258064516129},
 'tv.tv_network': {'Precision': 1.0,
  'Recall': 0.9230769230769231,
  'F1': 0.9600000000000001},
 'aviation.airline': {'Precision': 0.4444444444444444,
  'Recall': 1.0,
  'F1': 0.6153846153846153},
 'american_football.football_conference': {'Precision': 1.0,
  'Recall': 0.4,
  'F1': 0.5714285714285715},
 'soccer.football_world_cup': {'Precision': 0.5, 'Recall': 0.5, 'F1': 0.5},
 'american_football.

## Error Analysis

In [236]:
errors = 0
for i in range(len(predictions)):
    label_set = set(labels[i]) 
    
    prediction_set = set(predictions[i])
    
    if len(list(label_set-prediction_set) + list(prediction_set-label_set)) != 0:
        for y in label_set:
            if y not in prediction_set:
                errors +=1
        print(f"Predicted as {predictions[i]} when it was {label_set}")
errors

Predicted as ['-', '-'] when it was {'cvg.cvg_genre'}
Predicted as ['-', '-'] when it was {'baseball.baseball_team'}
Predicted as ['-', '-'] when it was {'american_football.football_conference'}
Predicted as ['-', '-'] when it was {'sports.golfer'}
Predicted as ['-', '-'] when it was {'ice_hockey.hockey_team'}
Predicted as ['ice_hockey.hockey_position'] when it was {'tv.tv_personality'}
Predicted as ['ice_hockey.hockey_position'] when it was {'tv.tv_personality'}
Predicted as ['soccer.football_league'] when it was {'soccer.fifa'}
Predicted as ['baseball.baseball_player'] when it was {'american_football.football_coach'}
Predicted as ['sports.sports_championship'] when it was {'soccer.football_award'}
Predicted as ['-', '-'] when it was {'soccer.football_award'}
Predicted as ['-', '-'] when it was {'tv.tv_personality'}
Predicted as ['-', '-'] when it was {'sports.golfer'}
Predicted as ['sports.sports_championship'] when it was {'soccer.football_award'}
Predicted as ['-', '-'] when it was

207

### Re-load previous preds files

In [None]:
with open(f'Predictions/{model_name}/column/{nr}-shot/chat-column-{prompt_name}-{nr}-shot.pkl', "rb") as f:
    preds = pickle.load(f)

In [48]:
with open(f'Predictions/{model_name}/column/zero-shot/chat-column-r-zero-shot.pkl', "rb") as f:
    preds = pickle.load(f)

In [49]:
preds

['football league',
 'baseball team',
 'baseball team',
 'football league, ice hockey team',
 'baseball player',
 'baseball team',
 'ice hockey team',
 'football league',
 'music record label',
 'musical group',
 'sports golfer',
 'football league',
 'baseball player',
 'baseball team',
 'baseball position',
 'american football conference',
 'football league',
 'football league',
 'sports golfer',
 'aviation airline',
 'ice hockey team',
 'tv personality',
 'ice hockey player',
 'ice hockey position',
 'tv personality',
 'football league',
 'tv network',
 'american football coach',
 'football league',
 'baseball team',
 'baseball team',
 'baseball player',
 'baseball player',
 'football league',
 'football award',
 'football league',
 'ice hockey team',
 'football league',
 'tv personality',
 'sports golfer',
 'sports golfer',
 'football league',
 'football league',
 'baseball team',
 'baseball player',
 'american football coach',
 'sports championship',
 'football league',
 'cvg devel