In [2]:
## First, check to see if lightning is installed, if not, install it.
import pip
try:
  __import__("lightning")
except ImportError:
  pip.main(['install', "lightning"])  

import torch ## torch let's us create tensors and also provides helper functions
import torch.nn as nn ## torch.nn gives us nn.Module(), nn.Embedding() and nn.Linear()
import torch.nn.functional as F # This gives us the softmax() and argmax()
from torch.optim import Adam ## We will use the Adam optimizer, which is, essentially, 
                             ## a slightly less stochastic version of stochastic gradient descent.
from torch.utils.data import TensorDataset, DataLoader ## We'll store our data in DataLoaders

import lightning as L ## Lightning makes it easier to write, optimize and scale our code



In [3]:
from typing import List
import random
from torch.nn.utils.rnn import pad_sequence

from typing import List
import random

def generate_synthetic_data(num_data: int, max_digits: int) -> List[List[int]]:
    """
    Generate synthetic data for training the language model.

    Args: 
        num_data (int): Number of data points to generate.
        max_digits (int): Maximum number of digits (not directly used; kept for compatibility).

    Returns:
        List[List[int]]: List of sequences, where each sequence is a list of token IDs.
    """
    # Define the token-to-ID mapping
    token_to_id = {
        '0': 0, '1': 1, '2': 2, '3': 3, '4': 4,
        '5': 5, '6': 6, '7': 7, '8': 8, '9': 9,
        '.': 10,
        'or': 11,
        '<': 12,
        '>': 13,
        '=': 14,
        '[BOS]': 15,
        '[EOS]': 16,
        '[SEP]': 17,
        '[PAD]': 18
    }

    def generate_number():
        # Integer part: 0 to 9999 ensures up to 4 digits
        int_part = str(random.randint(0, 999))
        # Decide whether to include a decimal part (50% chance)
        if random.random() < 0.7:
            # Decimal digits: 1 or 2
            dec_digits = random.randint(1, 2)
            dec_part = ''.join(str(random.randint(0, 9)) for _ in range(dec_digits))
            return int_part + '.' + dec_part
        return int_part

    sequences = []
    for _ in range(num_data):
        # Generate two numbers
        num1_str = generate_number()
        num2_str = generate_number()
        
        # Convert to floats for comparison
        num1_float = float(num1_str)
        num2_float = float(num2_str)
        
        # Determine the operator
        if num1_float < num2_float:
            operator = '<'
        elif num1_float > num2_float:
            operator = '>'
        else:
            operator = '='
        
        # Construct the sequence
        input_part = f"[BOS]{num1_str}or{num2_str}[SEP]"
        output_part = f"{num1_str}{operator}{num2_str}[EOS]"
        sequence_str = input_part + output_part
        
        # Tokenize the sequence
        sequence_tokens = []
        i = 0
        while i < len(sequence_str):
            if sequence_str[i:i+5] in ('[BOS]', '[EOS]', '[SEP]'):
                sequence_tokens.append(sequence_str[i:i+5])
                i += 5
            elif sequence_str[i:i+2] == 'or':
                sequence_tokens.append('or')
                i += 2
            else:
                sequence_tokens.append(sequence_str[i])
                i += 1
        
        # Map tokens to IDs
        sequence_ids = [token_to_id[token] for token in sequence_tokens]
        sequences.append(sequence_ids)

    return sequences
token_to_id = {
        '0': 0, '1': 1, '2': 2, '3': 3, '4': 4,
        '5': 5, '6': 6, '7': 7, '8': 8, '9': 9,
        '.': 10,
        'or': 11,
        '<': 12,
        '>': 13,
        '=': 14,
        '[BOS]': 15,
        '[EOS]': 16,
        '[SEP]': 17,
        '[PAD]': 18
    }

id_to_token = dict(map(reversed, token_to_id.items()))

data = generate_synthetic_data(num_data=100000, max_digits=2)

