In [None]:
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from typing import List, Dict

class SummarizationEnsemble:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.models = self._load_models()
        
    def _load_models(self) -> List[Dict]:
        """Загрузка предобученных моделей"""
        models = [
            {
                'name': 'facebook/bart-large-cnn',
                'model': AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn'),
                'tokenizer': AutoTokenizer.from_pretrained('facebook/bart-large-cnn'),
                'lang': 'en',
                'weight': 0.4
            },
            {
                'name': 'IlyaGusev/mbart_ru_sum_gazeta',
                'model': AutoModelForSeq2SeqLM.from_pretrained('IlyaGusev/mbart_ru_sum_gazeta'),
                'tokenizer': AutoTokenizer.from_pretrained('IlyaGusev/mbart_ru_sum_gazeta'),
                'lang': 'ru',
                'weight': 0.4
            },
            {
                'name': 'google/mt5-base',
                'model': AutoModelForSeq2SeqLM.from_pretrained('google/mt5-base'),
                'tokenizer': AutoTokenizer.from_pretrained('google/mt5-base'),
                'lang': 'multi',
                'weight': 0.2
            }
        ]
        
        # Перенос моделей на устройство
        for m in models:
            m['model'] = m['model'].to(self.device)
            
        return models
    
    def detect_language(self, text: str) -> str:
        """Простое определение языка (можно заменить на более сложный детектор)"""
        ru_chars = len([c for c in text if 'а' <= c <= 'я' or 'А' <= c <= 'Я'])
        en_chars = len([c for c in text if 'a' <= c <= 'z' or 'A' <= c <= 'Z'])
        
        return 'ru' if ru_chars > en_chars else 'en'
    
    def summarize_single(self, text: str, model_info: Dict, max_length=130, min_length=30) -> str:
        """Генерация суммаризации одной моделью"""
        tokenizer = model_info['tokenizer']
        model = model_info['model']
        
        inputs = tokenizer(
            text,
            max_length=1024,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)
        
        summary_ids = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            min_length=min_length,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )
        
        return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    
    def summarize(self, text: str, max_length=130, min_length=30) -> str:
        """Ансамблевая суммаризация"""
        lang = self.detect_language(text)
        summaries = []
        weights = []
        
        for model_info in self.models:
            # Выбираем модели, подходящие для языка или мультиязычные
            if model_info['lang'] in [lang, 'multi']:
                try:
                    summary = self.summarize_single(text, model_info, max_length, min_length)
                    summaries.append(summary)
                    weights.append(model_info['weight'])
                except Exception as e:
                    print(f"Error with {model_info['name']}: {str(e)}")
        
        if not summaries:
            raise ValueError("No suitable models found for the detected language")
        
        # Здесь можно добавить более сложную логику комбинирования
        # Пока просто возвращаем summary с наибольшим весом
        return summaries[weights.index(max(weights))]

# Дообучение модели

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

def fine_tune_ensemble_member(model_name, train_dataset, eval_dataset, output_dir):
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        predict_with_generate=True,
        evaluation_strategy="steps",
        eval_steps=500,
        save_steps=500,
        warmup_steps=500,
        max_steps=4000,
        logging_dir=f"{output_dir}/logs",
        logging_steps=100,
        learning_rate=3e-5,
        weight_decay=0.01,
        save_total_limit=3,
        fp16=torch.cuda.is_available(),
    )
    
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
    )
    
    trainer.train()
    trainer.save_model(f"{output_dir}/final")
    
    return model, tokenizer