In [None]:
import sys

# Error handling for imports
try:
    from datasets import load_dataset
    from transformers import (BertTokenizer, BertForSequenceClassification,
                              Trainer, TrainingArguments)
    import torch
    import numpy as np
    from sklearn.metrics import accuracy_score, f1_score
except ImportError as e:
    print(f"Error importing required libraries: {e}")
    print("Please install missing packages: pip install datasets transformers torch scikit-learn")
    sys.exit(1)

# Error handling for device availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Load the IMDb dataset
try:
    dataset = load_dataset("imdb")
except Exception as e:
    print(f"Failed to load IMDb dataset: {e}")
    # sys.exit(1)   # REMOVE this line
    raise          # This will just show the error and stop execution in this cell


# 2. Initialize tokenizer
try:
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
except Exception as e:
    print(f"Failed to load tokenizer: {e}")
    sys.exit(1)

# 3. Preprocess: Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=256,
    )

try:
    tokenized_datasets = dataset.map(tokenize_function, batched=True)
except Exception as e:
    print(f"Error during tokenization: {e}")
    sys.exit(1)

# 4. Prepare datasets for PyTorch
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5000))  # Subset for speed
test_dataset = tokenized_datasets["test"]

# 5. Load pre-trained BERT model
try:
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    model.to(device)
except Exception as e:
    print(f"Failed to load BERT model: {e}")
    sys.exit(1)

# 6. Define metrics
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    return {"accuracy": acc, "f1": f1}

# 7. Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_dir="./logs",
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=2,
    report_to="none",
)

# 8. Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

# 9. Train the model
try:
    trainer.train()
except Exception as e:
    print(f"Error during training: {e}")
    sys.exit(1)

# 10. Evaluate the model
try:
    eval_results = trainer.evaluate()
    print("Evaluation results:", eval_results)
except Exception as e:
    print(f"Error during evaluation: {e}")

# 11. Save the fine-tuned model
try:
    model.save_pretrained("./sentiment_model")
    tokenizer.save_pretrained("./sentiment_model")
    print("Model saved to ./sentiment_model")
except Exception as e:
    print(f"Error saving model: {e}")

# 12. Load and test inference
try:
    loaded_model = BertForSequenceClassification.from_pretrained("./sentiment_model")
    loaded_tokenizer = BertTokenizer.from_pretrained("./sentiment_model")
    loaded_model.to(device)
except Exception as e:
    print(f"Error loading saved model: {e}")
    sys.exit(1)

def predict_sentiment(text):
    try:
        inputs = loaded_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=256)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = loaded_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1)
            pred = torch.argmax(probs, dim=1).item()
            label = "positive" if pred == 1 else "negative"
            return label, probs.cpu().numpy()
    except Exception as e:
        print(f"Error during inference: {e}")
        return None, None

# Example inference
sample_text = "This movie was absolutely wonderful!"
label, probs = predict_sentiment(sample_text)
print(f"Sample text: {sample_text}\nPredicted sentiment: {label}\nProbabilities: {probs}")