In [11]:
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report


In [18]:
df = pd.read_csv("prediction.csv", index_col=False)
target_names = ['inflation-cause-dominant', 'inflation-related', 'non-inflation-related']


In [25]:
confusion_matrix(df["label"], df["prediction"], labels=target_names)

array([[12, 12,  1],
       [14, 14,  7],
       [ 1,  4, 33]])

In [22]:
print(classification_report(df["label"], df["prediction"], target_names=target_names))

                          precision    recall  f1-score   support

inflation-cause-dominant       0.44      0.48      0.46        25
       inflation-related       0.47      0.40      0.43        35
   non-inflation-related       0.80      0.87      0.84        38

                accuracy                           0.60        98
               macro avg       0.57      0.58      0.58        98
            weighted avg       0.59      0.60      0.60        98


In [38]:
def get_by_label_distribution(df, label_distribution):
    target_columns = ["annotator_5", "annotator_7", "annotator_8"]
    # Get unique labels from the dataset
    unique_labels = df[target_columns].values.flatten()
    unique_labels = list(set(unique_labels))  # Get all possible label categories
    
    # Convert each row into a distribution of label occurrences
    distribution_counts = df[target_columns].apply(lambda row: tuple(row.value_counts().reindex(unique_labels, fill_value=0).values), axis=1)
    
    # Filter rows where the label distribution matches the allowed ones
    filtered_df = df[distribution_counts.isin(label_distribution)]
    
    return filtered_df
    


In [39]:
# define label distribution by count (n_inflation_cause_dominant, n_inflation_related, n_non_inflation_related)
label_distribution = {(1,1,1), (1,0,2)}
filtered_df = get_by_label_distribution(df, label_distribution)
filtered_df



Unnamed: 0,text,annotator_5,annotator_7,annotator_8,label,prediction
5,UPDATE: Euro-Zone Jan Services PMI 57.9 Vs 57....,inflation-cause-dominant,non-inflation-related,inflation-related,inflation-related,inflation-cause-dominant
9,Discount Rate Cut Shows Fed&apos;s Concern Ove...,non-inflation-related,inflation-cause-dominant,non-inflation-related,non-inflation-related,inflation-related
15,Carolina Freight 3rd Qtr Net 21c A Shr Vs 25c\...,inflation-related,inflation-cause-dominant,non-inflation-related,inflation-related,non-inflation-related
20,Energy Dept. Sees 5% Yearly Oil Price Increase...,inflation-cause-dominant,non-inflation-related,non-inflation-related,non-inflation-related,non-inflation-related
51,Japan Central Bank Chief Apologizes Over Infla...,inflation-cause-dominant,inflation-related,non-inflation-related,inflation-related,inflation-cause-dominant
57,War in Ukraine Is Already Taking Its Toll on G...,inflation-cause-dominant,inflation-related,non-inflation-related,inflation-related,inflation-related
70,Australia Employment/Analysts -2: Wage Index D...,inflation-cause-dominant,inflation-related,non-inflation-related,inflation-related,non-inflation-related
73,German Consumer Confidence Is Expected to Cont...,inflation-related,inflation-cause-dominant,non-inflation-related,inflation-related,inflation-cause-dominant
93,CBO Increases FY92 Budget Deficit Estimate To ...,inflation-cause-dominant,inflation-related,non-inflation-related,inflation-related,inflation-related
96,WSJ(4/23) Philippines Pressed To Rethink Rice ...,inflation-cause-dominant,non-inflation-related,non-inflation-related,non-inflation-related,non-inflation-related
