In [1]:
from FCA import BinaryFCAClassifier, format_formula_as_str

import pathlib
import re

import numpy as np
import pandas as pd
from tqdm import notebook
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.datasets import fetch_20newsgroups

In [2]:
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

def gen_corpus(texts):
    for text in notebook.tqdm(texts):
        filtered_text = re.sub(r"[\W+]", " ", text.lower())
        yield filtered_text.split()


vectorizer_count = CountVectorizer(tokenizer=lambda doc: doc, lowercase=False, min_df=5)
vectorizer_tfidf = TfidfTransformer()

gen_corpus_train = gen_corpus(newsgroups_train.data)
sparse_train = vectorizer_count.fit_transform(gen_corpus_train)
X_train = vectorizer_tfidf.fit_transform(sparse_train)

class_names = newsgroups_train.target_names

gen_corpus_test = gen_corpus(newsgroups_test.data)
sparse_test = vectorizer_count.transform(gen_corpus_test)
X_test = vectorizer_tfidf.transform(sparse_test)



  0%|          | 0/11314 [00:00<?, ?it/s]

  0%|          | 0/7532 [00:00<?, ?it/s]

In [3]:
binarizer = LabelBinarizer()

y_train = binarizer.fit_transform(newsgroups_train.target)
y_test = binarizer.transform(newsgroups_test.target)

y_train.shape, y_test.shape

((11314, 20), (7532, 20))

In [4]:
class_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [5]:
def eval_metrics(y_true, y_pred):
    f1 = f1_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    return f1, precision, recall


results = []
inverse_idx = None
feature_names = vectorizer_count.get_feature_names_out()


for class_idx, class_name in notebook.tqdm(enumerate(class_names)):
    clf = BinaryFCAClassifier()
    clf.fit(X_train, y_train[:, class_idx], inverse_idx=inverse_idx)

    formula = clf.get_formula(feature_names=feature_names)

    y_pred_train = clf.predict(X_train)
    f1_train, precision_train, recall_train = eval_metrics(y_train[:, class_idx], y_pred_train)

    y_pred_test = clf.predict(X_test)
    f1_test, precision_test, recall_test = eval_metrics(y_test[:, class_idx], y_pred_test)
    
    results.append({'class_name': class_name,
                    'formula': formula,
                    'f1 train': f1_train,
                    'f1 test': f1_test,
                    'precision train': precision_train,
                    'recall train': recall_train,
                    'precision test': precision_test,
                    'recall test': recall_test})

    inverse_idx = clf._inverse_idx

0it [00:00, ?it/s]

In [6]:
results_df = pd.DataFrame(results)

In [7]:
float_columns = ['f1 train', 'f1 test', 'precision train','recall train', 'precision test', 'recall test']

results_df[float_columns] = results_df[float_columns].apply(lambda x: round(x, 4))
results_df['formula'] = results_df['formula'].apply(format_formula_as_str)

In [8]:
from IPython.display import display, HTML
display(HTML( results_df.to_html().replace("\\n", "<br>") ))


Unnamed: 0,class_name,formula,f1 train,f1 test,precision train,recall train,precision test,recall test
0,alt.atheism,atheists || keith && writes || keith || atheism,0.548,0.4555,0.5757,0.5229,0.5267,0.4013
1,comp.graphics,graphics || image && lines && from || 3d && lines && from,0.4695,0.4726,0.4276,0.5205,0.4251,0.5321
2,comp.os.ms-windows.misc,windows,0.5822,0.5742,0.5089,0.6802,0.4981,0.6777
3,comp.sys.ibm.pc.hardware,controller && organization && lines || bus && organization && lines || card && organization && lines || card && lines || ide,0.3924,0.3794,0.337,0.4695,0.3437,0.4235
4,comp.sys.mac.hardware,mac && lines && from || apple && from && subject || quadra && from || centris && from || simms && lines && from || powerbook && lines && from,0.612,0.5451,0.5342,0.7163,0.4569,0.6753
5,comp.windows.x,window && organization && from || window || motif && subject,0.493,0.4348,0.5549,0.4435,0.5331,0.3671
6,misc.forsale,sale || shipping && from,0.682,0.7134,0.6803,0.6838,0.7218,0.7051
7,rec.autos,car && lines && subject || cars && the && lines,0.6219,0.6088,0.5802,0.67,0.6065,0.6111
8,rec.motorcycles,dod || bike && from || ride && lines || bikes && subject || motorcycle && from && subject,0.8124,0.796,0.793,0.8328,0.798,0.794
9,rec.sport.baseball,baseball || team && edu || players && edu || season,0.5374,0.527,0.489,0.5963,0.4852,0.5768
