In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

In [None]:
# Define constants
MODEL_DIR = "E:\\YuyangGPT\\models\\minilm-custom-eos"
CHUNK_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer + embedding model
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModel.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()

In [None]:
for p in model.parameters():
    p.requires_grad = False

In [None]:
class ChunkTransformer(nn.Module):
    def __init__(self, hidden_size, num_layers=2, num_heads=4):
        super().__init__()

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Learnable summary token
        self.summary_token = nn.Parameter(
            torch.randn(1, 1, hidden_size)
        )

    def forward(self, hidden_states, attention_mask):
        """
        hidden_states: [B, T, d]
        attention_mask: [B, T]
        """

        B, T, d = hidden_states.shape

        # Expand summary token for batch
        summary = self.summary_token.expand(B, 1, d)

        # Prepend summary token
        x = torch.cat([summary, hidden_states], dim=1)  # [B, T+1, d]

        # Build attention mask (1 = keep, 0 = mask)
        summary_mask = torch.ones(B, 1, device=attention_mask.device)
        attn_mask = torch.cat([summary_mask, attention_mask], dim=1)

        # TransformerEncoder uses True = masked
        key_padding_mask = attn_mask == 0

        out = self.encoder(
            x,
            src_key_padding_mask=key_padding_mask
        )

        # Return summary token output
        return out[:, 0]  # [B, d]


In [None]:
class LongTermMemory(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_latents=8,
        num_layers=4,
        num_heads=4
    ):
        super().__init__()

        # Learned latent tokens
        self.latents = nn.Parameter(
            torch.randn(1, num_latents, hidden_size)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

    def forward(self, summaries, attention_mask=None):
        """
        summaries: [B, T, d]
        returns:   [B, K, d]
        """

        B = summaries.size(0)

        latents = self.latents.expand(B, -1, -1)

        # Concatenate: latents attend to summaries
        x = torch.cat([latents, summaries], dim=1)

        if attention_mask is not None:
            latent_mask = torch.ones(
                B,
                latents.size(1),
                device=attention_mask.device
            )
            attn_mask = torch.cat([latent_mask, attention_mask], dim=1)
        else:
            attn_mask = None

        out = self.encoder(x, src_key_padding_mask=(attn_mask == 0 if attn_mask is not None else None))

        # Return only latent outputs
        return out[:, :latents.size(1)]


In [None]:
class LTMPredictorHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.net = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, latents):
        """
        latents: [B, K, d]
        returns: [B, K, d]
        """
        return self.net(latents)


In [None]:
stm = ChunkTransformer(
    hidden_size=model.config.hidden_size,
    num_layers=2,
    num_heads=4
).to(DEVICE)


In [None]:
def chunk_tokens(input_ids, attention_mask, chunk_size=128):
    chunks = []

    seq_len = input_ids.size(1)

    for start in range(0, seq_len, chunk_size):
        end = start + chunk_size

        chunk_ids = input_ids[:, start:end]
        chunk_mask = attention_mask[:, start:end]

        if chunk_ids.size(1) == 0:
            continue

        chunks.append({
            "input_ids": chunk_ids,
            "attention_mask": chunk_mask
        })

    return chunks

def encode_chunks(chunks, stm_model):
    summaries = []

    with torch.no_grad():
        for chunk in chunks:
            outputs = model(
                input_ids=chunk["input_ids"],
                attention_mask=chunk["attention_mask"]
            )

            # Token-level hidden states
            h = outputs.last_hidden_state  # [1, T, d]

            summary = stm_model(
                hidden_states=h,
                attention_mask=chunk["attention_mask"]
            )

            summaries.append(summary.squeeze(0))  # [d]

    return torch.stack(summaries)  # [num_chunks, d]


In [None]:
class STMTrainer(nn.Module):
    def __init__(self, stm, hidden_size, vocab_size):
        super().__init__()
        self.stm = stm
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, hidden_states, attention_mask):
        """
        hidden_states: [B, T, d]  (from MiniLM)
        attention_mask: [B, T]
        """
        summary = self.stm(hidden_states, attention_mask)  # [B, d]
        logits = self.lm_head(summary)                      # [B, vocab]
        return logits


In [None]:
stm_trainer = STMTrainer(
    stm=stm,
    hidden_size=model.config.hidden_size,
    vocab_size=len(tokenizer)
).to(DEVICE)


In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(stm_trainer.parameters(), lr=3e-4)