print(data[0])
print("Generated sequence (text):", "".join([id_to_token[token_id] for token_id in data[0]]))

# Convert to tensors
data_tensors = [torch.tensor(seq, dtype=torch.long) for seq in data]

[15, 3, 2, 5, 10, 6, 6, 11, 9, 4, 8, 10, 7, 17, 3, 2, 5, 10, 6, 6, 12, 9, 4, 8, 10, 7, 16]
Generated sequence (text): [BOS]325.66or948.7[SEP]325.66<948.7[EOS]


In [4]:
train_size = int(0.9 * len(data_tensors))  # 90% for training
val_size = len(data_tensors) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(data_tensors, [train_size, val_size])

def collate_fn(batch):
    return pad_sequence(batch, batch_first=True, padding_value=0)

train_dataloader = DataLoader(train_dataset, batch_size=1024, shuffle=True, collate_fn=collate_fn, num_workers=4, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=1024, shuffle=False, collate_fn=collate_fn, num_workers=4, pin_memory=True)  

In [5]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 50):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(1, max_len, d_model)  # Shape: (1, max_len, d_model) for batch-first
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]  # Slice to seq_len, broadcast over batch
        return x

class Attention(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.d_model = d_model

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))

        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, v)
        return output


class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
      super().__init__()
      self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
      self.norm1 = nn.LayerNorm(d_model)
      self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
          )
      self.norm2 = nn.LayerNorm(d_model)
      self.dropout = nn.Dropout(dropout)


  def forward(self, x, mask):
    # Self-attention with residual connection and layer normalization
    attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
    x = self.norm1(x + self.dropout(attn_output))

    # Feedforward network with residual and layer normalization.
    ff_output = self.ff(x)
    x = self.norm2(x + self.dropout(ff_output))

    return x
    

class DecoderOnlyTransformer(L.LightningModule):
    def __init__(self, num_tokens: int = 19, d_model: int = 128, num_heads: int = 4, num_layers: int = 3, d_ff: int = 512, max_len: int = 100, dropout=0.1):
        super().__init__()
        L.seed_everything(seed=42)
        self.embedding = nn.Embedding(num_tokens, d_model)
        self.pos_encoding = PositionEncoding(d_model, max_len)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc_out = nn.Linear(d_model, num_tokens)
        self.dropout = nn.Dropout(dropout)
        self.loss = nn.CrossEntropyLoss(ignore_index=18)
        self.d_model = d_model

    def forward(self, x, mask):
      x = self.embedding(x) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32)) # Scale embeddings
      x = self.pos_encoding(x)
      x = self.dropout(x)
      for layer in self.decoder_layers:
          x = layer(x, mask)
      x = self.fc_out(x)
      return x


    def configure_optimizers(self):
        return Adam(self.parameters(), lr=3e-4)

    def training_step(self, batch, batch_idx):
        inputs = batch[:, :-1]
        labels = batch[:, 1:]
        mask = self.generate_square_subsequent_mask(inputs.size(1)).to(inputs.device)
        output = self.forward(inputs, mask)
        output = output.view(-1, output.size(-1))
        labels = labels.contiguous().view(-1)
        loss = self.loss(output, labels)
        self.log('train_loss', loss)
        # print("Batch:", batch_idx)
        # print("Loss: ", loss.item())
        return loss
    def validation_step(self, batch, batch_idx):
        inputs = batch[:, :-1]
        labels = batch[:, 1:]
        mask = self.generate_square_subsequent_mask(inputs.size(1)).to(inputs.device)
        output = self.forward(inputs, mask)
        output = output.view(-1, output.size(-1))
        labels_flat = labels.contiguous().view(-1)
        loss = self.loss(output, labels_flat)
        predictions = torch.argmax(output, dim=-1)
        correct_predictions = (predictions == labels_flat)
        non_padding_mask = (labels_flat != 0)
        accuracy = (correct_predictions & non_padding_mask).float().sum() / non_padding_mask.float().sum()
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', accuracy, on_step=False, on_epoch=True, prog_bar=True)
        # print("Validation Loss: ", loss.item())
        # print("Validation Accuracy:", accuracy.item())
        return loss

    def generate_square_subsequent_mask(self, sz: int) -> torch.Tensor:
        """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
        Unmasked positions are filled with float(0.0).
        """
        return torch.triu(torch.full((sz, sz), float('-inf')), diagonal=1)

    def generate(self, input_sequence, max_length=20):
      self.eval()  # Set the model to evaluation mode
      with torch.no_grad():  # Disable gradient calculation
          input_tensor = torch.tensor(input_sequence, dtype=torch.long).unsqueeze(0).to(self.device) # Add batch dimension and move to device
          output = input_tensor
          for _ in range(max_length):
                mask = self.generate_square_subsequent_mask(output.size(1)).to(self.device)
                predictions = self.forward(output, mask)
                # Get the last predicted token
                next_token_prediction = predictions[:, -1, :]
                # Get the most likely next token
                next_token = torch.argmax(next_token_prediction, dim=-1)
                # Concatenate the next token to the output sequence
                output = torch.cat([output, next_token.unsqueeze(0)], dim=1)

                # If the next token is [EOS], stop generating
                if next_token.item() == 16: # [EOS] token ID
                    break
      return output.squeeze(0).tolist() # remove the batch dimension

