In [None]:
import pandas as pd
import torch
import time
import re
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, Trainer, TrainingArguments, EarlyStoppingCallback
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset

In [None]:
def clean_text(text):
    if not isinstance(text, str):
        text = str(text)

    text = text.lower()  # Convert to lowercase
    text = re.sub(r"http\S+|www\S+|https\S+", "", text)  # Remove URLs
    text = re.sub(r"<.*?>", "", text)  # Remove HTML tags
    text = re.sub(r"[^\w\s]", "", text)  # Remove punctuation
    text = re.sub(r"\d+", "", text)  # Remove numbers
    text = re.sub(r"\s+", " ", text).strip()  # Normalize spaces

    return text

In [None]:
df = pd.read_csv("/content/English_profanity_words.csv")
df["clean_text"] = df["text"].apply(clean_text)

train_texts, test_texts, train_labels, test_labels = train_test_split(
    df['clean_text'].tolist(),
    df['is_offensive'].tolist(),
    test_size=0.2,
    random_state=42
)



In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length=128)

In [None]:
class ProfanityDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

In [None]:
train_dataset = ProfanityDataset(train_encodings, train_labels)
test_dataset = ProfanityDataset(test_encodings, test_labels)

In [None]:
config = BertConfig.from_pretrained("bert-base-uncased", num_labels=5, hidden_dropout_prob=0.5)

# Load BERT with the correct classifier
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", config=config)

In [None]:
login(token="Your API token here")

In [None]:
import wandb
wandb.login()  # Logs into your wandb account

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.1,
    learning_rate=2e-5,
    logging_dir="./temp_logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to=["wandb"],
    fp16=True if torch.cuda.is_available() else False,
)

# Compute Metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = torch.argmax(torch.tensor(logits), axis=1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

# Initialize Trainer with Early Stopping
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

# Start wandb logging
wandb.init(project="bert-finetune", name="profanity-filter")

# Estimate Training Time
num_samples = len(train_dataset)
num_steps_per_epoch = num_samples // training_args.per_device_train_batch_size
total_steps = num_steps_per_epoch * training_args.num_train_epochs
gpu_factor = 1 if torch.cuda.is_available() else 4

estimated_time = (total_steps * 0.3) / gpu_factor
print(f"🚀 Estimated Training Time: ~{estimated_time:.2f} seconds (~{estimated_time/60:.2f} minutes)")

# Train Model & Track Actual Time
start_time = time.time()
trainer.train()
end_time = time.time()

# Print Actual Training Time
actual_time = end_time - start_time
print(f"✅ Training Complete! Actual Training Time: {actual_time:.2f} seconds (~{actual_time/60:.2f} minutes)")

# Evaluate Model
results = trainer.evaluate()
print(f"✅ Final Accuracy: {results['eval_accuracy']:.4f}")