In [1]:
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set the device to MPS if available, otherwise CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", padding_side="left")
llm_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B").to(device)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
llm_model.resize_token_embeddings(len(tokenizer))

Embedding(128256, 2048)

## Generate experience dataset

In [22]:
import Levenshtein

from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# Function to compute BLEU score reward with smoothing
def bleu_reward(generated_answer, correct_answer):
    reference = [correct_answer.split()]  # BLEU expects a list of references
    candidate = generated_answer.split()
    
    # Use SmoothingFunction to handle cases with low n-gram overlap
    smoothing_function = SmoothingFunction().method1  # Method 1 is a common choice for smoothing
    
    score = sentence_bleu(reference, candidate, smoothing_function=smoothing_function)
    return score  # BLEU score with smoothing, between 0 and 1


# Load BERT or similar model for embedding
similarity_model = AutoModel.from_pretrained("google/flan-t5-base").to(device)
similarity_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

# Function to compute semantic similarity reward
def semantic_similarity_reward(generated_answer, correct_answer):
    # Tokenize and embed both answers
    inputs_gen = similarity_tokenizer(generated_answer, return_tensors="pt", padding=True, truncation=True).to(device)
    inputs_corr = similarity_tokenizer(correct_answer, return_tensors="pt", padding=True, truncation=True).to(device)
    
    with torch.no_grad():
        embedding_gen = similarity_model(**inputs_gen).last_hidden_state.mean(dim=1)  # Mean pooling
        embedding_corr = similarity_model(**inputs_corr).last_hidden_state.mean(dim=1)

    # Compute cosine similarity
    similarity = F.cosine_similarity(embedding_gen, embedding_corr).item()
    return similarity  # Reward is between 0 and 1


# Function to compute Levenshtein similarity reward
def levenshtein_reward(generated_answer, correct_answer):
    distance = Levenshtein.distance(generated_answer, correct_answer)
    max_len = max(len(generated_answer), len(correct_answer))
    reward = 1 - (distance / max_len)  # Normalized to be between 0 and 1
    return reward

# Define a helper function to structure each experience with combined rewards
def create_experience(question, correct_answer, generated_answer):
    state = question
    action = generated_answer
    next_state = question + " -> " + generated_answer

    # Compute rewards
    semantic_reward = semantic_similarity_reward(generated_answer, correct_answer)
    bleu_reward_score = bleu_reward(generated_answer, correct_answer)
    levenshtein_reward_score = levenshtein_reward(generated_answer, correct_answer)

    # Combined reward (weighted average, you can adjust weights based on relevance)
    reward = (0.4 * semantic_reward) + (0.3 * bleu_reward_score) + (0.3 * levenshtein_reward_score)

    return (state, action, next_state, reward)

In [32]:
# Load the GSM8K dataset
gsm8k = load_dataset("gsm8k", "main", split="train")

# Define a batch size
batch_size = 8  # Adjust batch size based on available memory

# Prepare a DataLoader for batching
class GSM8KDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        return example["question"], example["answer"]

dataset = GSM8KDataset(gsm8k)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Create a list to store experiences
experiences = []

# Process each batch of question-answer pairs
for batch in tqdm(dataloader, desc="Processing batches"):
    questions, correct_answers = batch

    # Ensure pad_token_id is set
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # Set pad_token to eos_token if not defined

    # Encode batch inputs with padding and attention mask
    inputs = tokenizer(list(questions), return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Generate answers in batch mode
    with torch.no_grad():
        outputs = llm_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=50,
            pad_token_id=tokenizer.pad_token_id
        )

    # Decode each generated answer and create experiences
    generated_answers = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    for question, correct_answer, generated_answer in zip(questions, correct_answers, generated_answers):
        experience = create_experience(question, correct_answer, generated_answer)
        experiences.append(experience)

    break

Processing batches:   0%|          | 0/935 [00:07<?, ?it/s]


In [None]:
print(experiences[0][0])
print(experiences[0][1])


Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? A) 16 B) 24 C) 32 D) 48 E) 64
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.distributions as dist

# Set the device to MPS if available, otherwise CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
llm_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
llm_model.resize_token_embeddings(len(tokenizer))

