In [None]:
from transformers import TrainingArguments, Trainer, TrainerCallback
import matplotlib.pyplot as plt
from IPython.display import display, update_display, HTML

class ProgressVisualizationCallback(TrainerCallback):
    def __init__(self):
        self.training_logs = []
        self.output_id = 'progress_viz'
        self.fig = None
    
    def on_train_begin(self, args, state, control, **kwargs):
        self.training_logs = []
        # Don't print anything here to avoid interfering with the default progress display
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs:
            self.training_logs.append(logs)
            # Plot every 10 logs to avoid slowing down training
            if len(self.training_logs) % 10 == 0:
                self.visualize_progress(state)
    
    def visualize_progress(self, state):
        # Extract metrics
        steps = [log.get('step', i) for i, log in enumerate(self.training_logs) if 'loss' in log]
        loss = [log['loss'] for log in self.training_logs if 'loss' in log]
        lr = [log['learning_rate'] for log in self.training_logs if 'learning_rate' in log]
        
        # Create plot
        if self.fig is None:
            self.fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            display(self.fig, display_id=self.output_id)
        else:
            # Clear previous plot data
            for ax in self.fig.axes:
                ax.clear()
            ax1, ax2 = self.fig.axes
        
        # Loss plot
        ax1.plot(steps, loss, label='Training Loss')
        ax1.set_xlabel('Step')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training Loss')
        ax1.legend()
        ax1.grid(True)
        
        # Learning rate plot
        ax2.plot(steps, lr, label='Learning Rate', color='green')
        ax2.set_xlabel('Step')
        ax2.set_ylabel('Learning Rate')
        ax2.set_title('Learning Rate Schedule')
        ax2.legend()
        ax2.grid(True)
        
        self.fig.tight_layout()
        update_display(self.fig, display_id=self.output_id)
        
        # Don't print status here to avoid interfering with the default progress display

# Setup training with proper logging
training_args = TrainingArguments(
    output_dir="test_trainer",
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=10,  # Log every 10 steps
    evaluation_strategy="steps",
    eval_steps=100,    # Evaluate every 100 steps
    save_strategy="steps",
    save_steps=500,    # Save model every 100 steps
    save_only_model=True,  # Only save model weights, not optimizer state
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    max_steps=1000,    # Set a specific number of steps
    report_to="none",  # Disable wandb/tensorboard to avoid conflicts
    learning_rate=2e-5,
    weight_decay=0.01,
)

# Create the trainer with our callback
progress_callback = ProgressVisualizationCallback()
trainer = Trainer(
    model=base_model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
    callbacks=[progress_callback],  # Add the callback here
)

# Start training
print("Starting training...")
resource_monitor = ResourceMonitor(interval=60, log_path="training_log.csv", verbose=False)
resource_monitor.start(append_log=True)
trainer.train()
resource_monitor.stop()
print("Training completed.")