# Grokking on Algorithmic Tasks

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import random
from typing import List, Optional
from pysat.solvers import Glucose4  # or another SAT solver from python-sat

In [2]:
def generate_random_3sat_instance(
    n_vars: int,
    ratio: float = 4.2,
    seed: Optional[int] = None
):
    """
    Generate a random 3-SAT instance with n_vars variables and k clauses, 
    where k = int(ratio * n_vars). Each clause has exactly 3 literals.
    
    Returns:
        - clauses: a list of lists, each sub-list is [lit1, lit2, lit3],
          where each lit is an integer in {1..n_vars, -1..-n_vars}.
          A positive integer p means variable p is not negated,
          a negative integer -p means variable p is negated.
    """
    if seed is not None:
        random.seed(seed)

    n_clauses = int(ratio * n_vars)
    clauses = []
    for _ in range(n_clauses):
        # Randomly choose 3 distinct variables out of n_vars
        vars_in_clause = random.sample(range(1, n_vars + 1), 3)
        # Randomly negate some of them
        clause = [v if random.random() < 0.5 else -v for v in vars_in_clause]
        clauses.append(clause)
    return clauses


def solve_3sat(clauses: List[List[int]], n_vars: int):
    """
    Solve the given 3-SAT instance (clauses) with n_vars using a SAT solver.
    Returns a list of length n_vars with True/False assignments, or None if unsat.
    """
    solver = Glucose4()
    for clause in clauses:
        solver.add_clause(clause)
    sat = solver.solve()
    if not sat:
        solver.delete()
        return None
    model = solver.get_model()  # list of ints with sign indicating T/F
    solver.delete()

    # model might have length >= n_vars (depending on the solver),
    # but we only care about the first n_vars in absolute value:
    assignment = [False] * n_vars
    for lit in model:
        idx = abs(lit) - 1
        if 0 <= idx < n_vars:
            assignment[idx] = (lit > 0)
    return assignment


class ThreeSatDataset(Dataset):
    """
    Creates a dataset of random 3-SAT instances along with a known satisfying assignment.
    Each item is:
        input_representation, assignment
    Where:
        - input_representation: a token-like encoding of the 3-SAT instance
        - assignment: a list of length n_vars in {0,1}, representing F or T.
    """

    def __init__(
        self,
        num_samples: int,
        n_vars: int,
        ratio: float = 4.2,
        max_tries: int = 1000,
        skip_unsat: bool = True
    ):
        super().__init__()
        self.data = []
        self.n_vars = n_vars

        count = 0
        while len(self.data) < num_samples and count < max_tries:
            clauses = generate_random_3sat_instance(n_vars, ratio)
            solution = solve_3sat(clauses, n_vars)
            if solution is not None:
                # We have a satisfiable instance
                input_repr = self._encode_instance(clauses, n_vars)
                assignment = [1 if x else 0 for x in solution]
                self.data.append((input_repr, assignment))
            else:
                if not skip_unsat:
                    # We'll store unsat instance with a placeholder assignment
                    input_repr = self._encode_instance(clauses, n_vars)
                    assignment = [0] * n_vars  # or any fallback
                    self.data.append((input_repr, assignment))
            count += 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def _encode_instance(self, clauses: List[List[int]], n_vars: int):
        """
        Encode the 3-SAT instance as a list of integer tokens (or any token representation).
        Example simplistic encoding scheme:
            - We'll start with a special token for "start of instance", e.g. 0
            - For each clause, we add tokens [ var1, var2, var3, <sep> ]
              where var is +x or -x mapped to distinct integer IDs.
            - End with a <end> token.
        The mapping is arbitrary as long as it's consistent.
        
        Return a list of ints that the Transformer can embed.
        """
        tokens = [0]  # start token
        # We'll map literal i (±1..±n_vars) to int: +x -> x, -x -> n_vars + |x|
        # That is, pos i in [1..n_vars] = i, neg i = n_vars + |i|.
        for clause in clauses:
            for lit in clause:
                if lit > 0:
                    tokens.append(lit)  # e.g. 1..n_vars
                else:
                    tokens.append(n_vars + abs(lit))  # e.g. n_vars+1..2*n_vars
            tokens.append(2 * n_vars + 1)  # a <sep> token
        tokens.append(2 * n_vars + 2)  # <end> token
        return torch.tensor(tokens, dtype=torch.long)

