In [12]:
from pathlib import Path
import duckdb
from datetime import datetime
from transformers import (
    AutoTokenizer,
)
from core_pro.ultilities import make_sync_folder, update_df
import sys
sys.path.extend([str(Path.home() / 'PycharmProjects/model_train')])

from src.model_train.data_loading import TrainDistribution
from src.model_train.pipeline_train import Pipeline
from src.model_train.func import training_report

In [2]:
path = make_sync_folder('dataset/category_tag')

query = f"""
select *
from read_parquet('{path / 'clean.parquet'}')
"""
label = 'label'
df = duckdb.sql(query).pl()
df.head()

item_id,item_name,level1_global_be_category,level2_global_be_category,total_item,total_shop,label,word
i64,str,str,str,i64,i64,str,str
25681560302,"""túi xách hình bánh cute, túi t…","""Women Bags""","""Top-handle Bags""",1,1,"""Women Bags >> Top-handle Bags""","""túi_xách hình bánh cute , túi …"
23413088890,"""Set nước hoa Juliette Not A Pe…","""Beauty""","""Perfumes & Fragrances""",1,1,"""Beauty >> Perfumes & Fragrance…","""Set nước_hoa Juliette Not A_Pe…"
29268958644,"""Tất dài qua đầu gối, tất cotto…","""Women Clothes""","""Socks & Stockings""",1,1,"""Women Clothes >> Socks & Stock…","""Tất_dài qua đầu_gối , tất cott…"
28608703297,"""1kg tim gà CP""","""Food & Beverages""","""Fresh & Frozen Food""",1,1,"""Food & Beverages >> Fresh & Fr…","""1kg tim gà CP"""
1312377872,"""Set nguyên liệu làm Vỏ bánh dẻ…","""Food & Beverages""","""Baking Needs""",1,1,"""Food & Beverages >> Baking Nee…","""Set nguyên_liệu làm Vỏ bánh_dẻ…"


In [3]:
label_list = df[label].unique().to_list()
dist_check = TrainDistribution(df, col_label=label, col_item='word', label_list=label_list)
train, val, test = dist_check.split_train_valid_test(test_size=.2)
# dist_check.check_distribution()

In [4]:
col = ['item_id', 'word', label]
_ = dist_check.df_to_dataset(col, show_index=3)

In [5]:
pretrain_name = 'bkai-foundation-models/vietnamese-bi-encoder'
tokenizer = AutoTokenizer.from_pretrained(pretrain_name)

In [6]:
dict_train = dist_check.ds_tokenize(tokenizer, show_index=1)

Map:   0%|          | 0/3187420 [00:00<?, ? examples/s]

Map:   0%|          | 0/796856 [00:00<?, ? examples/s]

Map:   0%|          | 0/996069 [00:00<?, ? examples/s]

In [7]:
pipe = Pipeline(
    pretrain_name=pretrain_name,
    id2label=dist_check.id2label,
    label2id=dist_check.label2id,
    bf16=True,
    flash_attention_2=False,
    hub_model_id="kevinkhang2909/l2_category"
)

time_now = datetime.now().strftime("%Y%m%d%H%M%S")
folder = path / f'model_multi_classes/{pretrain_name.split('/')[-1]}/{time_now}'
config = dict(    
    log_step=5000,
    num_train_epochs=5,
    learning_rate=1e-4,
)
trainer = pipe.train(
    folder=folder, 
    train=dict_train['train'],
    val=dict_train['valid'],
    **config
)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at bkai-foundation-models/vietnamese-bi-encoder and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss,Validation Loss,F1,Accuracy
5000,2.951,1.459079,0.733395,0.733395
10000,1.4286,1.283615,0.765118,0.765118
15000,1.3381,1.255377,0.772217,0.772217
20000,1.3115,1.248705,0.774367,0.774367
25000,1.3034,1.246454,0.774855,0.774855
30000,1.3015,1.246419,0.774865,0.774865


***** train metrics *****
  epoch                    =         5.0
  total_flos               = 382282625GF
  train_loss               =      1.5945
  train_runtime            =  1:51:07.64
  train_samples_per_second =    2390.214
  train_steps_per_second   =       4.669


In [11]:
valid_result = trainer.predict(dict_train['test'])
y_pred = valid_result.predictions.argmax(-1)
y_true = valid_result.label_ids

df_report = training_report(y_true=y_true, y_pred=y_pred, id2label=dist_check.id2label)

                                                             precision    recall  f1-score   support

                         Fashion Accessories >> Hats & Caps       0.89      0.93      0.91      5668
                                        Men Clothes >> Sets       0.85      0.92      0.89      4818
                                       Stationery >> Others       0.31      0.12      0.18      4966
                         Home Appliances >> Remote Controls       0.85      0.91      0.88      5221
                                    Women Clothes >> Shorts       0.85      0.80      0.82      5630
                          Women Clothes >> Traditional Wear       0.87      0.84      0.85      5788
          Fashion Accessories >> Investment Precious Metals       1.00      0.00      0.00       124
                                     Audio >> Media Players       0.84      0.70      0.76      1747
                                 Beauty >> Bath & Body Care       0.78      0.84      0.81

In [13]:
sh = '1L-4z-SrAWXee-ScQ9dZCEVcPrUGaJUv5U4hw_jiDeOI'
update_df(df_report, pretrain_name.split('/')[1], sh)