In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import json

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            yield json.loads(line)

In [2]:
# Define constants
MODEL_DIR = "E:\\YuyangGPT\\models\\minilm-custom-eos"
CHUNK_SIZE = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer + embedding model
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModel.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30523, 384, padding_idx=0)
    (position_embeddings): Embedding(512, 384)
    (token_type_embeddings): Embedding(2, 384)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-5): 6 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=384, out_features=384, bias=True)
            (key): Linear(in_features=384, out_features=384, bias=True)
            (value): Linear(in_features=384, out_features=384, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=384, out_features=384, bias=True)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)


In [3]:
for p in model.parameters():
    p.requires_grad = False

In [7]:
def chunk_tokens(input_ids, attention_mask, chunk_size=128):
    chunks = []

    seq_len = input_ids.size(1)

    for start in range(0, seq_len, chunk_size):
        end = start + chunk_size

        chunk_ids = input_ids[:, start:end]
        chunk_mask = attention_mask[:, start:end]

        if chunk_ids.size(1) == 0:
            continue

        chunks.append({
            "input_ids": chunk_ids,
            "attention_mask": chunk_mask
        })

    return chunks

def encode_chunks(chunks, stm_model):
    summaries = []

    with torch.no_grad():
        for chunk in chunks:
            outputs = model(
                input_ids=chunk["input_ids"],
                attention_mask=chunk["attention_mask"]
            )

            # Token-level hidden states
            h = outputs.last_hidden_state  # [1, T, d]

            summary = stm_model(
                hidden_states=h,
                attention_mask=chunk["attention_mask"]
            )

            summaries.append(summary.squeeze(0))  # [d]

    return torch.stack(summaries)  # [num_chunks, d]


In [8]:
class ChunkTransformer(nn.Module):
    def __init__(self, hidden_size, num_layers=2, num_heads=4):
        super().__init__()

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

        # Learnable summary token
        self.summary_token = nn.Parameter(
            torch.randn(1, 1, hidden_size)
        )

    def forward(self, hidden_states, attention_mask):
        """
        hidden_states: [B, T, d]
        attention_mask: [B, T]
        """

        B, T, d = hidden_states.shape

        # Expand summary token for batch
        summary = self.summary_token.expand(B, 1, d)

        # Prepend summary token
        x = torch.cat([summary, hidden_states], dim=1)  # [B, T+1, d]

        # Build attention mask (1 = keep, 0 = mask)
        summary_mask = torch.ones(B, 1, device=attention_mask.device)
        attn_mask = torch.cat([summary_mask, attention_mask], dim=1)

        # TransformerEncoder uses True = masked
        key_padding_mask = attn_mask == 0

        out = self.encoder(
            x,
            src_key_padding_mask=key_padding_mask
        )

        # Return summary token output
        return out[:, 0]  # [B, d]


In [10]:
class STMTrainer(nn.Module):
    def __init__(self, stm, hidden_size, vocab_size):
        super().__init__()
        self.stm = stm
        self.lm_head = nn.Linear(hidden_size, vocab_size)

    def forward(self, hidden_states, attention_mask):
        """
        hidden_states: [B, T, d]  (from MiniLM)
        attention_mask: [B, T]
        """
        summary = self.stm(hidden_states, attention_mask)  # [B, d]
        logits = self.lm_head(summary)                      # [B, vocab]
        return logits


In [12]:
stm = ChunkTransformer(
    hidden_size=model.config.hidden_size,
    num_layers=2,
    num_heads=4
).to(DEVICE)

In [13]:
stm_trainer = STMTrainer(
    stm=stm,
    hidden_size=model.config.hidden_size,
    vocab_size=len(tokenizer)
).to(DEVICE)


In [17]:
# OPTIONAL, LOAD STM AND STM_TRAINER PARAM DICTs
stm_dir = "E:\YuyangGPT\models\stms\stm_3"

# ---- Load STM ----
stm_ckpt = torch.load(
    f"{stm_dir}/stm_checkpoint.pt",
    map_location=DEVICE
)

stm.load_state_dict(stm_ckpt["stm_state_dict"])

