In [6]:
# train.py — Resume Classifier Training (AR + EN)

#pip install -r requirements.txt


import os, json, warnings
from typing import List, Dict, Tuple, Optional
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, ConfusionMatrixDisplay
)

import arabic_reshaper
from bidi.algorithm import get_display

import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

# ===== ثابتات بسيطة =====
EN_CSV_PATH = "/Users/reemabalharith/Desktop/Capston Project/english_cv_dataset_full_6800_enriched.csv"   # عدّلي للمسار/الرابط
AR_CSV_PATH = "/Users/reemabalharith/Desktop/Capston Project/arabic_cv_dataset_full_6800_enriched.csv"    # عدّلي للمسار/الرابط
EN_MODEL_NAME = "distilbert-base-uncased"
AR_MODEL_NAME = "aubmindlab/bert-base-arabertv02"

OUT_EN = "llm_model_en"
OUT_AR = "llm_model_ar"
EPOCHS = 3
BATCH_SIZE = 16
LR = 2e-5
MAX_LENGTH = 256
SEED = 42
AR_FONT = "Amiri-Regular.ttf"

# ===== أعمدة شائعة =====
TEXT_CAND_EN = ["Resume","Cleaned_Resume","summary","Text","text"]
TEXT_CAND_AR = TEXT_CAND_EN + ["النص","السيرة"]
LABEL_CAND_EN = ["Category","category","Label","label"]
LABEL_CAND_AR = LABEL_CAND_EN + ["التصنيف"]

# ===== دوال مساعدة مختصرة =====
def ar_shape(s: str) -> str:
    if not isinstance(s, str): s = str(s)
    return get_display(arabic_reshaper.reshape(s))

def pick_col(df: pd.DataFrame, cands: List[str]) -> str:
    cols_lower = {c.lower(): c for c in df.columns}
    for c in cands:
        if c in df.columns: return c
    for c in cands:
        if c.lower() in cols_lower: return cols_lower[c.lower()]
    raise ValueError(f"Missing any of {cands}")

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

# ===== الرسوم =====
def plot_top7_and_pie(df: pd.DataFrame, label_col: str, out_prefix: str, lang: str, ar_font: Optional[str]=None):
    counts = df[label_col].value_counts().head(7)
    labels = counts.index.astype(str).tolist(); vals = counts.values
    if lang=='ar' and ar_font and os.path.exists(ar_font):
        matplotlib.rcParams['font.sans-serif'] = [ar_font]
        labels_plot = [ar_shape(x) for x in labels]; title_bar = ar_shape('أكثر ٧ تصنيفات'); title_pie = ar_shape('أكثر ٧ تصنيفات (دائري)'); ylabel = ar_shape('العدد')
    else:
        labels_plot = labels; title_bar = 'Top-7 Categories'; title_pie = 'Top-7 Categories (Pie)'; ylabel = 'Count'
    fig, ax = plt.subplots(figsize=(7,4)); ax.bar(labels_plot, vals); ax.set_title(title_bar); ax.set_ylabel(ylabel); ax.tick_params(axis='x', rotation=30); plt.tight_layout(); plt.savefig(f"{out_prefix}_top7_bar.png", dpi=150); plt.close(fig)
    fig, ax = plt.subplots(figsize=(5,5)); ax.pie(vals, labels=labels_plot, autopct='%1.1f%%', startangle=140); ax.set_title(title_pie); plt.tight_layout(); plt.savefig(f"{out_prefix}_top7_pie.png", dpi=150); plt.close(fig)

def plot_class_imbalance(df: pd.DataFrame, label_col: str, out_prefix: str, lang: str, ar_font: Optional[str]=None):
    counts = df[label_col].value_counts(); labels = counts.index.astype(str).tolist()
    if lang=='ar' and ar_font and os.path.exists(ar_font):
        matplotlib.rcParams['font.sans-serif'] = [ar_font]
        labels_plot = [ar_shape(x) for x in labels]; title = ar_shape('عدم توازن الفئات'); ylabel = ar_shape('العدد')
    else:
        labels_plot = labels; title = 'Class Imbalance'; ylabel = 'Count'
    fig, ax = plt.subplots(figsize=(7,4)); ax.bar(labels_plot, counts.values); ax.set_title(title); ax.set_ylabel(ylabel); ax.tick_params(axis='x', rotation=30); plt.tight_layout(); plt.savefig(f"{out_prefix}_class_imbalance.png", dpi=150); plt.close(fig)

