In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class LongTermMemory(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_latents=8,
        num_layers=4,
        num_heads=4
    ):
        super().__init__()

        # Learned latent tokens
        self.latents = nn.Parameter(
            torch.randn(1, num_latents, hidden_size)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )

    def forward(self, summaries, attention_mask=None):
        """
        summaries: [B, T, d]
        returns:   [B, K, d]
        """

        B = summaries.size(0)

        latents = self.latents.expand(B, -1, -1)

        # Concatenate: latents attend to summaries
        x = torch.cat([latents, summaries], dim=1)

        if attention_mask is not None:
            latent_mask = torch.ones(
                B,
                latents.size(1),
                device=attention_mask.device
            )
            attn_mask = torch.cat([latent_mask, attention_mask], dim=1)
        else:
            attn_mask = None

        out = self.encoder(x, src_key_padding_mask=(attn_mask == 0 if attn_mask is not None else None))

        # Return only latent outputs
        return out[:, :latents.size(1)]


In [3]:
class LTMPredictorHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.net = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size)
        )

    def forward(self, latents):
        """
        latents: [B, K, d]
        returns: [B, K, d]
        """
        return self.net(latents)


In [4]:
def make_ltm_targets(summaries, K):
    """
    summaries: [B, T, d]
    """
    context = summaries[:, :-K]        # [B, T-K, d]
    targets = summaries[:, -K:]         # [B, K, d]
    return context, targets


In [5]:
def train_ltm_step(
    summaries,
    ltm,
    predictor,
    optimizer,
    K
):
    """
    summaries: [B, T, d]
    """

    context, targets = make_ltm_targets(summaries, K)

    latents = ltm(context)              # [B, K, d]
    preds = predictor(latents)          # [B, K, d]

    loss = F.mse_loss(preds, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


In [6]:
import json

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            yield json.loads(line)