In [None]:
# =====================================================
# FAKE NEWS DETECTION ‚Äì ALBERT-base x FAKENEWSNET
# =====================================================

# 1. C√ÄI ƒê·∫∂T & IMPORT
!pip install -q transformers datasets torch scikit-learn pandas numpy psutil accelerate sentencepiece

import os, re, shutil, psutil, warnings
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from sklearn.utils.class_weight import compute_class_weight
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    Trainer, TrainingArguments, EarlyStoppingCallback,
    DataCollatorWithPadding
)
from google.colab import drive

warnings.filterwarnings("ignore")

# 2. KI·ªÇM TRA GPU
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"Device: {device_name} | VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("Device: CPU")

# 3. MOUNT DRIVE & SETUP PATH
drive.mount('/content/drive', force_remount=False)
OUTPUT_DIR = "/content/drive/MyDrive/FakeNewsNet_ALBERT_Pro"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# H√†m qu·∫£n l√Ω checkpoint (gi·ªØ ·ªï c·ª©ng kh√¥ng b·ªã ƒë·∫ßy)
def manage_checkpoints(output_dir, keep_latest=1):
    ckpts = [c for c in os.listdir(output_dir) if c.startswith("checkpoint-")]
    if len(ckpts) <= keep_latest: return
    # S·∫Øp x·∫øp theo th·ªùi gian s·ª≠a ƒë·ªïi
    ckpts_sorted = sorted(ckpts, key=lambda x: os.path.getmtime(os.path.join(output_dir, x)))
    for ck in ckpts_sorted[:-keep_latest]:
        shutil.rmtree(os.path.join(output_dir, ck), ignore_errors=True)
        print(f"üßπ ƒê√£ x√≥a checkpoint c≈©: {ck}")

# 4. T·∫¢I DATASET: FAKENEWSNET
print("\n‚è≥ ƒêang t·∫£i dataset FakeNewsNet...")
try:
    # T·∫£i c·∫£ 2 mi·ªÅn d·ªØ li·ªáu
    ds_gossip = load_dataset("rickstello/FakeNewsNet", "gossipcop", split="train")
    ds_politi = load_dataset("rickstello/FakeNewsNet", "politifact", split="train")
    dataset_full = concatenate_datasets([ds_gossip, ds_politi])
    df = pd.DataFrame(dataset_full)
except Exception as e:
    print(f"‚ö†Ô∏è T·∫£i config th·∫•t b·∫°i ({e}), th·ª≠ t·∫£i default...")
    dataset = load_dataset("rickstello/FakeNewsNet", split="train")
    df = pd.DataFrame(dataset)

print(f"T·ªïng s·ªë m·∫´u: {len(df)}")

# 5. X·ª¨ L√ù C·ªòT & L√ÄM S·∫†CH (Auto-detect columns)
# T√¨m t√™n c·ªôt
text_col = next((c for c in ['news_content', 'text', 'content', 'body'] if c in df.columns), None)
title_col = next((c for c in ['title', 'news_title', 'headline'] if c in df.columns), None)
label_col = next((c for c in ['real', 'label', 'class', 'fake'] if c in df.columns), None)

print(f"Mapping: Text='{text_col}' | Title='{title_col}' | Label='{label_col}'")

if not label_col: raise ValueError("‚ùå Kh√¥ng t√¨m th·∫•y c·ªôt nh√£n!")

# Chu·∫©n h√≥a nh√£n v·ªÅ c·ªôt 'label'
df['label'] = df[label_col]

# Gh√©p Title + Text an to√†n
title_data = df[title_col].fillna('') if title_col else ""
text_data = df[text_col].fillna('') if text_col else ""
df['content'] = title_data + " [SEP] " + text_data

# H√†m l√†m s·∫°ch
def clean_text(s):
    if not isinstance(s, str): return ""
    s = s.lower()
    s = re.sub(r'https?://\S+', ' ', s)
    s = re.sub(r'<.*?>', ' ', s)
    s = re.sub(r'[^a-z0-9\s]', ' ', s)
    s = re.sub(r'\s+', ' ', s).strip()
    return s

print("üßπ ƒêang l√†m s·∫°ch vƒÉn b·∫£n...", end="")
df['content'] = df['content'].apply(clean_text)
df = df[df['content'].str.len() > 20].drop_duplicates(subset=['content'])
print(f" ‚Üí Sau x·ª≠ l√Ω: {len(df):,}")

# Ph√¢n b·ªë nh√£n
print(f"Ph√¢n b·ªë nh√£n: {df['label'].value_counts(normalize=True).to_dict()}")

# T√≠nh Class Weights (C√¢n b·∫±ng d·ªØ li·ªáu)
classes = np.unique(df['label'])
class_weights = compute_class_weight('balanced', classes=classes, y=df['label'])
class_weight_dict = {k: float(v) for k, v in zip(classes, class_weights)}
print("Class weights:", class_weight_dict)

# 6. SPLIT DATA
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['label'])

dataset_dict = DatasetDict({
    "train": Dataset.from_pandas(train_df[['content','label']].reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df[['content','label']].reset_index(drop=True)),
    "test": Dataset.from_pandas(test_df[['content','label']].reset_index(drop=True))
})

# 7. TOKENIZER (ALBERT)
MODEL_NAME = "albert-base-v2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_fn(batch):
    return tokenizer(batch["content"], truncation=True, max_length=384, padding=False)

print("‚öôÔ∏è ƒêang tokenize...")
tokenized = dataset_dict.map(tokenize_fn, batched=True, batch_size=1000, remove_columns=['content'])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 8. MODEL SETUP
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

# C·∫•u h√¨nh nh√£n (FakeNewsNet: 1=Real, 0=Fake ho·∫∑c ng∆∞·ª£c l·∫°i, code t·ª± map theo d·ªØ li·ªáu)
# Gi·∫£ ƒë·ªãnh c·ªôt 'real': 1 l√† Real, 0 l√† Fake
model.config.id2label = {0: "Fake", 1: "Real"}
model.config.label2id = {"Fake": 0, "Real": 1}

# 9. CUSTOM TRAINER (Weighted Loss)
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        # Truy·ªÅn weight v√†o loss function
        w = torch.tensor(list(class_weight_dict.values()), dtype=torch.float32, device=model.device)
        loss_fct = torch.nn.CrossEntropyLoss(weight=w)
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

# 10. METRICS
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="weighted", zero_division=0)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

# 11. TRAINING ARGUMENTS
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=5,                 
    per_device_train_batch_size=16,     
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,                 
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    fp16=torch.cuda.is_available(),
    report_to="none"
)

trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

# 12. RUN TRAIN
print("\nüöÄ B·∫ÆT ƒê·∫¶U HU·∫§N LUY·ªÜN ALBERT...")
trainer.train()

# 13. EVALUATE & SAVE
print("\nüéØ ƒê√ÅNH GI√Å TR√äN TEST SET:")
results = trainer.evaluate(tokenized["test"])
print(results)

# L∆∞u model cu·ªëi c√πng
final_path = os.path.join(OUTPUT_DIR, "final_albert_fnn")
trainer.save_model(final_path)
tokenizer.save_pretrained(final_path)

# D·ªçn d·∫πp checkpoint th·ª´a
manage_checkpoints(OUTPUT_DIR, keep_latest=0) # X√≥a h·∫øt checkpoint, ch·ªâ gi·ªØ model final
print(f"\n‚úÖ ƒê√£ l∆∞u model t·∫°i: {final_path}")