In [3]:
class TransformerBlock(nn.Module):
    """
    A single Transformer block with multi-head attention and a GELU feedforward.
    """
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, batch_first=True)
        self.attn_ln = nn.RMSNorm(d_model)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.ff_ln = nn.RMSNorm(d_model)

    def forward(self, x, mask=None):
        # x shape: (batch, seq_len, d_model)
        attn_out, _ = self.attn(x, x, x, attn_mask=mask)
        x = self.attn_ln(x + attn_out)
        ff_out = self.ff(x)
        x = self.ff_ln(x + ff_out)
        return x


class TransformerModel(nn.Module):
    """
    Minimal Transformer-based model:
      - An embedding layer
      - N Transformer blocks
      - A readout that produces a sequence of length n_vars (the T/F assignment)
    """
    def __init__(
        self,
        vocab_size,
        d_model=128,
        n_heads=4,
        d_ff=512,
        num_layers=4,
        n_vars=10
    ):
        super().__init__()
        self.n_vars = n_vars
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)

        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff)
            for _ in range(num_layers)
        ])
        self.final_ln = nn.RMSNorm(d_model)

        # We will produce an output for each variable in {0,1}:
        # i.e. a classification "head" that returns a T/F assignment per variable.
        # We'll do this by reading the final hidden state from the [start] token 
        # (or something else) and projecting into n_vars*2, or produce an entire sequence.
        #
        # For simplicity, let's produce a single vector (batch, d_model) from the
        # [start] token, and then map that to (batch, n_vars, 2).
        self.classifier = nn.Linear(d_model, n_vars * 2)

    def forward(self, x, mask=None):
        """
        x: (batch, seq_len)
        Return shape: (batch, n_vars, 2)
        """
        emb = self.embed(x)  # (batch, seq_len, d_model)
        hidden = emb
        for block in self.blocks:
            hidden = block(hidden, mask=mask)
        hidden = self.final_ln(hidden)

        # Let's index the first token [start] = hidden[:,0,:]
        start_token_hidden = hidden[:, 0, :]  # (batch, d_model)
        logits = self.classifier(start_token_hidden)  # (batch, n_vars*2)
        logits = logits.view(-1, self.n_vars, 2)      # (batch, n_vars, 2)
        return logits

In [4]:
def collate_fn(batch):
    """
    Collate function for DataLoader.
    We have variable-length input sequences, so we'll pad them.
    """
    inputs = [b[0] for b in batch]
    targets = [b[1] for b in batch]

    # Pad inputs
    input_lens = [len(i) for i in inputs]
    max_len = max(input_lens)
    padded_inputs = []
    for inp in inputs:
        pad_len = max_len - len(inp)
        padded_inp = torch.cat([inp, torch.zeros(pad_len, dtype=torch.long)], dim=0)
        padded_inputs.append(padded_inp)
    padded_inputs = torch.stack(padded_inputs, dim=0)

    # Targets are all length n_vars
    targets = torch.tensor(targets, dtype=torch.long)

    return padded_inputs, targets