In [None]:

num_tokens = 19
d_model = 32
num_heads = 4
num_layers = 8
d_ff = 1024
max_len = 50
dropout = 0.1

# --- Model Creation ---
model = DecoderOnlyTransformer(num_tokens, d_model, num_heads, num_layers, d_ff, max_len, dropout)

# --- Training (using Lightning) ---
trainer = L.Trainer(
    max_epochs=10,  # More epochs
    accelerator="gpu",
    devices="auto",
    gradient_clip_val=0.5, 
    # accumulate_grad_batches = 2 # Add gradient clipping
    # profiler="simple", # uncomment for profiling
)
trainer.fit(model, train_dataloader, val_dataloader)  
trainer.save_checkpoint("/kaggle/working/base.ckpt")

In [None]:
# checkpoint_path = "/kaggle/working/start2.ckpt" 

# num_tokens = 19
# d_model = 64
# num_heads = 8
# num_layers = 8
# d_ff = 2048
# max_len = 50
# dropout = 0.1

# model = DecoderOnlyTransformer.load_from_checkpoint(
#     checkpoint_path,
#     num_tokens=num_tokens,  # Pass hyperparameters to reconstruct the model
#     d_model=d_model,
#     num_heads=num_heads,
#     num_layers=num_layers,
#     d_ff=d_ff,
#     max_len=max_len,
#     dropout=dropout
# )

## New RL Method with log(p) and log(1-p)

