Author: Steve Bischoff

This notebook contains code defining the first-order language, randomly constructing an interpretation, building a training dataset of true well-formed formulae, training a decoder-only transformer on next-word prediction, and tracking ambiguity statistics throughout training. The results of only one round of the experiment are printed here, though the code can easily be modified to run more rounds.

The code for the model itself is adapted with minor changes from: https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py

## Contents

 1. [Language Preliminaries](#language)
 2. [WFF Generation Functions](#generation)
 3. [WFF Evaluation Functions](#evaluation) 
 4. [Functions to Gather Statistics](#stats)
 5. [Model](#model)
 6. [Experiment](#experiment)

In [1]:
## Imports
# built-ins
import copy
import pickle
import random
import string
# packages
import numpy as np
# PyTorch
import torch
import torch.nn as nn
from torch.nn import functional as F

## 1. Language Preliminaries <a name="language"></a>

As discussed in the paper, the language is a first-order language containing connectives, one-place predicates, and constants. It (rather arbitrarily) contains 9 predicates {F, ..., N}. 

We model an ambiguous technical word using the predicate F. We make the predicate ambiguous between a "vernacular" meaning and a "technical" meaning. The vernacular meaning is generally broader, hard-coded so that F is true of constants {a, ..., g} on the vernacular meaning but false on the technical meaning. Since technical meanings do often extend beyond vernacular meanings, we make F true of constant z on the technical meaning but not the vernacular. For every other constant, we randomly assign the same truth value on both meanings.

In [2]:
# Define language base
connectives = ['~', '&', '>', 'V'] # negation, conjunction, conditional, disjunction
names = [i for i in string.ascii_lowercase]
unambiguous_predicates = ['G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']
ambiguous_predicates = ['F']
hidden_predicates = ['X']
# F is actually ambiguous between "vernacular" F and "technical" X.
# Manually set truth values of some ambiguous monads
true_ambiguous_monads = ['Fa', 'Fb', 'Fc', 'Fd', 'Fe', 'Ff', 'Fg']
false_ambiguous_monads = ['Fz']
ambiguous_monads = true_ambiguous_monads + false_ambiguous_monads

In [3]:
pred_mult = 6 # odds of dominant interpretation of ambiguous predicate: (pred_mult-1):1
# So that F appears with roughly equal frequency as the other predicates
predicates = unambiguous_predicates + ambiguous_predicates + hidden_predicates

'!' is the start character and ' ' is the stop character.

In [4]:
chars = ['!'] + [i for i in predicates if i not in hidden_predicates] + names + connectives + [' ']
vocab_size = len(chars)
print(chars, vocab_size)

stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda x: [[stoi[c] for c in s] for s in x] # encoder: take a string, output a list of integers
decode = lambda x: ''.join([itos[i] for i in x]) # decoder: take a list of integers, output a string

['!', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'F', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '~', '&', '>', 'V', ' '] 41


## 2. WFF Generation Functions<a name="generation"></a>

generate_monads() constructs an interpretation for the language by randomly assigning each predicate-constant pair a truth value (1/3 probability of true), except for the hard-coded truth values mentioned above.

In [5]:
def generate_monads():
    all_monads = []
    true_monads = []
    false_monads = []
    selection_odds = []
    for pred in unambiguous_predicates:  
        for name in names:  
            pos = pred+name
            if random.choice([0, 0, 1]) == 1:
                true_monads += [pos]
            else:
                false_monads += [pos]
            all_monads += [pos]
            selection_odds += [pred_mult]
    for pred in ambiguous_predicates:    
        for name in names:  
            pos = pred+name
            if pos in true_ambiguous_monads: # hard-coded "vernacular" meanings of ambiguous predicates
                true_monads += [pos]
                all_monads += [pos]
                selection_odds += [pred_mult - 1]
            elif pos in false_ambiguous_monads: # hard-coded "vernacular" meanings of ambiguous predicates
                false_monads += [pos]
                all_monads += [pos]
                selection_odds += [pred_mult - 1]
            elif random.choice([0, 0, 1]) == 1:
                true_monads += [pos]
                all_monads += [pos]
                selection_odds += [pred_mult]
            else:
                false_monads += [pos]
                all_monads += [pos]
                selection_odds += [pred_mult]
    for i in range(len(hidden_predicates)): # hard-coded "technical" meanings of ambiguous predicates
        hidden = hidden_predicates[i]
        ambig = ambiguous_predicates[i]
        for name in names:
            pos_ambig = ambig+name
            pos_hidden = hidden+name
            if pos_ambig in true_ambiguous_monads:
                false_monads += [pos_hidden]
                all_monads += [pos_hidden]
                selection_odds += [1]
            elif pos_ambig in false_ambiguous_monads:
                true_monads += [pos_hidden]
                all_monads += [pos_hidden]
                selection_odds += [1]
                
    return all_monads, true_monads, false_monads, selection_odds

generate_prefix() randomly generates WFFs with anywhere from 0 to *max_depth* connectives. The generated WFFs may be either true or false.

In [6]:
def generate_prefix(max_depth, monads):
    depth = 0
    monad_idx = 0
    count = 1
    exp = ''
    while depth < max_depth:
        prob = 0.04 * (depth + 2)
        if np.random.uniform() < prob:
            exp += monads[monad_idx]
            monad_idx += 1
            count -= 1
            break
        if count == 1: # connective forced
            conn = random.choice(connectives)
            exp += conn
            if conn != '~':
                count += 1
            depth += 1
        else: # randomly add connective or monad
            if random.choice([0, 1]) == 0: # connective
                conn = random.choice(connectives)
                exp += conn
                if conn != '~':
                    count += 1
                depth += 1
            else: # monad
                exp += monads[monad_idx]
                monad_idx += 1
                count -= 1
    while count > 0:
        exp += monads[monad_idx]
        monad_idx += 1
        count -= 1
    return exp

In [7]:
def generate_true_wffs(n, max_depth):
    true_wffs = []
    n_monads = max_depth + 1 # max per wff
    monad_array_size = n*n_monads
    monad_array = np.random.choice(all_monads, size=monad_array_size, p=selection_probs)
    monad_idx = 0
    while len(true_wffs) < n:
        monad_array_slice = monad_array[monad_idx:monad_idx+n_monads]
        wff = generate_prefix(max_depth, monad_array_slice)
        tv = evaluate_prefix(wff)
        
        if tv == 1:
            wff = wff.replace('X', 'F') # ambiguity
            true_wffs.append(wff)
            
        monad_idx += n_monads
        if monad_idx >= monad_array_size:
            monad_array = np.random.choice(all_monads, size=monad_array_size, p=selection_probs)
            monad_idx = 0

    return true_wffs

## 3. WFF Evaluation Functions<a name="evaluation"></a>

In [8]:
def check_wff(expression):
    """Evaluate whether a string is a WFF"""
    count = 1
    empty_predicate = False
    expression_len = len(expression)
    for i in range(expression_len):
        symbol = expression[i]
        if symbol in predicates:
            if empty_predicate:
                return False
            temp_predicate = symbol
            empty_predicate = True
        elif symbol in names:
            if not empty_predicate:
                return False
            else:
                empty_predicate = False
                temp_exp = temp_predicate + symbol
                count -= 1
        elif empty_predicate:
            return False
        elif symbol in [' ', '!']:
            return False
        elif symbol == '~':
            pass
        else: # symbol in binary connectives
            count += 1
        
        if count == 0:
            if i+1 < expression_len:
                return False
        
    if count == 0:
        return True
    else:
        return False

In [9]:
def evaluate_prefix(expression): # assumes wff
    """Evaluate the truth value of an unambiguous wff"""
    stack = []
    for i in reversed(expression):
        if i in names:
            temp_name = i
        elif i in predicates:
            stack.append(tv_dict[i + temp_name])
        elif i in connectives:
            if i == '~': # one-place connective
                tv = stack.pop()
                stack.append(1-tv)
            else:
                tv1 = stack.pop()
                tv2 = stack.pop()
                if i == '&': # two-place connective
                    stack.append(tv1*tv2)
                elif i == 'V':
                    stack.append(max(tv1, tv2))
                elif i == '>':
                    if tv1 == 0 or tv2 == 1:
                        stack.append(1)
                    else:
                        stack.append(0)  
    assert len(stack) == 1, 'Stack too big'
    return stack[0]


def evaluate_ambiguous_prefix(expression): # assumes wff
    """Evaluate the truth value of an ambiguous wff"""
    # check whether we need to evaluate an ambiguous expression
    is_ambiguous = False
    for monad in ambiguous_monads:
        if monad in expression:
            is_ambiguous = True
            break
            
    if is_ambiguous:
        stacks = [[]]
        for i in reversed(expression):
            if i in names:
                temp_name = i
            elif i in predicates:
                pos = i + temp_name
                if pos in ambiguous_monads:
                    stacks_temp = []
                    for stack in stacks:
                        stacks_temp += [stack + [1], stack + [0]]
                    stacks = stacks_temp
                else:
                    for stack in stacks:
                        stack.append(tv_dict[pos])
            elif i in connectives:
                for stack in stacks:
                    if i == '~': # one-place connective
                        tv = stack.pop()
                        stack.append(1-tv)
                    else: # two-place connective
                        tv1 = stack.pop()
                        tv2 = stack.pop()
                        if i == '&':
                            stack.append(tv1*tv2)
                        elif i == 'V':
                            stack.append(max(tv1, tv2))
                        elif i == '>':
                            if tv1 == 0 or tv2 == 1:
                                stack.append(1)
                            else:
                                stack.append(0) 
    else: # unambiguous
        stacks = [[evaluate_prefix(expression)]]
    return stacks

## 4. Functions to Gather Statistics<a name="stats"></a>

In [10]:
def update_ambiguity_stats(tv, n_true_wffs, n_ambiguous_wffs):
    """helper function"""
    if len(tv) == 1: # unambiguous
        n_true_wffs += tv[0][0]
    else:
        tvs = [i[0] for i in tv]
        if tvs[0] == 1:                        
            n_true_wffs += 1
            if tvs[-1] == 1: # true on both meanings
                n_ambiguous_wffs[0] += 1
            else:
                n_ambiguous_wffs[2] += 1 # true only on 1st meaning                     
        elif tvs[-1] == 1: # in this case, tvs[0] == 0
            n_true_wffs += 1
            n_ambiguous_wffs[3] += 1
        elif sum(tvs) == 0: # false on all meaning combinations
            n_ambiguous_wffs[1] += 1
        else: # true only on mixed meanings
            n_true_wffs += 1
            n_ambiguous_wffs[4] += 1 
            
    return n_true_wffs
            
def true_wff_stats(true_wffs, iters): 
    """Takes ambiguity stats in a sample from a target population."""
    n_true_wffs = 0 # ambiguous truth value
    # true on either meaning, false on all meanings, true only on 1st meaning, true only on 2nd meaning, true on a mix
    n_ambiguous_wffs = [0, 0, 0, 0, 0]
    for i in range(iters):
        gen = true_wffs[i]
        tv = evaluate_ambiguous_prefix(gen) # 0/1
        n_true_wffs = update_ambiguity_stats(tv, n_true_wffs, n_ambiguous_wffs)
         
    return n_true_wffs/iters, [i/iters for i in n_ambiguous_wffs]

In [11]:
def model_generation_stats(m, iters, batch_size=100):
    n_wffs, n_true_wffs = 0, 0 
    # true on either meaning, false on all meanings, true only on 1st meaning, true only on 2nd meaning, true on a mix
    n_ambiguous_wffs = [0, 0, 0, 0, 0]
    
    context = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
    m.eval() # set evaluation mode
    for _ in range(iters//batch_size):
        for gen in m.generate(context, max_new_tokens=block_size):
            gen = decode(gen.tolist()[1:]).strip()
            is_wff = check_wff(gen)
            if is_wff:
                n_wffs += 1
                try:
                    tv = evaluate_ambiguous_prefix(gen) # 0/1
                except: # sometimes generates ' '
                    continue               
                n_true_wffs = update_ambiguity_stats(tv, n_true_wffs, n_ambiguous_wffs)     
            else:
                pass
            
    return n_wffs/iters, n_true_wffs/iters, [i/iters for i in n_ambiguous_wffs]

## 5. Model<a name="model"></a>

In [12]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [13]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data), (batch_size,))
    x = torch.stack([data[i][:-1] for i in ix])
    y = torch.stack([data[i][1:] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [14]:
@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

## 6. Experiment<a name="experiment"></a>

In [15]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
max_iters = 250000
eval_interval = max_iters//25
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 288 # 384
n_head = 6
n_layer = 2
dropout = 0.2

stats_iters = 100000 # number of sentences to generate when gathering model stats

In [16]:
for seed in range(19, 20):
    print('Seed', seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    all_monads, true_monads, false_monads, selection_odds = generate_monads()

    tv_dict = {exp: 1 for exp in true_monads}
    for exp in false_monads:
        tv_dict[exp] = 0

    odds_sum = sum(selection_odds)
    selection_probs = [odds/odds_sum for odds in selection_odds]
    selection_idx = [i for i in range(len(selection_probs))]

    true_wffs = []
    for i in range(10):
        true_wffs += generate_true_wffs(1000000, 3) 
    max_length = max([len(i) for i in true_wffs])
    block_size = max_length # what is the maximum context length for predictions?

    seed_true_stats = true_wff_stats(true_wffs, 1000000) 
    print(seed_true_stats)

    true_wffs = ['!' + wff + ' '*(max_length - len(wff)) for wff in true_wffs]
    text = true_wffs

    data = torch.tensor(encode(text), dtype=torch.long)
    n = int(0.9*len(data)) # first 90% will be train, rest val
    train_data = data[:n]
    val_data = data[n:]

    model = GPTLanguageModel()
    m = model.to(device)
    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

    seed_gen_stats = []
    for iteration in range(max_iters):
        # every once in a while evaluate the loss on train and val sets
        if iteration % eval_interval == 0 or iteration == max_iters - 1:
            if iteration > 0:
                iteration_stats = (iteration, model_generation_stats(m, stats_iters, batch_size=stats_iters//50))
                print('Seed {}. Iter {}: {}'.format(seed, iteration, iteration_stats[1]))
                seed_gen_stats.append(iteration_stats)
            m.train() # set training mode

        # sample a batch of data
        xb, yb = get_batch('train')

        # evaluate the loss
        logits, loss = m(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    
    stats_dict = {'true': seed_true_stats, 'generated': seed_gen_stats}
    
    with open('stats/training_stats_{}.pkl'.format(seed), 'wb') as f:
        pickle.dump(stats_dict, f)

Seed 19
(1.0, [0.061138, 0.0, 0.031741, 0.005855, 7e-05])
Seed 19. Iter 10000: (0.99678, 0.80827, [0.05594, 0.00648, 0.02146, 0.01184, 0.00015])
Seed 19. Iter 20000: (0.99846, 0.88683, [0.06085, 0.00431, 0.02363, 0.00905, 0.0001])
Seed 19. Iter 30000: (0.99868, 0.9035, [0.06732, 0.00352, 0.02603, 0.00921, 9e-05])
Seed 19. Iter 40000: (0.99897, 0.90868, [0.0628, 0.00309, 0.02603, 0.00837, 0.0001])
Seed 19. Iter 50000: (0.99909, 0.92669, [0.06997, 0.00242, 0.02872, 0.00892, 0.00011])
Seed 19. Iter 60000: (0.99912, 0.9217, [0.06602, 0.00304, 0.02586, 0.0092, 0.00012])
Seed 19. Iter 70000: (0.99906, 0.94192, [0.07222, 0.00207, 0.0276, 0.00814, 8e-05])
Seed 19. Iter 80000: (0.9993, 0.96835, [0.06285, 0.00108, 0.02682, 0.00598, 5e-05])
Seed 19. Iter 90000: (0.99939, 0.97802, [0.06247, 0.0007, 0.02732, 0.005, 6e-05])
Seed 19. Iter 100000: (0.99931, 0.9849, [0.06577, 0.00059, 0.02867, 0.00517, 0.00015])
Seed 19. Iter 110000: (0.99934, 0.98804, [0.06427, 0.0004, 0.02661, 0.00555, 7e-05])
Seed 1