def main():
    # Hyperparameters
    N_VARS = 5      # number of variables per sample
    RATIO = 3.0      # clauses-to-variables ratio
    TRAIN_SAMPLES = 200000
    VAL_SAMPLES = 20000
    BATCH_SIZE = 128
    EPOCHS = 1000
    LR = 1e-3

    # We define the vocabulary size:
    #   - For each variable index i in 1..n_vars, we have a token i for the positive literal
    #   - For each negative literal, we have n_vars + i
    #   - We add 1 special <sep> token, 1 <end> token, and 1 <start> token
    # So in total: 2*n_vars + 2 or 2*n_vars + a handful of special tokens
    # We'll define them as:
    vocab_size = 2*N_VARS + 3  # 3 for [start], <sep>, <end> (some margin)

    # Create dataset
    train_ds = ThreeSatDataset(num_samples=TRAIN_SAMPLES, n_vars=N_VARS, ratio=RATIO)
    val_ds = ThreeSatDataset(num_samples=VAL_SAMPLES, n_vars=N_VARS, ratio=RATIO)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print("Using device:", device)

    # Create model
    model = TransformerModel(
        vocab_size=vocab_size,
        d_model=512,
        n_heads=8,
        d_ff=2048,
        num_layers=8,
        n_vars=N_VARS
    ).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR)

    train_losses = []
    val_losses = []

    # Training loop
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0.0
        
        for batch_inp, batch_tgt in train_loader:
            batch_inp = batch_inp.to(device)
            batch_tgt = batch_tgt.to(device)  # shape: (batch, n_vars)

            optimizer.zero_grad()
            logits = model(batch_inp)        # (batch, n_vars, 2)
            # Reshape for cross-entropy:
            # We have (batch, n_vars, 2) and want (batch*n_vars, 2) vs (batch*n_vars)
            logits_flat = logits.view(-1, 2)
            targets_flat = batch_tgt.view(-1)
            loss = criterion(logits_flat, targets_flat)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_inp, batch_tgt in val_loader:
                batch_inp = batch_inp.to(device)
                batch_tgt = batch_tgt.to(device)

                logits = model(batch_inp)
                logits_flat = logits.view(-1, 2)
                targets_flat = batch_tgt.view(-1)
                loss = criterion(logits_flat, targets_flat)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")

        if (epoch + 1) % 10 == 0:
            model.eval()
            test_clauses = generate_random_3sat_instance(N_VARS, RATIO)
            test_solution = solve_3sat(test_clauses, N_VARS)
            encoded = train_ds._encode_instance(test_clauses, N_VARS).unsqueeze(0).to(device)
            with torch.no_grad():
                pred_logits = model(encoded)
                pred = pred_logits.argmax(dim=-1).squeeze(0).cpu().tolist()
            
            print("\nSample test case:")
            print("Clauses:", test_clauses)
            print("SAT solver solution:", [1 if x else 0 for x in test_solution])
            print("Model prediction:", pred)
            print("Match:", pred == [1 if x else 0 for x in test_solution], "\n")

if __name__ == "__main__":
    main()

Using device: mps
Epoch 1/1000 | Train Loss: 1.4204 | Val Loss: 0.7608
Percentage of all-zero predictions: 0.00%
Epoch 2/1000 | Train Loss: 0.6980 | Val Loss: 0.6524
Percentage of all-zero predictions: 0.00%
Epoch 3/1000 | Train Loss: 0.6404 | Val Loss: 0.6207
Percentage of all-zero predictions: 0.00%
Epoch 4/1000 | Train Loss: 0.6349 | Val Loss: 0.6192
Percentage of all-zero predictions: 100.00%
Epoch 5/1000 | Train Loss: 0.6338 | Val Loss: 0.6199
Percentage of all-zero predictions: 0.00%
Epoch 6/1000 | Train Loss: 0.6319 | Val Loss: 0.6191
Percentage of all-zero predictions: 100.00%
Epoch 7/1000 | Train Loss: 0.6379 | Val Loss: 0.6202
Percentage of all-zero predictions: 0.00%
Epoch 8/1000 | Train Loss: 0.6326 | Val Loss: 0.6182
Percentage of all-zero predictions: 0.00%
Epoch 9/1000 | Train Loss: 0.6283 | Val Loss: 0.6209
Percentage of all-zero predictions: 0.00%
Epoch 10/1000 | Train Loss: 0.6287 | Val Loss: 0.6144
Percentage of all-zero predictions: 100.00%

Sample test case:
Clause