In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
import psutil
import os

# =========================
# Modelo Base
# =========================
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size=2000, hidden_dim=64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        x = self.transformer(x)
        return self.lm_head(x)


# =========================
# Engram Vetorizado
# =========================
class TinyEngram(nn.Module):
    def __init__(self, hidden_dim=64, table_size=5000, ngram_size=2):
        super().__init__()
        self.ngram_size = ngram_size
        self.table_size = table_size

        self.memory = nn.Embedding(table_size, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)

        self.register_buffer(
            "hash_weights",
            torch.randint(1, 100, (ngram_size,))
        )

    def forward(self, input_ids, hidden_states):
        B, T = input_ids.shape

        if T >= self.ngram_size:
            ngrams = input_ids.unfold(1, self.ngram_size, 1)
            pad = torch.zeros(
                B,
                self.ngram_size - 1,
                self.ngram_size,
                dtype=input_ids.dtype,
                device=input_ids.device
            )
            ngrams = torch.cat([pad, ngrams], dim=1)
        else:
            ngrams = torch.zeros(
                B,
                T,
                self.ngram_size,
                dtype=input_ids.dtype,
                device=input_ids.device
            )

        # Hash vetorizado simples
        hash_vals = (ngrams * self.hash_weights).sum(-1) % self.table_size

        mem_vec = self.memory(hash_vals)

        k = self.key_proj(mem_vec)
        v = self.value_proj(mem_vec)

        h_norm = F.normalize(hidden_states, dim=-1)
        k_norm = F.normalize(k, dim=-1)

        alpha = torch.sigmoid(
            (h_norm * k_norm).sum(-1, keepdim=True)
            / math.sqrt(hidden_states.size(-1))
        )

        return hidden_states + alpha * v


class TinyTransformerWithEngram(nn.Module):
    def __init__(self, vocab_size=2000, hidden_dim=64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)

        self.engram = TinyEngram(hidden_dim=hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        x = self.transformer(x)
        x = self.engram(input_ids, x)
        return self.lm_head(x)


# =========================
# Benchmark de Tempo
# =========================
def benchmark_time(model, input_ids, runs=10):
    model.eval()
    start = time.time()
    with torch.no_grad():
        for _ in range(runs):
            _ = model(input_ids)
    return (time.time() - start) / runs


# =========================
# Medi√ß√£o de RAM
# =========================
def measure_ram_mb():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 ** 2)


# =========================
# Complexidade Te√≥rica
# =========================
def theoretical_transformer(seq_len, d):
    # Aten√ß√£o dominante: O(T¬≤ * d)
    return seq_len * seq_len * d

def theoretical_engram(seq_len, d):
    # Lookup + proje√ß√µes: O(T * d¬≤)
    return seq_len * d * d


# =========================
# Exemplo de Uso
# =========================
if __name__ == "__main__":

    vocab_size = 2000
    hidden_dim = 64
    batch_size = 4

    base_model = TinyTransformer(vocab_size, hidden_dim)
    engram_model = TinyTransformerWithEngram(vocab_size, hidden_dim)

    print("\n=== Teste de RAM ===")
    input_ids = torch.randint(0, vocab_size, (batch_size, 64))

    ram_before = measure_ram_mb()
    _ = base_model(input_ids)
    ram_base = measure_ram_mb()

    _ = engram_model(input_ids)
    ram_engram = measure_ram_mb()

    print("RAM Base (MB):", ram_base - ram_before)
    print("RAM Com Engram (MB):", ram_engram - ram_before)

    print("\n=== Tempo vs seq_len ===")
    seq_lengths = [16, 32, 64, 128]

    for sl in seq_lengths:
        input_ids = torch.randint(0, vocab_size, (batch_size, sl))
        t_base = benchmark_time(base_model, input_ids)
        t_engram = benchmark_time(engram_model, input_ids)

        print(f"seq_len={sl}")
        print(f" Base: {t_base:.6f}s")
        print(f" Engram: {t_engram:.6f}s")

    print("\n=== Complexidade Te√≥rica ===")
    for sl in seq_lengths:
        c_base = theoretical_transformer(sl, hidden_dim)
        c_engram = c_base + theoretical_engram(sl, hidden_dim)

        print(f"seq_len={sl}")
        print(f" Base complexity: {c_base}")
        print(f" Engram complexity: {c_engram}")