def plot_wordcloud_en(df: pd.DataFrame, text_col: str, out_file: str):
    text = ' '.join(map(str, df[text_col].dropna().tolist()))
    wc = WordCloud(width=1000, height=500, background_color='white').generate(text)
    fig, ax = plt.subplots(figsize=(10,5)); ax.imshow(wc); ax.axis('off'); ax.set_title('English Word Cloud'); plt.tight_layout(); plt.savefig(out_file, dpi=150); plt.close(fig)

def plot_confusion(y_true: List[int], y_pred: List[int], id2label: Dict[int,str], out_file: str, lang: str, ar_font: Optional[str]=None):
    labels_sorted = sorted(id2label.keys()); cm = confusion_matrix(y_true, y_pred, labels=labels_sorted); disp_labels = [id2label[i] for i in labels_sorted]
    if lang=='ar' and ar_font and os.path.exists(ar_font): matplotlib.rcParams['font.sans-serif'] = [ar_font]; disp_labels = [ar_shape(x) for x in disp_labels]
    fig, ax = plt.subplots(figsize=(6,6)); disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=disp_labels); disp.plot(ax=ax, xticks_rotation=45, cmap='Blues', colorbar=False); ax.set_title('Confusion Matrix (AR)' if lang=='ar' else 'Confusion Matrix (EN)'); plt.tight_layout(); plt.savefig(out_file, dpi=160); plt.close(fig)

def plot_accuracy_comparison(acc_map: Dict[str, float], out_file: str):
    if not acc_map: return
    names = list(acc_map.keys()); vals = [acc_map[n] for n in names]
    fig, ax = plt.subplots(figsize=(6,4)); ax.bar(names, vals); ax.set_ylim(0,1.0); ax.set_ylabel('Accuracy'); ax.set_title('Model Accuracy Comparison'); ax.tick_params(axis='x', rotation=20); plt.tight_layout(); plt.savefig(out_file, dpi=150); plt.close(fig)

# ===== تجهيز البيانات =====
def factorize_labels(series: pd.Series) -> Tuple[np.ndarray, Dict[int,str]]:
    labels, uniques = pd.factorize(series.astype(str), sort=True)
    id2label = {int(i): str(lab) for i, lab in enumerate(uniques.tolist())}
    return labels.astype(int), id2label

def make_hf_dataset(texts: List[str], labels: List[int]) -> Dataset:
    # استخدم المفتاح 'labels' (متوافق مع إصدارات قديمة/جديدة)
    return Dataset.from_dict({"text": texts, "labels": labels})

def tokenize_batch(batch, tokenizer, max_length: int):
    return tokenizer(batch["text"], truncation=True, padding=True, max_length=max_length)

