In [None]:
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as dist
from torch.utils.data import DataLoader, Dataset

from transformers import AutoTokenizer, AutoModelForCausalLM

from tqdm import tqdm

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
batch_size = 8
experiences = []

In [None]:
# Load GSM8K dataset
gsm8k = load_dataset("gsm8k", "main", split="train")

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B", padding_side="left")
llm_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B").to(device)

# Set padding token
tokenizer.pad_token = tokenizer.eos_token
llm_model.resize_token_embeddings(len(tokenizer))

# Define experience creation in batch
def create_experience_batch(questions, correct_answers, generated_answers):
    batch_experiences = []
    for question, correct_answer, generated_answer in zip(questions, correct_answers, generated_answers):
        state = question
        action = tokenizer.encode(generated_answer, add_special_tokens=True)
        next_state = question + " -> " + generated_answer
        reward = 1.0 if generated_answer.strip() == correct_answer.strip() else 0.0
        batch_experiences.append((state, action, next_state, reward))
    return batch_experiences

# Generate experiences in batches
for i in tqdm(range(0, len(gsm8k), batch_size), desc="Generating experiences"):
    batch = gsm8k.select(range(i, min(i + batch_size, len(gsm8k))))
    questions = batch["question"]
    correct_answers = batch["answer"]

    # Tokenize batch of questions
    inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True).to(device)
    
    # Generate answers in batch
    with torch.no_grad():
        outputs = llm_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=50,
            pad_token_id=tokenizer.pad_token_id
        )
    
    # Decode generated answers and create experiences
    generated_answers = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    experiences.extend(create_experience_batch(questions, correct_answers, generated_answers))

In [77]:
class FixedLengthExperienceDataset(Dataset):
    def __init__(self, experiences, max_length=64):
        self.experiences = experiences
        self.max_length = max_length

    def __len__(self):
        return len(self.experiences)

    def __getitem__(self, idx):
        state, action, next_state, reward = self.experiences[idx]
        
        # Tokenize state and next_state with consistent max_length and padding
        state_ids = tokenizer(
            state, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True, 
            max_length=self.max_length
        )["input_ids"].squeeze(0)
        
        next_state_ids = tokenizer(
            next_state, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True, 
            max_length=self.max_length
        )["input_ids"].squeeze(0)
        
        # Ensure action and reward are in tensor format
        action_tensor = torch.tensor(action, dtype=torch.long)  # Treat action as a list of token IDs
        reward_tensor = torch.tensor(reward, dtype=torch.float32)
        
        return state_ids, action_tensor, next_state_ids, reward_tensor


# Initialize dataset and DataLoader
dataset = FixedLengthExperienceDataset(experiences, max_length=max_length)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Q-network
class QNetwork(nn.Module):
    def __init__(self, vocab_size, state_dim=768, hidden_dim=256, action_dim=tokenizer.vocab_size):
        super(QNetwork, self).__init__()
        self.embedding = nn.Embedding(vocab_size, state_dim)
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = self.embedding(state).mean(dim=1)  # Mean pooling to handle sequence length
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Initialize Q-networks and optimizer
q_network = QNetwork(vocab_size=tokenizer.vocab_size).to(device)
target_q_network = QNetwork(vocab_size=tokenizer.vocab_size).to(device)
target_q_network.load_state_dict(q_network.state_dict())
optimizer = optim.Adam(q_network.parameters(), lr=3e-4)
gamma = 0.99
alpha = 0.1
num_epochs = 5

# Conservative Q-Learning (CQL) Loss
def cql_loss(q_network, target_q_network, states, actions, rewards, next_states, gamma=0.99, alpha=0.1):
    # Ensure actions is a 2D tensor with shape [batch_size, 1]
    if actions.dim() == 1:
        actions = actions.unsqueeze(-1)  # Shape becomes [batch_size, 1]
        
    # Get Q-values for the chosen actions and ensure they have shape [batch_size]
    q_values_all = q_network(states)  # Shape: [batch_size, action_dim]
    q_values = q_values_all.gather(1, actions).squeeze(-1)  # Gather along the action dimension, then squeeze

    # Get max Q-value for the next states
    with torch.no_grad():
        next_q_values = target_q_network(next_states).max(1)[0]  # Shape: [batch_size]

    # Compute the target Q-values with the same shape as q_values
    target_q = (rewards + gamma * next_q_values).view(-1)  # Shape: [batch_size]

    # Ensure q_values and target_q have matching shapes
    assert q_values.shape == target_q.shape, f"Shape mismatch: q_values {q_values.shape}, target_q {target_q.shape}"

    # Calculate Bellman loss using MSE
    bellman_loss = nn.MSELoss()(q_values, target_q)

    # Calculate conservative loss
    logsumexp_q = torch.logsumexp(q_values_all, dim=1)  # Shape: [batch_size]
    conservative_loss = alpha * (logsumexp_q - q_values).mean()

    return bellman_loss + conservative_loss


# Training loop
for epoch in range(num_epochs):
    for states, actions, next_states, rewards in dataloader:
        states = states.to(device)
        actions = actions.to(device)
        next_states = next_states.to(device)
        rewards = rewards.to(device)

        # Compute CQL loss and optimize Q-network
        loss = cql_loss(q_network, target_q_network, states, actions, rewards, next_states, gamma, alpha)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Periodically update target network
    target_q_network.load_state_dict(q_network.state_dict())
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

AssertionError: Shape mismatch: q_values torch.Size([1, 87]), target_q torch.Size([1])