In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from evaluation.metrics import calculate_metrics

class Evaluator:
    def __init__(self, config):
        self.config = config

    def evaluate(self, model, eval_dataset):
        eval_dataloader = DataLoader(eval_dataset, batch_size=self.config.batch_size)
        model.eval()
        all_predictions = []
        all_references = []

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            with torch.no_grad():
                outputs = model.generate(**batch)
            predictions = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)
            references = model.tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)
            all_predictions.extend(predictions)
            all_references.extend(references)

        return calculate_metrics(all_predictions, all_references, self.config.metrics)