=== Teste de RAM ===
RAM Base (MB): 17.47265625
RAM Com Engram (MB): 29.578125

=== Tempo vs seq_len ===
seq_len=16
 Base: 0.000674s
 Engram: 0.000769s
seq_len=32
 Base: 0.000685s
 Engram: 0.000974s
seq_len=64
 Base: 0.001088s
 Engram: 0.001373s
seq_len=128
 Base: 0.001862s
 Engram: 0.002382s

=== Complexidade Te√≥rica ===
seq_len=16
 Base complexity: 16384
 Engram complexity: 81920
seq_len=32
 Base complexity: 65536
 Engram complexity: 196608
seq_len=64
 Base complexity: 262144
 Engram complexity: 524288
seq_len=128
 Base complexity: 1048576
 Engram complexity: 1572864



## üìå 1Ô∏è‚É£ Uso de RAM

* **Base:** 17.47 MB
* **Com Engram:** 29.58 MB

Diferen√ßa: **~12.1 MB a mais**
Aumento relativo: **~69%**

Isso indica que o m√≥dulo Engram adiciona uma estrutura de mem√≥ria significativa ‚Äî provavelmente buffers extras, cache de estados ou embeddings persistentes.

---

## üìå 2Ô∏è‚É£ Tempo de execu√ß√£o vs `seq_len`

Os tempos crescem de forma consistente com o aumento da sequ√™ncia.

### Base

| seq_len | Tempo (s) |
| ------- | --------- |
| 16 | 0.000674 |
| 32 | 0.000685 |
| 64 | 0.001088 |
| 128 | 0.001862 |

### Engram

| seq_len | Tempo (s) |
| ------- | --------- |
| 16 | 0.000769 |
| 32 | 0.000974 |
| 64 | 0.001373 |
| 128 | 0.002382 |

üìå Observa√ß√µes importantes:

* O **Engram √© consistentemente mais lento**, mas n√£o dramaticamente.
* O overhead aumenta conforme `seq_len` cresce.
* A diferen√ßa absoluta √© pequena (fra√ß√µes de milissegundo).
* A diferen√ßa relativa gira entre **15% e 30%** dependendo do tamanho.

Isso sugere que o Engram adiciona custo proporcional ao tamanho da sequ√™ncia, mas n√£o altera drasticamente a ordem de crescimento.

---

## üìå 3Ô∏è‚É£ Complexidade Te√≥rica

Base parece seguir:

[
O(n^2)
]

Exemplo:

* 16 ‚Üí 16384
* 32 ‚Üí 65536 (4√ó)
* 64 ‚Üí 262144 (4√ó)
* 128 ‚Üí 1048576 (4√ó)

Isso √© claramente crescimento quadr√°tico.

---

### Engram

Valores:

* 16 ‚Üí 81920
* 32 ‚Üí 196608
* 64 ‚Üí 524288
* 128 ‚Üí 1572864

O crescimento n√£o √© exatamente 4√ó a cada dobra. Parece algo como:

[
O(n^2 + n . m)
]

Ou seja, h√° um termo adicional linear multiplicado por algum fator fixo (provavelmente tamanho do banco de mem√≥ria).

---

# üéØ Conclus√£o T√©cnica

‚úî O Engram aumenta:

* Uso de RAM (~70%)
* Tempo de execu√ß√£o (~15‚Äì30%)
* Complexidade com termo adicional

‚úî Mas:

* N√£o muda a ordem dominante (continua aproximadamente quadr√°tico)
* O overhead √© relativamente controlado
* Escala de forma est√°vel

---

# üìä Interpreta√ß√£o pr√°tica

Se voc√™ estiver usando isso em:

* üîπ Edge devices ‚Üí pode ser pesado
* üîπ Servidor ‚Üí impacto pequeno
* üîπ Treinamento grande escala ‚Üí pode virar gargalo