<a href="https://colab.research.google.com/github/shitkov/categorizer/blob/main/tag3_catboost_gridsearch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!wget https://github.com/shitkov/categorizer/raw/main/data_split.zip
!unzip /content/data_split.zip
!pip install catboost

In [2]:
import pandas as pd
from catboost import Pool, CatBoostClassifier
from sklearn.metrics import f1_score, precision_score, recall_score

In [3]:
data_train = pd.read_csv('/content/data_train.csv').fillna('')
data_eval = pd.read_csv('/content/data_valid.csv').fillna('')
data_test = pd.read_csv('/content/data_test.csv').fillna('')

In [4]:
target_columns = ['target', 'tag_0', 'tag_1', 'tag_2', 'tag_3', 'tag_4', 'tag_5', 'tag_6', 'tag_7', 'tag_8']

In [5]:
labels_train = list(data_train['tag_3'])
labels_eval = list(data_eval['tag_3'])
labels_test = list(data_test['tag_3'])

In [6]:
data_train = data_train.drop(columns=target_columns)
data_eval = data_eval.drop(columns=target_columns)
data_test = data_test.drop(columns=target_columns)

In [7]:
def get_equal(data):
    equal_list = []
    positive_texts = list(data['clean_positive'])
    negative_texts = list(data['clean_negative'])
    for positive, negative in zip(positive_texts, negative_texts):
        if positive == negative:
            equal_list.append('yes')
        else:
            equal_list.append('no')
    data['equal'] = equal_list
    return data

In [8]:
data_train = get_equal(data_train)
data_eval = get_equal(data_eval)
data_test = get_equal(data_test)

In [9]:
cat_features = [
        'position',
        'sentiment_positive_label',
        'sentiment_negative_label',
        'emotion_positive_label',
        'emotion_negative_label',
        'toxic_positive_label',
        'toxic_negative_label',
        'equal'
       ]

In [10]:
text_features = [
                 'clean_positive',
                 'clean_negative'
]

In [11]:
columns2drop = [
                'city',
                'positive',
                'negative', 
                'clean_positive',
                'clean_negative'
]

In [12]:
dataset_train = Pool(
    data_train.drop(columns=columns2drop),
    labels_train,
    cat_features=cat_features,
    # text_features=text_features
    )

In [13]:
dataset_eval = Pool(
    data_eval.drop(columns=columns2drop),
    labels_eval,
    cat_features=cat_features,
    # text_features=text_features
    )

In [None]:
N=1000
weights_list = [1, 2, 3, 5, 10, 20, 30, 50, 100]
metric_list = ['F1', 'TotalF1', 'CrossEntropy', 'BalancedAccuracy']
gridsearch = []
for w in weights_list:
    for metric in metric_list:
        model = CatBoostClassifier(
            iterations=N,
            eval_metric=metric,
            task_type='GPU',
            use_best_model=True,
            silent=True,
            class_weights=[1, w]
            )
        model.fit(dataset_train, eval_set=dataset_eval)
        predicted = model.predict(data_test.drop(columns=columns2drop))
        ans = {
            'metric': metric,
            'weight': w,
            'f1': f1_score(predicted, labels_test, average=None),
            'precision': precision_score(predicted, labels_test, average=None),
            'recall': recall_score(predicted, labels_test, average=None)
        }
        gridsearch.append(ans)

In [None]:
model = CatBoostClassifier(
    iterations=10000,
    eval_metric='CrossEntropy',
    task_type='GPU',
    use_best_model=True,
    silent=True,
    class_weights=[1, 2]
    )
model.fit(dataset_train, eval_set=dataset_eval)
predicted = model.predict(data_test.drop(columns=columns2drop))

In [None]:
ans = {
    'metric': 'CrossEntropy',
    'weight': 2,
    'f1': f1_score(predicted, labels_test, average=None),
    'precision': precision_score(predicted, labels_test, average=None),
    'recall': recall_score(predicted, labels_test, average=None)
}

In [21]:
ans

{'f1': array([0.99400816, 0.78923767]),
 'metric': 'CrossEntropy',
 'precision': array([0.99362814, 0.8       ]),
 'recall': array([0.99438847, 0.77876106]),
 'weight': 2}