In [None]:
from tqdm import tqdm
class ImprovedRLFineTuner:
    def __init__(self, model, learning_rate=1e-5, gamma=0.99):
        """Initialize the improved RL fine-tuner."""
        self.model = model
        self.model.train()
        
        # Use a smaller learning rate for stability
        self.optimizer = Adam(self.model.parameters(), lr=learning_rate)
        self.gamma = gamma
        
        # Create token mappings for easier reference
        self.id_to_token = {
            0: '0', 1: '1', 2: '2', 3: '3', 4: '4',
            5: '5', 6: '6', 7: '7', 8: '8', 9: '9',
            10: '.', 11: 'or', 12: '<', 13: '>', 
            14: '=', 15: '[BOS]', 16: '[EOS]', 17: '[SEP]', 18: '[PAD]'
        }
        
        # Track metrics
        self.rewards_history = []
        self.accuracy_history = []
    
    def parse_numbers(self, sequence):
        """Extract the two numbers from a sequence for comparison."""
        # Convert sequence to string
        if isinstance(sequence, list) or isinstance(sequence, torch.Tensor):
            sequence = "".join([self.id_to_token.get(int(token), "") for token in sequence])
        
        # Find the position of 'or'
        or_pos = sequence.find('or')
        if or_pos == -1:
            return None, None
        
        # Find the position of [BOS] and [SEP]
        bos_pos = sequence.find('[BOS]')
        sep_pos = sequence.find('[SEP]')
        
        if bos_pos == -1 or sep_pos == -1:
            return None, None
        
        # Extract the numbers
        num1_str = sequence[bos_pos+5:or_pos]
        num2_str = sequence[or_pos+2:sep_pos]
        
        try:
            num1 = float(num1_str)
            num2 = float(num2_str)
            return num1, num2
        except ValueError:
            return None, None
    
    def compute_reward(self, generated_sequence, input_sequence):
        """Compute reward based on correct comparison operator."""
        # Parse numbers from input
        num1, num2 = self.parse_numbers(input_sequence)
        if num1 is None or num2 is None:
            return -1  # Invalid input
        
        # Determine correct operator
        correct_op = '<' if num1 < num2 else '>' if num1 > num2 else '='
        correct_op_id = 12 if correct_op == '<' else 13 if correct_op == '>' else 14
        
        # Find operator in generated sequence
        generated_op = None
        for token in generated_sequence:
            if token in [12, 13, 14]:  # < or > or =
                generated_op = token
                break
        
        # No operator found - large penalty
        if generated_op is None:
            return -1
        
        # Compare with correct operator
        if generated_op == correct_op_id:
            return 1.0  # Correct comparison
        else:
            return -1.0  # Incorrect comparison
    
    def train_one_epoch(self, dataloader):
        """Train for one epoch using both correct and incorrect examples with different weights."""
        total_reward = 0
        total_correct = 0
        total_samples = 0
        
        self.model.train()
        
        # Process batches
        for batch in tqdm(dataloader, desc="Training"):
            batch = batch.to(self.model.device)
            batch_size = batch.size(0)
            
            for i in range(batch_size):
                # Find the separator token ([SEP]) position
                sep_positions = (batch[i] == 17).nonzero(as_tuple=True)[0]
                if len(sep_positions) == 0:
                    continue  # Skip if no separator
                    
                sep_position = sep_positions[0].item()
                
                # Input is everything up to and including [SEP]
                input_seq = batch[i, :sep_position+1].unsqueeze(0)
                
                # 1. Generate the sequence with the current model
                with torch.no_grad():
                    generated = self.model.generate(input_seq[0].tolist())
                
                # 2. Compute reward
                reward = self.compute_reward(generated, batch[i].tolist())
                total_reward += reward
                
                # Check if prediction is correct
                is_correct = reward > 0
                if is_correct:
                    total_correct += 1
                total_samples += 1
                
                # 3. Create a modified target sequence with the correct operator
                num1, num2 = self.parse_numbers(batch[i].tolist())
                
                if num1 is not None and num2 is not None:
                    # Determine correct operator
                    correct_op_id = 12 if num1 < num2 else 13 if num1 > num2 else 14
                    
                    # Extract parts before and after the operator from the target
                    target_parts = []
                    op_found = False
                    
                    for token in batch[i].tolist():
                        if token in [12, 13, 14]:  # Found an operator
                            target_parts.append(correct_op_id)  # Replace with correct operator
                            op_found = True
                        else:
                            target_parts.append(token)
                    
                    if not op_found:
                        continue  # Skip if no operator in target
                    
                    # Create target tensor
                    target = torch.tensor(target_parts, device=self.model.device).unsqueeze(0)
                    
                    # 4. Compute loss with this target and update
                    # Create inputs for the model (everything except the last token)
                    inputs = target[:, :-1]
                    # Create targets (everything except the first token - shifted by 1)
                    targets = target[:, 1:]
                    
                    # Create appropriate mask
                    mask = self.model.generate_square_subsequent_mask(inputs.size(1)).to(self.model.device)
                    
                    # Forward pass
                    output = self.model(inputs, mask)
                    
                    # Custom loss computation
                    epsilon = 1e-8  # Small value to prevent log(0)
                    p = torch.softmax(output, dim=-1)  # [1, seq_len_inputs, vocab_size], probabilities
                    
                    # Get probabilities of correct tokens (targets)
                    p_ct = p.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)  # [1, seq_len_inputs]
                    
                    # Compute -log(p) for correct tokens
                    neg_log_p_ct = -torch.log(p_ct + epsilon)  # [1, seq_len_inputs]
                    
                    # Get the model's actual predictions (highest probability tokens)
                    predicted_tokens = torch.argmax(p, dim=-1)  # [1, seq_len_inputs]
                    
                    # Create a mask where predictions don't match targets
                    incorrect_mask = (predicted_tokens != targets)  # [1, seq_len_inputs]
                    
                    # Get probabilities of predicted tokens
                    p_pred = p.gather(dim=-1, index=predicted_tokens.unsqueeze(-1)).squeeze(-1)  # [1, seq_len_inputs]
                    
                    # Compute -log(1 - p) for incorrect predictions only
                    neg_log_one_minus_p_incorrect = torch.zeros_like(p_ct)  # [1, seq_len_inputs]
                    neg_log_one_minus_p_incorrect[incorrect_mask] = -torch.log(1 - p_pred[incorrect_mask] + epsilon)
                    
                    # Total loss per position: -log(p) for correct + -log(1-p) for incorrect predictions
                    loss_per_position = neg_log_p_ct + neg_log_one_minus_p_incorrect  # [1, seq_len_inputs]
                    
                    # Exclude padding tokens (ID 18) from loss
                    valid_mask = (targets != 18)  # [1, seq_len_inputs]
                    if valid_mask.sum() == 0:
                        continue  # Skip if all tokens are padding
                    
                    # Average loss over valid positions
                    loss = loss_per_position[valid_mask].mean()
                    
                    # Apply different scaling based on whether prediction was correct
                    if is_correct:
                        # For correct predictions, use a smaller weight
                        loss = loss * 0.2  # 5x smaller weight than incorrect predictions
                    else:
                        # For incorrect predictions, use a larger weight
                        loss = loss * abs(reward)  # Full weight based on reward magnitude
                    
                    # Update
                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.optimizer.step()
        
        # Calculate statistics
        avg_reward = total_reward / max(total_samples, 1)
        accuracy = total_correct / max(total_samples, 1)
        
        self.rewards_history.append(avg_reward)
        self.accuracy_history.append(accuracy)
        
        print(f"Epoch stats - Avg Reward: {avg_reward:.4f}, Accuracy: {accuracy:.4f}")
        return avg_reward, accuracy
    
    def evaluate(self, dataloader):
        """Evaluate the model on a validation set."""
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Evaluating"):
                batch = batch.to(self.model.device)
                
                for i in range(batch.size(0)):
                    # Find separator position
                    sep_positions = (batch[i] == 17).nonzero(as_tuple=True)[0]
                    if len(sep_positions) == 0:
                        continue
                        
                    sep_position = sep_positions[0].item()
                    
                    # Extract input sequence
                    input_seq = batch[i, :sep_position+1].tolist()
                    
                    # Extract numbers for comparison
                    num1, num2 = self.parse_numbers(batch[i].tolist())
                    if num1 is None or num2 is None:
                        continue
                        
                    # Determine correct operator
                    correct_op_id = 12 if num1 < num2 else 13 if num1 > num2 else 14
                    
                    # Generate prediction
                    try:
                        generated = self.model.generate(input_seq)
                        
                        # Find generated operator
                        gen_op = None
                        for token in generated:
                            if token in [12, 13, 14]:  # < or > or =
                                gen_op = token
                                break
                                
                        if gen_op is not None:
                            total += 1
                            if gen_op == correct_op_id:
                                correct += 1
                                
                    except Exception as e:
                        print(f"Error during evaluation: {e}")
        
        accuracy = correct / max(total, 1)
        print(f"Evaluation Results: {correct}/{total} correct = {accuracy:.4f} accuracy")
        return accuracy
    
    def fine_tune(self, train_dataloader, val_dataloader, num_epochs=3):
        """Fine-tune the model for multiple epochs."""
        # First evaluate the model
        print("Initial evaluation:")
        initial_accuracy = self.evaluate(val_dataloader)
        
        best_accuracy = initial_accuracy
        best_model_state = self.model.state_dict().copy()
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch+1}/{num_epochs}:")
            
            # Train for one epoch
            reward, train_acc = self.train_one_epoch(train_dataloader)
            
            # Evaluate
            val_accuracy = self.evaluate(val_dataloader)
            
            # Save best model
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                best_model_state = self.model.state_dict().copy()
                print(f"New best model with accuracy: {best_accuracy:.4f}")
        
        # Restore best model
        self.model.load_state_dict(best_model_state)
        print(f"Fine-tuning complete. Best accuracy: {best_accuracy:.4f}")
        return best_accuracy
    
    def save_model(self, path):
        """Save the fine-tuned model."""
        torch.save(self.model.state_dict(), path)

