# Early Exit

## Prerequisites

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
import

class EarlyExitLM(nn.Module):
    def __init__(self, base_model: nn.Module, threshold: float = 1.0, min_layers: int = 1):
        super().__init__()
        self.base_model = base_model
        # Для удобства вытащим список трансформер-блоков и head
        self.blocks = base_model.transformer.h
        self.wte    = base_model.transformer.wte
        self.wpe    = base_model.transformer.wpe
        self.lm_head= base_model.lm_head
        self.threshold = threshold
        self.min_layers = min_layers

    def forward(self, input_ids, attention_mask=None):
        # эмбеддинги + позиционные
        device = input_ids.device
        seq_len = input_ids.size(-1)
        hidden = self.wte(input_ids) + self.wpe(torch.arange(seq_len, device=device))
        
        exit_layer = len(self.blocks)
        logits = None

        # по-блочно
        for i, block in enumerate(self.blocks):
            hidden = block(hidden, attn_mask=attention_mask)[0]
            
            # попробуем выйти, начиная с min_layers
            if i + 1 >= self.min_layers:
                lm_logits = self.lm_head(hidden)
                probs = torch.softmax(lm_logits[:, -1, :], dim=-1)
                entropy = - (probs * torch.log(probs + 1e-12)).sum(dim=-1).mean().item()
                if entropy < self.threshold:
                    logits = lm_logits
                    exit_layer = i + 1
                    break

        # если выхода не случилось — считаем на последнем состоянии
        if logits is None:
            logits = self.lm_head(hidden)

        return logits, exit_layer

def optimize_model_with_early_exit(
    base_model: nn.Module,
    threshold: float = 1.0,
    min_layers: int = 1
) -> nn.Module:
    """
    Оборачивает любую causal-LM модель в EarlyExitLM с заданными параметрами.
    """
    return EarlyExitLM(base_model, threshold=threshold, min_layers=min_layers)


if __name__ == "__main__":
    # 1) Загрузка модели с вашими опциями
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForCausalLM.from_pretrained(
        "openai-community/gpt2",
        torch_dtype=torch.float16,
        device_map="auto",
        attn_implementation="sdpa"
    )
    tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
    model = model.to(device).eval()

    # 2) Оптимизация
    early_model = optimize_model_with_early_exit(model, threshold=2.5, min_layers=2).eval()

    # 3) Тест на простом примере
    import time
    text = "The quick brown fox jumps over"
    inputs = tokenizer(text, return_tensors="pt").to(device)

    # baseline
    start = time.time()
    with torch.no_grad():
        _ = model(**inputs).logits
    print("Baseline:", time.time() - start)

    # early-exit
    start = time.time()
    with torch.no_grad():
        logits, exit_layer = early_model(inputs.input_ids, attention_mask=inputs.attention_mask)
    print("Early-exit:", time.time() - start, "Exited at layer", exit_layer)