# ---- Load STM trainer ----
trainer_state = torch.load(
    f"{stm_dir}/stm_trainer_checkpoint.pt",
    map_location=DEVICE
)

stm_trainer.load_state_dict(trainer_state)

<All keys matched successfully>

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(stm_trainer.parameters(), lr=3e-4)


In [None]:
def train_on_chunks(chunks, model, stm_trainer):
    total_loss = 0.0

    for i in range(len(chunks) - 1):
        # Encode current chunk
        outputs = model(
            input_ids=chunks[i]["input_ids"],
            attention_mask=chunks[i]["attention_mask"]
        )

        h = outputs.last_hidden_state  # [1, T, d]

        # Predict next chunk's first token
        logits = stm_trainer(
            hidden_states=h,
            attention_mask=chunks[i]["attention_mask"]
        )  # [1, vocab]

        target = chunks[i + 1]["input_ids"][:, 0]  # [1]

        loss = loss_fn(logits, target)
        loss.backward()
        for name, p in stm_trainer.named_parameters():
            if p.grad is not None:
                print(name, p.grad.norm())
                break
        assert loss.requires_grad
        assert stm_trainer.lm_head.weight.grad is not None

        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    return total_loss / (len(chunks) - 1)


In [None]:
epochs = 30
stm_dir = "E:\YuyangGPT\models\stms\stm_3"
JSONL_PATH = "E:\\YuyangGPT\\dataset\\cleaned_data\\train_tokenized_discord_messages.jsonl"