In [None]:
def train_on_chunks(chunks, model, stm_trainer):
    total_loss = 0.0

    for i in range(len(chunks) - 1):
        # Encode current chunk
        outputs = model(
            input_ids=chunks[i]["input_ids"],
            attention_mask=chunks[i]["attention_mask"]
        )

        h = outputs.last_hidden_state  # [1, T, d]

        # Predict next chunk's first token
        logits = stm_trainer(
            hidden_states=h,
            attention_mask=chunks[i]["attention_mask"]
        )  # [1, vocab]

        target = chunks[i + 1]["input_ids"][:, 0]  # [1]

        loss = loss_fn(logits, target)
        loss.backward()
        for name, p in stm_trainer.named_parameters():
            if p.grad is not None:
                print(name, p.grad.norm())
                break
        assert loss.requires_grad
        assert stm_trainer.lm_head.weight.grad is not None

        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    return total_loss / (len(chunks) - 1)


In [None]:
epochs = 30
stm_dir = "E:\YuyangGPT\models\stms\stm_3"
JSONL_PATH = "E:\\YuyangGPT\\dataset\\cleaned_data\\train_tokenized_discord_messages.jsonl"



In [None]:
import json

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            yield json.loads(line)

In [None]:

for epoch in range(epochs):
    total_loss = 0.0
    n_docs = 0

    for sample in load_jsonl(JSONL_PATH):
        # Convert lists → tensors
        input_ids = torch.tensor(
            sample["input_ids"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        attention_mask = torch.tensor(
            sample["attention_mask"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        # Skip short sequences
        if input_ids.size(1) <= CHUNK_SIZE:
            continue

        chunks = chunk_tokens(input_ids, attention_mask, CHUNK_SIZE)

        if len(chunks) < 2:
            continue

        loss = train_on_chunks(chunks, model, stm_trainer)

        total_loss += loss
        n_docs += 1

    avg_loss = total_loss / max(n_docs, 1)
    print(f"Epoch {epoch} | STM loss: {avg_loss:.4f}")

In [None]:
torch.save(
    {
        "stm_state_dict": stm.state_dict(),
        "stm_config": {
            "hidden_size": model.config.hidden_size,
            "num_layers": 2,
            "num_heads": 4
        }
    },
    f"{stm_dir}/stm_checkpoint.pt"
)
torch.save(
    stm_trainer.state_dict(),
    f"{stm_dir}/stm_trainer_checkpoint.pt"
)


In [None]:
def make_ltm_targets(summaries, K):
    """
    summaries: [B, T, d]
    """
    context = summaries[:, :-K]        # [B, T-K, d]
    targets = summaries[:, -K:]         # [B, K, d]
    return context, targets

def train_ltm_step(
    summaries,
    ltm,
    predictor,
    optimizer,
    K
):
    """
    summaries: [B, T, d]
    """

    context, targets = make_ltm_targets(summaries, K)

    latents = ltm(context)              # [B, K, d]
    preds = predictor(latents)          # [B, K, d]

    loss = F.mse_loss(preds, targets)
    print("context", context.shape, context.dtype, context.device)
    print("targets", targets.shape, targets.dtype, targets.device)
    print("latents requires_grad:", latents.requires_grad)
    print("preds   requires_grad:", preds.requires_grad)
    print("loss    requires_grad:", loss.requires_grad, "loss:", loss.item())

    # sanity: are your targets all zeros?
    print("targets abs mean:", targets.abs().mean().item(), "std:", targets.std().item())
    print("preds   abs mean:", preds.abs().mean().item(), "std:", preds.std().item())
    print("mse raw:", ((preds - targets) ** 2).mean().item())   

    optimizer.zero_grad()
    loss.backward()
    for name, p in predictor.named_parameters():
        if p.grad is not None:
            print(name, p.grad.norm())
            break
    optimizer.step()

    return loss.item()

stm.load_state_dict(torch.load(f"{stm_dir}/stm_checkpoint.pt")["stm_state_dict"])
stm_trainer.load_state_dict(torch.load(f"{stm_dir}/stm_trainer_checkpoint.pt"))

In [None]:
K = 2 # number of latents
ltm = LongTermMemory(
    hidden_size=model.config.hidden_size,
    num_latents=K,
    num_layers=4,
    num_heads=4
).to(DEVICE)

ltm_predictor = LTMPredictorHead(
    hidden_size=model.config.hidden_size
).to(DEVICE)

ltm_optimizer = torch.optim.AdamW(
    list(ltm.parameters()) + list(ltm_predictor.parameters()),
    lr=3e-4
)

for epoch in range(epochs):
    total_loss = 0.0
    n_docs = 0

    for sample in load_jsonl(JSONL_PATH):
        # Convert lists → tensors
        input_ids = torch.tensor(
            sample["input_ids"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]
        # print("input_ids", input_ids.shape, input_ids.dtype, input_ids.device)

        attention_mask = torch.tensor(
            sample["attention_mask"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]
        # print("attention_mask", attention_mask.shape, attention_mask.dtype, attention_mask.device)

        # Skip short sequences
        if input_ids.size(1) <= CHUNK_SIZE:
            continue

        chunks = chunk_tokens(input_ids, attention_mask, CHUNK_SIZE)

        if len(chunks) < K + 1:
            print("Skipping sample, not enough chunks:", len(chunks))
            continue
        summaries = encode_chunks(chunks, stm).unsqueeze(0)  # [1, num_chunks, d]
        print("summaries", summaries.shape, summaries.dtype, summaries.device)
        loss = train_ltm_step(
            summaries,
            ltm,
            ltm_predictor,
            ltm_optimizer,
            K
        )
        total_loss += loss
        n_docs += 1

    avg_loss = total_loss / max(n_docs, 1)
    print(f"Epoch {epoch} | LTM loss: {avg_loss:.4f}")

In [None]:
ltm_dir = "E:\YuyangGPT\models\ltms\ltm_1"
torch.save(
    {
        "ltm_state_dict": ltm.state_dict(),
        "ltm_config": {
            "hidden_size": model.config.hidden_size,
            "num_latents": K,
            "num_layers": 4,
            "num_heads": 4
        },
        "predictor_state_dict": ltm_predictor.state_dict()
    },
    f"{ltm_dir}/ltm_checkpoint.pt"
)

torch.save(
    ltm_predictor.state_dict(),
    f"{ltm_dir}/ltm_predictor_checkpoint.pt"
)

In [None]:
text = """
This is a test of YuyangGPT v1
"""

In [None]:
# Tokenize the text
inputs = tokenizer(
    text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512
).to(DEVICE)


In [46]:
chunks = chunk_tokens(inputs["input_ids"], inputs["attention_mask"], chunk_size=128)

summaries = encode_chunks(chunks, stm)

# Add batch dimension for LTM
summaries_batched = summaries.unsqueeze(0)  # [1, num_chunks, d]

In [50]:
N = 3          # number of memory refinement steps
CHUNK_SIZE = 128

# Initialize inputs
current_input_ids = inputs["input_ids"]
current_attention_mask = inputs["attention_mask"]

for step in range(N):
    print(f"\n===== ITERATION {step + 1} =====")

    # ---- Embed tokens ----
    with torch.no_grad():
        outputs = model(
            input_ids=current_input_ids,
            attention_mask=current_attention_mask
        )
    token_embeddings = outputs.last_hidden_state  # [1, T, d]

    # ---- Chunk embeddings ----
    chunks = chunk_tokens(
        token_embeddings,
        current_attention_mask,
        CHUNK_SIZE
    )

    # ---- STM pass ----
    summaries = encode_chunks(chunks, stm)
    summaries_batched = summaries.unsqueeze(0)  # [1, C, d]

    print("Summaries shape:", summaries_batched.shape)

    # ---- LTM pass ----
    with torch.no_grad():
        ltm_latents = ltm(summaries_batched)  # [1, K, d]

    print("LTM latents shape:", ltm_latents.shape)

    # ---- Decode LTM latents into tokens ----
    ltm_logits = stm_trainer.lm_head(ltm_latents)   # [1, K, vocab]
    ltm_tokens = torch.argmax(ltm_logits, dim=-1)  # [1, K]

    decoded_text = tokenizer.decode(
        ltm_tokens[0],
        skip_special_tokens=True
    )

    print("LTM decoded text:", decoded_text)

    # ---- Append tokens to input ----
    current_input_ids = torch.cat(
        [current_input_ids, ltm_tokens],
        dim=1
    )  # [1, T + K]

    current_attention_mask = torch.ones_like(current_input_ids)

    print("New input length:", current_input_ids.shape[1])



===== ITERATION 1 =====


ValueError: too many values to unpack (expected 2)