# Define the World Model with RNN layer to aggregate sequence information
class WorldModel(nn.Module):
    def __init__(self, llm_model, state_dim=768):
        super(WorldModel, self).__init__()
        self.llm_model = llm_model
        self.projection = nn.Linear(llm_model.config.vocab_size, state_dim)  # Project vocab_size to state_dim
        self.rnn = nn.GRU(state_dim, state_dim, batch_first=True)

    def forward(self, input_ids, attention_mask):
        # Ensure input tensors are on the correct device
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        outputs = self.llm_model(input_ids=input_ids, attention_mask=attention_mask)
        
        # Project the logits to the state_dim for compatibility with the GRU
        projected_logits = self.projection(outputs.logits)  # Shape: [batch_size, sequence_length, state_dim]
        
        # Pass the projected logits through the RNN to obtain the state
        _, state = self.rnn(projected_logits)
        return state.squeeze(0)  # Shape: [batch_size, state_dim]

# Define the Actor and Critic Models
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Actor, self).__init__()
        self.fc = nn.Linear(state_dim, action_dim)

    def forward(self, state):
        return torch.softmax(self.fc(state), dim=-1)

class Critic(nn.Module):
    def __init__(self, state_dim):
        super(Critic, self).__init__()
        self.fc = nn.Linear(state_dim, 1)

    def forward(self, state):
        return self.fc(state)

# Instantiate Models
world_model = WorldModel(llm_model).to(device)
actor = Actor(state_dim=768, action_dim=tokenizer.vocab_size).to(device)
critic = Critic(state_dim=768).to(device)

# Set up optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)
critic_optimizer = optim.Adam(critic.parameters(), lr=3e-4)

def sample_action(action_probs, epsilon=1e-10):
    # Ensure action_probs are valid probabilities by adding epsilon and normalizing
    action_probs = torch.clamp(action_probs, min=epsilon)  # Ensure no zero values
    action_probs = action_probs / action_probs.sum(dim=-1, keepdim=True)  # Normalize to sum to 1
    
    # Sample from the action probabilities
    action_dist = dist.Categorical(action_probs)
    action = action_dist.sample()
    return action



# Imagined Trajectories for Sequence-to-Sequence
def imagine_trajectories(world_model, actor, tokenizer, initial_input_ids, attention_mask, trajectory_length=5):
    trajectories = []
    current_input_ids = initial_input_ids.to(device)
    current_attention_mask = attention_mask.to(device)

    for _ in range(trajectory_length):
        state = world_model(current_input_ids, current_attention_mask)
        action_probs = actor(state)
        action = sample_action(action_probs)
        trajectories.append(action)
        action = action.unsqueeze(-1)
        current_input_ids = torch.cat([current_input_ids, action], dim=1)
        current_attention_mask = torch.cat([current_attention_mask, torch.ones_like(action)], dim=1)

    return trajectories


# Decoding function for Seq2Seq
def decode_trajectories(trajectories, tokenizer):
    decoded_texts = []
    for trajectory in trajectories:
        trajectory_tokens = trajectory.tolist()
        decoded_text = tokenizer.decode(trajectory_tokens, skip_special_tokens=True)
        decoded_texts.append(decoded_text)
    return decoded_texts



In [None]:
from tqdm import tqdm

num_epochs = 10

# Reward Function for Seq2Seq
def compute_reward(predicted_seq, target_seq):
    lev_distance = levenshtein_distance(predicted_seq, target_seq)
    max_len = max(len(target_seq), len(predicted_seq))
    return 1.0 - (lev_distance / max_len)