In [None]:
for epoch in range(epochs):
    total_loss = 0.0
    n_docs = 0

    for sample in load_jsonl(JSONL_PATH):
        # Convert lists → tensors
        input_ids = torch.tensor(
            sample["input_ids"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        attention_mask = torch.tensor(
            sample["attention_mask"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        # Skip short sequences
        if input_ids.size(1) <= CHUNK_SIZE:
            continue

        chunks = chunk_tokens(input_ids, attention_mask, CHUNK_SIZE)

        if len(chunks) < 2:
            continue

        loss = train_on_chunks(chunks, model, stm_trainer)

        total_loss += loss
        n_docs += 1

    avg_loss = total_loss / max(n_docs, 1)
    print(f"Epoch {epoch} | STM loss: {avg_loss:.4f}")

In [None]:
torch.save(
    {
        "stm_state_dict": stm.state_dict(),
        "stm_config": {
            "hidden_size": model.config.hidden_size,
            "num_layers": 2,
            "num_heads": 4
        }
    },
    f"{stm_dir}/stm_checkpoint.pt"
)
torch.save(
    stm_trainer.state_dict(),
    f"{stm_dir}/stm_trainer_checkpoint.pt"
)


In [None]:
# Outline for LTM
# Pass in summary tokens, get k latent vectors of size d (hidden state size)
# Run concated original tokens with latent vectors through frozen stm model with decoding head to get next token
# Train LTM only against predicted vs actual next token
# Pseudocode
# input_tokens -> stm_model -> summary tokens -> ltm_model -> latent vectors -> concat with input_tokens -> stm_model + decoding head -> next token prediction -> loss, backprop LTM only

In [18]:
class LongTermMemory(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_latents=8,
        num_layers=4,
        num_heads=4
    ):
        super().__init__()

        # Learned latent tokens
        self.latents = nn.Parameter(
            torch.randn(1, num_latents, hidden_size)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

    def forward(self, summaries, attention_mask=None):
        """
        summaries: [B, T, d]
        returns:   [B, K, d]
        """

        B = summaries.size(0)

        latents = self.latents.expand(B, -1, -1)

        # Concatenate: latents attend to summaries
        x = torch.cat([latents, summaries], dim=1)

        if attention_mask is not None:
            latent_mask = torch.ones(
                B,
                latents.size(1),
                device=attention_mask.device
            )
            attn_mask = torch.cat([latent_mask, attention_mask], dim=1)
        else:
            attn_mask = None

        out = self.encoder(x, src_key_padding_mask=(attn_mask == 0 if attn_mask is not None else None))

        # Return only latent outputs
        return out[:, :latents.size(1)]


In [29]:
def encode_chunks_with_hiddens(chunks, stm_model):
    """
    Encode chunks and return both summaries and hidden states
    
    Returns:
        summaries: [num_chunks, d]
        chunk_hiddens: List of [1, T_i, d]
    """
    summaries = []
    chunk_hiddens = []
    
    with torch.no_grad():
        for chunk in chunks:
            outputs = model(
                input_ids=chunk["input_ids"],
                attention_mask=chunk["attention_mask"]
            )
            
            # Token-level hidden states
            h = outputs.last_hidden_state  # [1, T, d]
            
            # Get summary via STM
            summary = stm_model(
                hidden_states=h,
                attention_mask=chunk["attention_mask"]
            )
            
            summaries.append(summary.squeeze(0))  # [d]
            chunk_hiddens.append(h)  # [1, T, d]
    
    summaries_stacked = torch.stack(summaries)  # [num_chunks, d]
    return summaries_stacked, chunk_hiddens

def train_ltm_step(
    chunks,
    summaries,  # [1, num_chunks, d]
    chunk_hiddens,  # List of [1, T_i, d]
    ltm,
    stm_trainer,
    optimizer,
    DEVICE,
    K,
    N=3  # Number of refinement iterations
):
    """
    Train LTM with iterative latent refinement:
    1. Get initial latents from LTM using context summaries
    2. For N iterations:
       - Concatenate chunk hidden states with current latents
       - Pass through STM to get refined representation
       - Pass refined representation through LTM to generate new latents
    3. Use final latents concatenated with chunk hidden states to predict next token
    """
    num_chunks = len(chunks)
    total_loss = 0.0
    
    for i in range(num_chunks - 1):
        # Get summaries from chunks 0 to i (context for LTM)
        context_summaries = summaries[:, :i+1]  # [1, i+1, d]
        
        # Pass through LTM to get initial latent vectors
        latents = ltm(context_summaries)  # [1, K, d]
        
        # Get hidden states of current chunk
        chunk_h = chunk_hiddens[i]  # [1, T_i, d]
        
        # Iteratively refine latents N times
        for _ in range(N):
            # Concatenate chunk hidden states with current latents
            combined = torch.cat([chunk_h, latents], dim=1)  # [1, T_i+K, d]
            
            # Create attention mask
            attention_mask = torch.ones(combined.shape[:2], dtype=torch.long, device=DEVICE)
            
            # Pass through STM to get refined summary
            refined_summary = stm_trainer.stm(combined, attention_mask)  # [1, d]
            
            # Use refined summary to generate new latents via LTM
            refined_context = refined_summary.unsqueeze(1)  # [1, 1, d]
            latents = ltm(refined_context)  # [1, K, d]
        
        # Final prediction: concatenate chunk_h with final latents
        final_combined = torch.cat([chunk_h, latents], dim=1)  # [1, T_i+K, d]
        attention_mask = torch.ones(final_combined.shape[:2], dtype=torch.long, device=DEVICE)
        
        # Get logits for next token prediction via stm_trainer head
        logits = stm_trainer(final_combined, attention_mask)  # [1, vocab]
        
        # Target: first token of the next chunk
        target = chunks[i+1]["input_ids"][:, 0]  # [1]
        
        # Calculate loss
        loss = F.cross_entropy(logits, target)
        
        # Backprop (only updates LTM)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / max(num_chunks - 1, 1)


In [None]:
K = 2 # number of latents
N = 3  # number of refinement iterations
epochs = 10 
JSONL_PATH = "E:\\YuyangGPT\\dataset\\cleaned_data\\train_tokenized_discord_messages.jsonl" # Tokenized training data
ltm = LongTermMemory(
    hidden_size=model.config.hidden_size,
    num_latents=K,
    num_layers=4,
    num_heads=4
).to(DEVICE)

ltm_optimizer = torch.optim.AdamW(
    ltm.parameters(),
    lr=3e-4
)

for epoch in range(epochs):
    total_loss = 0.0
    n_docs = 0

    for sample in load_jsonl(JSONL_PATH):
        # Convert lists → tensors
        input_ids = torch.tensor(
            sample["input_ids"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        attention_mask = torch.tensor(
            sample["attention_mask"],
            dtype=torch.long,
            device=DEVICE
        ).unsqueeze(0)  # [1, T]

        # Skip short sequences
        if input_ids.size(1) <= CHUNK_SIZE:
            continue

        chunks = chunk_tokens(input_ids, attention_mask, CHUNK_SIZE)

        if len(chunks) < K + 1:
            continue
        
        # Encode chunks to get summaries and hidden states
        summaries, chunk_hiddens = encode_chunks_with_hiddens(chunks, stm)
        summaries_batched = summaries.unsqueeze(0)  # [1, num_chunks, d]
        
        # Train LTM on this batch with iterative refinement
        loss = train_ltm_step(
            chunks,
            summaries_batched,
            chunk_hiddens,
            ltm,
            stm_trainer,
            ltm_optimizer,
            DEVICE,
            K,
            N  # Pass refinement iterations
        )
        
        total_loss += loss
        n_docs += 1

    avg_loss = total_loss / max(n_docs, 1)
    print(f"Epoch {epoch} | LTM loss: {avg_loss:.4f}")

In [None]:
ltm_dir = "E:\YuyangGPT\models\ltms\ltm_1"
torch.save(
    {
        "ltm_state_dict": ltm.state_dict(),
        "ltm_config": {
            "hidden_size": model.config.hidden_size,
            "num_latents": K,
            "num_layers": 4,
            "num_heads": 4
        },
        "predictor_state_dict": ltm_predictor.state_dict()
    },
    f"{ltm_dir}/ltm_checkpoint.pt"
)

torch.save(
    ltm_predictor.state_dict(),
    f"{ltm_dir}/ltm_predictor_checkpoint.pt"
)

In [75]:
text = """
Hello! I am YuyangGPT, a custom language model developed by Yuyang Hu, to mimic his style of writing and to increase his ability to be lazy
"""
N = 10

In [76]:
# Tokenize the text
inputs = tokenizer(
    text,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=512
).to(DEVICE)

# Initialize inputs
current_input_ids = inputs["input_ids"]
current_attention_mask = inputs["attention_mask"]


In [None]:
chunks = chunk_tokens(inputs["input_ids"], inputs["attention_mask"], chunk_size=128)

# Encode chunks to get summaries and hidden states
summaries, chunk_hiddens = encode_chunks_with_hiddens(chunks, stm)
summaries_batched = summaries.unsqueeze(0)  # [1, num_chunks, d]

# Pass summaries through LTM to generate latents
with torch.no_grad():
    ltm_latents = ltm(summaries_batched)  # [1, K, d]

print("LTM Latents shape:", ltm_latents.shape)

# Concatenate latents with the last chunk's hidden states
last_chunk_hidden = chunk_hiddens[-1]  # [1, T, d]
combined = torch.cat([last_chunk_hidden, ltm_latents], dim=1)  # [1, T+K, d]

# Create attention mask
attention_mask = torch.ones(combined.shape[:2], dtype=torch.long, device=DEVICE)

# Pass through stm_trainer to predict next token
ltm_logits = stm_trainer(combined, attention_mask)  # [1, vocab]
ltm_tokens = torch.argmax(ltm_logits, dim=-1, keepdim=True)  # [1, 1]

decoded_text = tokenizer.decode(
    ltm_tokens[0],
    skip_special_tokens=True
)

print("LTM generated tokens:", decoded_text)

# Append LTM tokens to current input
current_input_ids = torch.cat(
    [inputs["input_ids"], ltm_tokens],
    dim=1
)  # [1, T + 1]

current_attention_mask = torch.ones_like(current_input_ids)
print(f"New input length: {current_input_ids.shape[1]}")
print(f"Full text: {tokenizer.decode(current_input_ids[0], skip_special_tokens=True)}")


LTM Latents raw tensors:  tensor([[[ 6.1346e-03,  2.0783e-01, -1.4621e-02,  1.2475e-02, -4.2895e-02,
           2.4536e-02, -3.1636e-02, -3.5248e-02,  3.1469e-02, -7.7356e-02,
           1.0164e-01,  5.6215e-03,  7.1762e-02,  6.4854e-02,  3.2184e-02,
          -1.0968e-01,  2.3746e-02,  2.4515e-01, -2.3751e-01, -1.6380e-02,
          -9.6214e-02,  9.9117e-03, -4.5560e-02,  2.1287e-01, -2.6197e-03,
           8.4560e-02,  1.1555e-02,  3.6881e-02,  5.1513e-02, -2.3701e-02,
          -2.5928e-02, -4.1905e-02,  1.5209e-01, -3.0275e-02, -1.8244e-04,
          -6.3817e-03,  7.6464e-02, -1.4012e-02,  4.5082e+00, -1.0293e-01,
           8.2999e-02, -9.9391e-05, -1.9487e-01,  5.5413e-02, -2.7280e+00,
          -6.3853e-02, -1.3144e-01,  2.7343e+00,  1.4637e-02,  5.4035e-01,
           1.8240e-02, -4.2589e-03, -1.1037e-01,  1.2690e-02,  6.7131e-02,
           3.9765e-02, -1.2319e+01,  2.9649e-02, -6.6016e-02,  6.5117e-02,
          -2.5257e-02, -1.6618e-02,  6.2852e-01,  4.8459e-02, -3.8034e-01,

In [None]:
# ---- STEPS 1-N: Refine with new chunks ----
for step in range(1, N):
    chunks = chunk_tokens(current_input_ids, current_attention_mask, chunk_size=128)
    
    # Encode chunks to get summaries and hidden states
    summaries, chunk_hiddens = encode_chunks_with_hiddens(chunks, stm)
    summaries_batched = summaries.unsqueeze(0)  # [1, num_chunks, d]

    print("Summaries shape:", summaries_batched.shape)

    # Pass summaries through LTM to generate latents
    with torch.no_grad():
        ltm_latents = ltm(summaries_batched)  # [1, K, d]

    print("LTM latents shape:", ltm_latents.shape)

    # Concatenate latents with the last chunk's hidden states
    last_chunk_hidden = chunk_hiddens[-1]  # [1, T, d]
    combined = torch.cat([last_chunk_hidden, ltm_latents], dim=1)  # [1, T+K, d]
    
    # Create attention mask
    attention_mask = torch.ones(combined.shape[:2], dtype=torch.long, device=DEVICE)
    
    # Predict next token
    ltm_logits = stm_trainer(combined, attention_mask)  # [1, vocab]
    ltm_tokens = torch.argmax(ltm_logits, dim=-1, keepdim=True)  # [1, 1]

    decoded_text = tokenizer.decode(
        ltm_tokens[0],
        skip_special_tokens=True
    )

    print("LTM generated tokens:", decoded_text)

    # Append new token to current input
    current_input_ids = torch.cat(
        [current_input_ids, ltm_tokens],
        dim=1
    )  # [1, T + 1]

    current_attention_mask = torch.ones_like(current_input_ids)

    print(f"New input length: {current_input_ids.shape[1]}")
    print(f"Full text: {tokenizer.decode(current_input_ids[0], skip_special_tokens=True)}")

print("\n===== FINAL OUTPUT =====")
print(f"Final sequence length: {current_input_ids.shape[1]}")
print(f"Final text: {tokenizer.decode(current_input_ids[0], skip_special_tokens=True)}")


Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New input length: 15
Full text: this is a test of yuyanggpt v1 / /
Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New input length: 15
Full text: this is a test of yuyanggpt v1 / /
Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New input length: 15
Full text: this is a test of yuyanggpt v1 / /
Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New input length: 15
Full text: this is a test of yuyanggpt v1 / /
Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New input length: 15
Full text: this is a test of yuyanggpt v1 / /
Summaries shape: torch.Size([1, 1, 384])
LTM latents shape: torch.Size([1, 2, 384])
LTM generated tokens: / /
New i