# Run improved fine-tuning
def run_improved_finetuning():
    # Start with a fresh copy of the original pre-trained model
    model = DecoderOnlyTransformer.load_from_checkpoint(
        "/kaggle/working/base.ckpt",
       num_tokens=19, 
        d_model=32, 
        num_heads=4, 
        num_layers=8, 
        d_ff=1024, 
        max_len=50, 
        dropout=0.1
    )
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Create datasets with a variety of comparison examples
    print("Generating training data...")
    train_data = generate_synthetic_data(num_data=5000, max_digits=3)
    train_tensors = [torch.tensor(seq, dtype=torch.long) for seq in train_data]
    train_dataloader = DataLoader(train_tensors, batch_size=64, shuffle=True, collate_fn=collate_fn)
    
    print("Generating validation data...")
    val_data = generate_synthetic_data(num_data=500, max_digits=3)
    val_tensors = [torch.tensor(seq, dtype=torch.long) for seq in val_data]
    val_dataloader = DataLoader(val_tensors, batch_size=64, shuffle=False, collate_fn=collate_fn)
    
    # Create fine-tuner
    finetuner = ImprovedRLFineTuner(model, learning_rate=5e-6)
    
    # Fine-tune
    finetuner.fine_tune(train_dataloader, val_dataloader, num_epochs=3)
    
    # Save model
    finetuner.save_model("/kaggle/working/true_rl.pt")
    
    # Test examples
    test_examples = [
        # [BOS]212.91or211.19[SEP]
        [15, 2, 1, 2, 10, 9, 1, 11, 2, 1, 1, 10, 1, 9, 17],
        
        # [BOS]8.1or2.95[SEP]
        [15, 8, 10, 1, 11, 2, 10, 9, 5, 17],
        
        # [BOS]125or115[SEP]
        [15, 1, 2, 5, 11, 1, 1, 5, 17],
        
        # [BOS]21.01or21.02[SEP]
        [15, 2, 1, 10, 0, 1, 11, 2, 1, 10, 0, 2, 17]
    ]
    
    # Test the model on examples
    print("\nTesting the improved fine-tuned model:")
    for i, input_sequence in enumerate(test_examples):
        generated_sequence = model.generate(input_sequence)
        
        # Print results
        id_to_token = finetuner.id_to_token
        input_text = "".join([id_to_token[token_id] for token_id in input_sequence])
        generated_text = "".join([id_to_token[token_id] for token_id in generated_sequence])
        
        # Parse numbers for expected result
        num1, num2 = finetuner.parse_numbers(input_sequence)
        if num1 is not None and num2 is not None:
            op = '<' if num1 < num2 else '>' if num1 > num2 else '='
            expected = f"{num1}{op}{num2}"
        else:
            expected = "Unable to parse"
        
        print(f"Example {i+1}:")
        print(f"Input: {input_text}")
        print(f"Generated: {generated_text}")
        print(f"Expected: {expected}")
        print("-" * 50)
    
    return model