# Training loop with tqdm progress bars
for epoch in range(num_epochs):
    epoch_progress = tqdm(gsm8k, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
    
    for example in epoch_progress:
        prompt = example["question"]  # Get the math problem
        answer = example["answer"]    # Get the correct solution

        # Encode prompt and move to device
        inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        # Generate real trajectory based on the prompt
        real_state = world_model(input_ids, attention_mask)
        action_probs = actor(real_state)
        predicted_ids = torch.argmax(action_probs, dim=-1)  # Greedy decoding for real experience
        predicted_text = tokenizer.decode(predicted_ids.tolist(), skip_special_tokens=True)

        # Calculate reward for the real trajectory
        reward = compute_reward(predicted_text, answer)

        # Imagined trajectories
        imagined_trajectories = imagine_trajectories(world_model, actor, tokenizer, input_ids, attention_mask)
        decoded_imagined_trajectories = decode_trajectories(imagined_trajectories, tokenizer)

        # Calculate imagined rewards and use for Actor-Critic updates
        for imagined_text in decoded_imagined_trajectories:
            imagined_reward = compute_reward(imagined_text, answer)
            imagined_reward_tensor = torch.tensor([imagined_reward], device=device, requires_grad=True)

            # Critic update
            critic_value = critic(real_state.detach())  # Detach real_state to prevent retaining the graph
            critic_loss = nn.MSELoss()(critic_value, imagined_reward_tensor)
            critic_optimizer.zero_grad()
            critic_loss.backward(retain_graph=False)  # Ensure graph is not retained
            critic_optimizer.step()

            # Actor update
            actor_loss = -imagined_reward_tensor.mean()  # Maximize reward by negating
            actor_optimizer.zero_grad()
            actor_loss.backward(retain_graph=False)  # Ensure graph is not retained
            actor_optimizer.step()

        # Update progress bar description with loss values
        epoch_progress.set_postfix({
            "Reward": reward,
            "Actor Loss": actor_loss.item(),
            "Critic Loss": critic_loss.item()
        })

    print(f"Epoch {epoch+1} complete.")


In [None]:
def generate_sequence(model, tokenizer, prompt, max_length=50, top_p=0.9):
    """
    Generate a meaningful text sequence based on a given prompt using MPS.
    
    Args:
        model: The language model used for generation (e.g., world_model or actor).
        tokenizer: Tokenizer corresponding to the model.
        prompt: Initial text prompt to start generation.
        max_length: Maximum length of the generated sequence.
        top_p: Nucleus sampling threshold for diversity.
        
    Returns:
        generated_text: The complete generated text.
    """
    # Set device to MPS if available, otherwise CPU
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    # Encode the initial prompt and move to device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    attention_mask = torch.ones_like(input_ids).to(device)
    
    # Start with empty generated sequence
    generated_sequence = input_ids.clone()

    # Use torch.no_grad() to save memory during inference
    with torch.no_grad():
        for _ in range(max_length):
            # Generate model outputs for the current input sequence
            outputs = model(input_ids=generated_sequence, attention_mask=attention_mask)

            # Extract the logits for the last generated token and move to CPU for processing
            next_token_logits = outputs.logits[:, -1, :].detach().cpu()

            # Apply nucleus (top-p) sampling to filter tokens for diversity
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

            # Create a mask to remove tokens with cumulative probability above top_p
            sorted_indices_to_keep = cumulative_probs <= top_p
            valid_indices = sorted_indices[sorted_indices_to_keep]

            # Sample from the filtered logits
            if valid_indices.size(0) > 0:
                # Sample only from valid indices
                sampled_index = torch.multinomial(torch.softmax(sorted_logits[sorted_indices_to_keep], dim=-1), num_samples=1)
                next_token = valid_indices[sampled_index].to(device)  # Move back to MPS
            else:
                # If no valid indices, fall back to argmax
                next_token = torch.argmax(next_token_logits).unsqueeze(0).to(device)

            # Ensure next_token has compatible dimensions
            next_token = next_token.view(1, 1)

            # Append the new token to the generated sequence
            generated_sequence = torch.cat([generated_sequence, next_token], dim=-1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)

            # Stop if the end-of-sequence token is generated
            if next_token.item() == tokenizer.eos_token_id:
                break

    # Decode the generated sequence into text
    generated_text = tokenizer.decode(generated_sequence[0], skip_special_tokens=True)
    
    return generated_text

In [None]:
# Define a prompt and generate a sequence
prompt = "2+2="
generated_text = generate_sequence(model=llm_model, tokenizer=tokenizer, prompt=prompt, max_length=5, top_p=0)

print("Generated Text:", generated_text)