# Entropic-Guided Neuro-Symbolic Diffusion

In [None]:
!pip install torch transformers tree-sitter tree-sitter-python

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tree_sitter import Language, Parser
import tree_sitter_python

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, d_model, max_seq_len=8192):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[1]
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb[None, :, None, :]

def apply_rotary_pos_emb(x, freqs):
    cos = freqs.cos()
    sin = freqs.sin()
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    cos_half = cos[..., :cos.shape[-1]//2]
    sin_half = sin[..., :sin.shape[-1]//2]
    x_even_rot = x_even * cos_half - x_odd * sin_half
    x_odd_rot = x_odd * cos_half + x_even * sin_half
    x_out = torch.stack([x_even_rot, x_odd_rot], dim=-1)
    return x_out.flatten(-2)

class BidirectionalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, max_len):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        self.c_proj = nn.Linear(d_model, d_model)
        self.rope = RotaryPositionalEmbeddings(self.d_head, max_len)

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.c_attn(x).split(self.d_model, dim=2)
        q, k, v = [z.view(B, T, self.n_head, self.d_head).transpose(1, 2) for z in qkv]
        freqs = self.rope(q, T)
        q = apply_rotary_pos_emb(q, freqs)
        k = apply_rotary_pos_emb(k, freqs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_head, max_len):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = BidirectionalSelfAttention(d_model, n_head, max_len)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class DiffusionTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_head=4, n_layer=4, max_len=512):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.blocks = nn.Sequential(*[TransformerBlock(d_model, n_head, max_len) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)
        self.mask_token_id = vocab_size - 1

    def forward(self, x):
        x = self.token_embedding(x)
        x = self.blocks(x)
        x = self.ln_f(x)
        return self.head(x)

In [None]:
class SymbolicConstraint:
    def __init__(self):
        self.language = Language(tree_sitter_python.language())
        self.parser = Parser(self.language)
    
    def verify(self, code_bytes):
        tree = self.parser.parse(code_bytes)
        errors = []
        stack = [tree.root_node]
        while stack:
            node = stack.pop()
            if node.type == 'ERROR' or node.is_missing:
                errors.append(node.byte_range)
            stack.extend(node.children)
        return len(errors) == 0, errors
    
    def get_mask_from_errors(self, code_bytes, tokens, tokenizer_decode_fn):
        is_valid, errors = self.verify(code_bytes)
        if is_valid:
            return None
        token_spans = []
        offset = 0
        for t in tokens:
            s = tokenizer_decode_fn([t])
            l = len(s.encode('utf-8'))
            token_spans.append((offset, offset + l))
            offset += l
        mask_indices = []
        for start_byte, end_byte in errors:
            for idx, (t_start, t_end) in enumerate(token_spans):
                if not (t_end <= start_byte or t_start >= end_byte):
                    mask_indices.append(idx)
        return list(set(mask_indices))

class EntropicSampler:
    def __init__(self, model, tokenizer, constraint):
        self.model = model
        self.tokenizer = tokenizer
        self.constraint = constraint
        self.mask_id = model.mask_token_id
    
    def generate(self, prompt_ids, T_steps=20, output_len=20):
        curr_seq = torch.tensor(prompt_ids + [self.mask_id] * output_len, device=device).unsqueeze(0)
        prompt_len = len(prompt_ids)
        print(f'Starting generation...')
        for t in range(T_steps, 0, -1):
            with torch.no_grad():
                logits = self.model(curr_seq)
            x_pred = torch.argmax(logits, dim=-1)
            x_pred[:, :prompt_len] = torch.tensor(prompt_ids, device=device)
            pred_tokens = x_pred[0].tolist()
            pred_str = self.tokenizer.decode(pred_tokens)
            error_indices = self.constraint.get_mask_from_errors(
                pred_str.encode('utf-8'), pred_tokens, lambda t: self.tokenizer.decode(t)
            )
            probs = F.softmax(logits, dim=-1)
            entropy = -(probs * (probs + 1e-10).log()).sum(dim=-1)
            if error_indices:
                for idx in error_indices:
                    if idx >= prompt_len:
                        entropy[0, idx] = float('inf')
            num_to_mask = int(output_len * t / T_steps)
            if num_to_mask > 0:
                _, top_k = torch.topk(entropy[0, prompt_len:], k=num_to_mask)
                x_t_next = x_pred.clone()
                x_t_next[0, top_k + prompt_len] = self.mask_id
                curr_seq = x_t_next
            else:
                curr_seq = x_pred
            if t % 5 == 0:
                print(f'Step {t}: {pred_str[:40]}...')
        return curr_seq

In [None]:
class CharTokenizer:
    def __init__(self):
        chars = sorted(set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_ ()[]{},.:;=\n+-*/<>'))
        self.vocab = {c: i for i, c in enumerate(chars)}
        self.vocab['<MASK>'] = len(chars)
        self.id_to_char = {i: c for c, i in self.vocab.items()}
        self.vocab_size = len(self.vocab)
    def encode(self, s):
        return [self.vocab.get(c, 0) for c in s]
    def decode(self, ids):
        return ''.join([self.id_to_char.get(i, '?') for i in ids])

tokenizer = CharTokenizer()
model = DiffusionTransformer(vocab_size=tokenizer.vocab_size, d_model=128, n_head=4, n_layer=4).to(device)
constraint = SymbolicConstraint()
sampler = EntropicSampler(model, tokenizer, constraint)
print(f'Model: {sum(p.numel() for p in model.parameters())/1e6:.2f}M params')

In [None]:
prompt = 'def add(a, b):\n    return '
prompt_ids = tokenizer.encode(prompt)
output = sampler.generate(prompt_ids, T_steps=20, output_len=20)
print(f'\nOutput:\n{tokenizer.decode(output[0].tolist())}')