In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
from training.lora_adapter import LoraAdapter

class Trainer:
    def __init__(self, config, model):
        self.config = config
        self.model = model
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.learning_rate)
        self.lora_adapter = LoraAdapter(config.lora_config)

    def train(self, train_dataset):
        train_dataloader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
        num_training_steps = len(train_dataloader) * self.config.num_epochs
        scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
        )

        self.model = self.lora_adapter.adapt(self.model)

        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0
            for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.config.num_epochs}"):
                self.optimizer.zero_grad()
                outputs = self.model(**batch)
                loss = outputs.loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                scheduler.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(train_dataloader)
            print(f"Epoch {epoch+1}/{self.config.num_epochs}, Average Loss: {avg_loss:.4f}")

            self._save_checkpoint(epoch)

    def _save_checkpoint(self, epoch):
        checkpoint_path = f"{self.config.checkpoint_dir}/model_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, checkpoint_path)