# Run the improved fine-tuning
improved_model = run_improved_finetuning()

INFO: Seed set to 42


Generating training data...
Generating validation data...
Initial evaluation:


Evaluating: 100%|██████████| 8/8 [00:28<00:00,  3.53s/it]


Evaluation Results: 301/500 correct = 0.6020 accuracy

Epoch 1/3:


Training:  57%|█████▋    | 45/79 [03:34<02:41,  4.74s/it]

### RL

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def train_one_epoch(self, dataloader):
    """
    Train the model for one epoch using policy gradient (REINFORCE).
    
    Args:
        dataloader: DataLoader providing batches of input sequences.
    
    Returns:
        avg_reward (float): Average reward over the epoch.
        accuracy (float): Fraction of correct predictions.
    """
    self.model.train()
    total_reward = 0.0
    total_correct = 0
    total_samples = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        batch = batch.to(self.model.device)
        batch_size = batch.size(0)
        loss = 0.0
        
        for i in range(batch_size):
            # Extract input sequence up to [SEP]
            sep_positions = (batch[i] == 17).nonzero(as_tuple=True)[0]
            if len(sep_positions) == 0:
                continue  # Skip if [SEP] is not found
            sep_position = sep_positions[0].item()
            input_seq = batch[i, :sep_position + 1].clone()  # e.g., [BOS] num1 or num2 [SEP]
            current_seq = input_seq.clone()
            
            log_probs = []  # Store log probabilities of sampled tokens
            generated = []  # Store generated tokens
            max_generation_len = 10  # Maximum tokens to generate after [SEP]
            
            # Generate sequence autoregressively by sampling
            while len(generated) < max_generation_len and (len(generated) == 0 or generated[-1] != 16):
                # Create causal mask
                mask = self.model.generate_square_subsequent_mask(current_seq.size(0)).to(self.model.device)
                
                # Forward pass to get logits for the next token
                output = self.model(current_seq.unsqueeze(0), mask)
                logits = output[0, -1, :]  # Logits for the last position
                
                # Convert logits to log probabilities
                log_probs_t = F.log_softmax(logits, dim=-1)
                probs = log_probs_t.exp()
                
                # Sample a token from the policy
                token = torch.multinomial(probs, num_samples=1).item()
                
                # Record the log probability of the sampled token
                selected_log_prob = log_probs_t[token]
                log_probs.append(selected_log_prob)
                
                # Append token to generated sequence
                generated.append(token)
                
                # Update current sequence for the next iteration
                current_seq = torch.cat([current_seq, torch.tensor([token], device=self.model.device)], dim=0)
            
            # Construct the full generated sequence
            full_generated_seq = input_seq.tolist() + generated
            
            # Compute reward using the existing method
            R = self.compute_reward(full_generated_seq, full_generated_seq)
            total_reward += R
            if R > 0:
                total_correct += 1
            total_samples += 1
            
            # Compute policy gradient loss for this sequence
            if log_probs:
                total_log_prob = sum(log_probs)  # Sum log probs over the sequence
                loss_i = -total_log_prob * R     # Negative because we maximize reward
                loss += loss_i
        
        # Average loss over the batch and update model
        if total_samples > 0:
            loss = loss / total_samples
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  # Gradient clipping
            self.optimizer.step()
    
    # Compute epoch statistics
    avg_reward = total_reward / max(total_samples, 1)
    accuracy = total_correct / max(total_samples, 1)
    
    self.rewards_history.append(avg_reward)
    self.accuracy_history.append(accuracy)
    
    print(f"Epoch stats - Avg Reward: {avg_reward:.4f}, Accuracy: {accuracy:.4f}")
    return avg_reward, accuracy