# ===== تدريب لغة =====
def train_one_language(csv_path: str, lang: str, model_name: str, output_dir: str, max_length=256, batch_size=16, epochs=3, lr=2e-5, seed=42, ar_font: Optional[str]=None) -> Dict[str,float]:
    df = pd.read_csv(csv_path)
    text_col = pick_col(df, TEXT_CAND_AR if lang=='ar' else TEXT_CAND_EN)
    label_col = pick_col(df, LABEL_CAND_AR if lang=='ar' else LABEL_CAND_EN)
    df = df[[text_col, label_col]].dropna(); df[text_col] = df[text_col].astype(str).str.strip(); df = df[df[text_col].str.len() > 0]

    ensure_dir(output_dir); prefix = os.path.join(output_dir, f"{lang}")
    plot_top7_and_pie(df, label_col, prefix, lang, ar_font)
    plot_class_imbalance(df, label_col, prefix, lang, ar_font)
    if lang=='en':
        try: plot_wordcloud_en(df, text_col, os.path.join(output_dir, "wordcloud_en.png"))
        except: pass

    y, id2label = factorize_labels(df[label_col])
    X_train, X_tmp, y_train, y_tmp = train_test_split(df[text_col].tolist(), y.tolist(), test_size=0.3, random_state=seed, stratify=y)
    X_val, X_test, y_val, y_test = train_test_split(X_tmp, y_tmp, test_size=0.5, random_state=seed, stratify=y_tmp)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(id2label), id2label=id2label, label2id={v:k for k,v in id2label.items()})

    ds_train = make_hf_dataset(X_train, y_train)
    ds_val   = make_hf_dataset(X_val,   y_val)
    ds_test  = make_hf_dataset(X_test,  y_test)

    ds_train = ds_train.map(lambda b: tokenize_batch(b, tokenizer, max_length), batched=True)
    ds_val   = ds_val.map(  lambda b: tokenize_batch(b, tokenizer, max_length), batched=True)
    ds_test  = ds_test.map(  lambda b: tokenize_batch(b, tokenizer, max_length), batched=True)

    # collator متوافق مع الإصدارات القديمة
    def simple_collate(features):
        return tokenizer.pad(features, return_tensors="pt")

    training_args = TrainingArguments(
        output_dir=os.path.join(output_dir, "hf_runs"),
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=lr,
        num_train_epochs=epochs,
        weight_decay=0.01,
        logging_steps=50
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ds_train,
        eval_dataset=ds_val,
        tokenizer=tokenizer,
        data_collator=simple_collate,
        compute_metrics=lambda p: {
            "accuracy": accuracy_score(p.label_ids, np.argmax(p.predictions, axis=1)),
            **{k:v for k,v in zip(["precision","recall","f1","_"], precision_recall_fscore_support(p.label_ids, np.argmax(p.predictions, axis=1), average='weighted', zero_division=0)) if k!="_"}
        }
    )

    trainer.train()
    eval_test = trainer.evaluate(ds_test)
    metrics = {k: float(v) for k, v in eval_test.items() if isinstance(v, (int,float,np.floating))}

    preds = np.argmax(trainer.predict(ds_test).predictions, axis=1)
    plot_confusion(y_test, preds, id2label, os.path.join(output_dir, f"cm_{lang}.png"), lang, ar_font)

    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    with open(os.path.join(output_dir, "labels.json"), "w", encoding="utf-8") as f: json.dump(id2label, f, ensure_ascii=False, indent=2)
    with open(os.path.join(output_dir, f"metrics_{lang}.json"), "w", encoding="utf-8") as f: json.dump(metrics, f, ensure_ascii=False, indent=2)

    return {f"{lang}_accuracy": metrics.get('eval_accuracy', 0.0)}

# ===== Main =====
def main():
    np.random.seed(SEED)
    accs = {}
    if EN_CSV_PATH:
        ensure_dir(OUT_EN)
        accs.update(train_one_language(EN_CSV_PATH, 'en', EN_MODEL_NAME, OUT_EN, MAX_LENGTH, BATCH_SIZE, EPOCHS, LR, SEED))
    if AR_CSV_PATH:
        ensure_dir(OUT_AR)
        accs.update(train_one_language(AR_CSV_PATH, 'ar', AR_MODEL_NAME, OUT_AR, MAX_LENGTH, BATCH_SIZE, EPOCHS, LR, SEED, AR_FONT))
    if accs:
        plot_accuracy_comparison(accs, 'accuracy_comparison.png')
        with open('accuracy_comparison.json','w',encoding='utf-8') as f: json.dump(accs, f, ensure_ascii=False, indent=2)
        print("Accuracy:", accs)

if __name__ == '__main__':
    main()



Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|████████████████████████| 4760/4760 [00:00<00:00, 33282.30 examples/s]
Map: 100%|████████████████████████| 1020/1020 [00:00<00:00, 27706.69 examples/s]
Map: 100%|████████████████████████| 1020/1020 [00:00<00:00, 31428.16 examples/s]
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
50,5.0565
100,4.9171
150,4.6447
200,4.3021
250,3.9495
300,3.6799
350,3.4221
400,3.201
450,3.013
500,2.8845


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at aubmindlab/bert-base-arabertv02 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|████████████████████████| 4760/4760 [00:00<00:00, 32134.41 examples/s]
Map: 100%|████████████████████████| 1020/1020 [00:00<00:00, 32650.71 examples/s]
Map: 100%|████████████████████████| 1020/1020 [00:00<00:00, 34141.12 examples/s]
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
50,5.1344
100,4.8736
150,4.4814
200,4.019
250,3.5525
300,3.1206
350,2.7293
400,2.4374
450,2.2558
500,2.0016


Accuracy: {'en_accuracy': 0.9833333333333333, 'ar_accuracy': 1.0}
