# Trainer pytorch scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [1]:
import os
from typing import Any
from pprint import pprint
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim

assert torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')

import numpy as np
np.set_printoptions(linewidth=150)

device: cuda


# Load Trajectories

In [2]:
# Load Trajectories (same as before)
from rgi.core import trajectory

game_name = "connect4"
trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 1100


In [3]:
# Define trajectory NamedTuple and unroll_trajectory function (same as before)
from typing import NamedTuple
import numpy as np
class TrajectoryStep(NamedTuple):
    move_index: int
    state: torch.Tensor
    action: torch.Tensor
    next_state: torch.Tensor
    reward: torch.Tensor

def fixup_reward(x): return (x+1) / 2

def unroll_trajectory(encoded_trajectories: list[trajectory.EncodedTrajectory]):
    for t in encoded_trajectories:
        for i in range(t.length - 1):
            # Convert states to writeable numpy arrays
            state = np.array(t.states[i], dtype=np.float32)
            next_state = np.array(t.states[i + 1], dtype=np.float32)
            action = np.array(t.actions[i], dtype=np.int32)
            reward = np.array(fixup_reward(t.final_rewards[0]), dtype=np.float32)
                
            yield TrajectoryStep(i, 
                                 torch.tensor(state, dtype=torch.float32),
                                 torch.tensor(action, dtype=torch.int32),
                                 torch.tensor(next_state, dtype=torch.float32),
                                 torch.tensor(reward, dtype=torch.float32))

all_trajectory_steps = list(unroll_trajectory(trajectories))
print(f'num_trajectory_steps: {len(all_trajectory_steps)}')

num_trajectory_steps: 22835


In [None]:
# Profile data loading code.
import cProfile
import pstats

profiler = cProfile.Profile()
profiler.enable()

# Your slow code here
all_trajectory_steps = list(unroll_trajectory(trajectories))

profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumulative')
stats.print_stats(100)

In [8]:
# Prepare batches
state_batch = torch.stack([t.state for t in all_trajectory_steps])
action_batch = torch.stack([t.action for t in all_trajectory_steps])
reward_batch = torch.stack([t.reward for t in all_trajectory_steps])
batch_full = {'state': state_batch, 'action': action_batch, 'reward': reward_batch}

print(f'state shape:  {batch_full["state"].shape}')
print(f'action shape: {batch_full["action"].shape}')
print(f'reward shape: {batch_full["reward"].shape}')


state shape:  torch.Size([22835, 43])
action shape: torch.Size([22835])
reward shape: torch.Size([22835])


# Save/Load in pytorch format.

In [10]:
# Convert to tensors
states = torch.stack([t.state for t in all_trajectory_steps])
actions = torch.stack([t.action for t in all_trajectory_steps])
next_states = torch.stack([t.next_state for t in all_trajectory_steps])
rewards = torch.stack([t.reward for t in all_trajectory_steps])

# Save processed data
torch.save({
    'states': states,
    'actions': actions,
    'next_states': next_states,
    'rewards': rewards
}, 'processed_trajectories.pt')

In [12]:
loaded_data = torch.load('processed_trajectories.pt', weights_only=True)
reloaded_states = loaded_data['states']
reloaded_actions = loaded_data['actions']
reloaded_next_states = loaded_data['next_states']
reloaded_rewards = loaded_data['rewards']

In [13]:
from torch.utils.data import Dataset

class TrajectoryDataset(Dataset):
    def __init__(self, states, actions, next_states, rewards):
        self.states = states
        self.actions = actions
        self.next_states = next_states
        self.rewards = rewards

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return (self.states[idx], self.actions[idx], 
                self.next_states[idx], self.rewards[idx])

# Create dataset
dataset = TrajectoryDataset(states, actions, next_states, rewards)

# Use DataLoader for efficient batching
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define Models

In [5]:
class Connect4StateEmbedder(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=256):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64 * 6 * 7, self.hidden_dim)
        self.linear2 = nn.Linear(self.hidden_dim, self.embedding_dim)
        
    def _state_to_array(self, encoded_state_batch):
        return encoded_state_batch[:, :-1].reshape(-1, 1, 6, 7)
    
    def forward(self, encoded_state_batch):
        x = self._state_to_array(encoded_state_batch)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.flatten(x)
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [6]:
def test_state_embedder():
    state_embedder = Connect4StateEmbedder().to(device)
    sample_states = state_batch[:2].to(device)
    embeddings = state_embedder(sample_states)
    print("Sample states shape:", sample_states.shape)
    print("State embeddings shape:", embeddings.shape)
    print("First few values of embeddings:")
    print(embeddings[:, :5])

test_state_embedder()

Sample states shape: torch.Size([2, 43])
State embeddings shape: torch.Size([2, 64])
First few values of embeddings:
tensor([[ 0.0276,  0.0408,  0.0303, -0.0146, -0.0464],
        [ 0.0230,  0.0380,  0.0300, -0.0109, -0.0419]], device='cuda:0',
       grad_fn=<SliceBackward0>)


In [7]:
class Connect4ActionEmbedder(nn.Module):
    def __init__(self, embedding_dim=64, num_actions=7):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_actions = num_actions
        self.embedding = nn.Embedding(num_actions, embedding_dim)
    
    def forward(self, action):
        return self.embedding(action - 1)
    
    def all_action_embeddings(self):
        return self.embedding.weight

In [8]:
# After defining Connect4ActionEmbedder
def test_action_embedder():
    action_embedder = Connect4ActionEmbedder().to(device)
    sample_actions = action_batch[:1].to(device)
    embeddings = action_embedder(sample_actions)
    print("Action embeddings shape:", embeddings.shape)
    print("Action embeddings:")
    print(embeddings)
    print("All action embeddings:")
    print(action_embedder.all_action_embeddings()[:, :5])

test_action_embedder()

Action embeddings shape: torch.Size([1, 64])
Action embeddings:
tensor([[ 0.4400, -0.3875,  1.5108, -0.6908, -0.4308,  0.2181,  1.7117,  1.6171,
          0.4790,  1.6018,  0.1761,  0.5400,  0.4086,  0.6683, -0.5456,  0.3425,
         -0.2177, -0.1389,  0.2651, -0.6244,  2.1062, -0.2677,  0.4823, -0.2823,
          0.2660, -0.0082,  1.6424, -1.8429,  1.4423,  0.6335,  0.3741, -0.4514,
          2.2383,  0.1583,  0.2759, -0.7155,  0.1515,  1.5209, -0.6984, -1.1295,
         -0.5943,  1.7510, -0.5253,  0.8532,  0.6825, -0.1523, -0.5436,  0.7105,
          0.2030, -0.2270, -0.6880, -0.7623, -0.4580,  0.3234, -0.0067, -0.9274,
         -0.2421,  0.0236, -1.3549, -0.1059,  0.7561,  1.5237,  0.2885, -0.1838]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)
All action embeddings:
tensor([[ 0.0677, -1.0600, -0.0900,  0.8930, -1.0142],
        [-1.3473,  0.3974,  2.2329,  1.0015, -0.7086],
        [ 0.1829, -0.3556, -1.3620, -0.9318,  0.6255],
        [-0.6195, -0.8287, -1.2327, -0.5364,

In [9]:
class PredictionModel(nn.Module):
    def __init__(self, state_embedder, action_embedder, embedding_dim=64, num_actions=7):
        super().__init__()
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder
        self.embedding_dim = embedding_dim
        self.num_actions = num_actions
        self.reward_head = nn.Linear(self.embedding_dim, 1)
    
    def action_logits(self, state_batch):
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = torch.matmul(state_embeddings, all_action_embeddings.t())
        return logits
    
    def action_probs(self, state_batch):
        logits = self.action_logits(state_batch)
        return torch.softmax(logits, dim=-1)
    
    def reward_pred(self, state_batch):
        state_embeddings = self.state_embedder(state_batch)
        return self.reward_head(state_embeddings)

In [10]:
# After defining PredictionModel
def test_prediction_model():
    state_embedder = Connect4StateEmbedder().to(device)
    action_embedder = Connect4ActionEmbedder().to(device)
    prediction_model = PredictionModel(state_embedder, action_embedder).to(device)
    
    sample_states = state_batch[:2].to(device)
    
    action_logits = prediction_model.action_logits(sample_states)
    action_probs = prediction_model.action_probs(sample_states)
    reward_pred = prediction_model.reward_pred(sample_states)
    
    print("Action logits shape:", action_logits.shape)
    print("Action logits:")
    print(action_logits)
    print("\nAction probabilities:")
    print(action_probs)
    print("\nReward predictions:")
    print(reward_pred)

test_prediction_model()

Action logits shape: torch.Size([2, 7])
Action logits:
tensor([[-0.6928, -0.2426,  0.8710,  0.0857, -0.4242, -0.4686, -0.2346],
        [-0.6637, -0.2830,  0.8562,  0.0995, -0.4559, -0.4721, -0.2352]],
       device='cuda:0', grad_fn=<MmBackward0>)

Action probabilities:
tensor([[0.0732, 0.1148, 0.3496, 0.1594, 0.0957, 0.0916, 0.1157],
        [0.0760, 0.1112, 0.3475, 0.1630, 0.0936, 0.0921, 0.1167]],
       device='cuda:0', grad_fn=<SoftmaxBackward0>)

Reward predictions:
tensor([[0.0074],
        [0.0081]], device='cuda:0', grad_fn=<AddmmBackward0>)


# Define Loss Function

In [16]:
# Loss function
def loss_fn(prediction_model, batch, l2_weight=1e-4):
    action_logits = prediction_model.action_logits(batch['state'])
    action_labels = (batch['action'] - 1).long()  # Convert to Long type
    action_data_loss = nn.functional.cross_entropy(action_logits, action_labels)
    
    reward_pred = prediction_model.reward_pred(batch['state']).squeeze()
    reward_labels = batch['reward']
    reward_data_loss = nn.functional.mse_loss(reward_pred, reward_labels)
    
    l2_loss = sum((p ** 2).sum() for p in prediction_model.parameters())
    
    total_loss = action_data_loss + reward_data_loss + l2_weight * l2_loss
    return total_loss, (action_logits, reward_pred)

# Training function
def train_step(prediction_model, optimizer, batch):
    prediction_model.train()
    optimizer.zero_grad()
    loss, (logits, reward_pred) = loss_fn(prediction_model, batch)
    loss.backward()
    optimizer.step()
    return loss.item(), logits, reward_pred

In [None]:
import torch.utils.data as data

# Prepare dataset
class TrajectoryDataset(data.Dataset):
    def __init__(self, states, actions, rewards):
        self.states = states
        self.actions = actions
        self.rewards = rewards
    
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return self.states[idx], self.actions[idx], self.rewards[idx]

# Create dataset and dataloader
dataset = TrajectoryDataset(state_batch, action_batch, reward_batch)
batch_size = 64  # Adjust this value based on your GPU memory
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Main training loop
def train_model(print_logits=False, num_epochs=10):
    state_embedder = Connect4StateEmbedder().to(device)
    action_embedder = Connect4ActionEmbedder().to(device)
    prediction_model = PredictionModel(state_embedder, action_embedder).to(device)
    optimizer = optim.Adam(prediction_model.parameters(), lr=0.0005)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for i, (states, actions, rewards) in enumerate(dataloader):
            states, actions, rewards = states.to(device), actions.to(device), rewards.to(device)
            batch = {'state': states, 'action': actions, 'reward': rewards}
            
            loss, (logits, reward_pred) = loss_fn(prediction_model, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
        
        if print_logits:
            for j in range(4):
                state = states[j].unsqueeze(0)
                action_probs = prediction_model.action_probs(state)
                reward_pred = prediction_model.reward_pred(state).item()
                reward_true = rewards[j].item()
                print(j, f'r={reward_true:.4f} p={reward_pred:.4f}', action_probs.cpu().detach().numpy())
            print()
    
    return prediction_model

In [20]:
# Train the model
prediction_model = train_model(print_logits=True, num_epochs=10)

Epoch [1/10], Step [100/357], Loss: 2.2429
Epoch [1/10], Step [200/357], Loss: 2.2272
Epoch [1/10], Step [300/357], Loss: 2.2235
Epoch [1/10], Average Loss: 2.2546
0 r=1.0000 p=0.6442 [[0.11464751 0.13339584 0.12322022 0.15932459 0.16929626 0.16126208 0.13885349]]
1 r=1.0000 p=0.5785 [[0.10394179 0.12067673 0.12600508 0.17570962 0.16703825 0.1559565  0.15067205]]
2 r=0.0000 p=0.5148 [[0.14041291 0.13038336 0.12101714 0.15320817 0.15778473 0.1736262  0.12356753]]
3 r=1.0000 p=0.5808 [[0.09953909 0.13835198 0.13823558 0.17376362 0.14622344 0.16698591 0.1369004 ]]

Epoch [2/10], Step [100/357], Loss: 2.2195
Epoch [2/10], Step [200/357], Loss: 2.2243
Epoch [2/10], Step [300/357], Loss: 2.2278
Epoch [2/10], Average Loss: 2.2082
0 r=1.0000 p=0.6067 [[0.17434004 0.15488389 0.16499685 0.13464737 0.01110733 0.22378612 0.13623843]]
1 r=1.0000 p=1.0207 [[0.17595111 0.18077828 0.17725362 0.01211655 0.11814594 0.17312568 0.16262877]]
2 r=0.0000 p=0.7469 [[0.12633634 0.13760516 0.19996881 0.14245827

In [21]:
# Save the model
torch.save(prediction_model.state_dict(), 'connect4_prediction_model.pth')

In [22]:
# Load the model
loaded_model = PredictionModel(Connect4StateEmbedder(), Connect4ActionEmbedder())
loaded_model.load_state_dict(torch.load('connect4_prediction_model.pth'))
loaded_model.eval()

PredictionModel(
  (state_embedder): Connect4StateEmbedder(
    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (linear1): Linear(in_features=2688, out_features=256, bias=True)
    (linear2): Linear(in_features=256, out_features=64, bias=True)
  )
  (action_embedder): Connect4ActionEmbedder(
    (embedding): Embedding(7, 64)
  )
  (reward_head): Linear(in_features=64, out_features=1, bias=True)
)

In [29]:
# Test the loaded model
from rgi.games import connect4
import numpy as np

game = connect4.Connect4Game()
serializer = connect4.Connect4Serializer()

print('Move 0:')
s_0 = game.initial_state()
state_array = serializer.state_to_jax_array(game, s_0)
writeable_state = np.array(state_array, dtype=np.float32)  # Create a writeable copy
j_0 = torch.tensor(writeable_state, dtype=torch.float32).unsqueeze(0).to(device)

print(j_0, prediction_model.reward_pred(j_0).item(), prediction_model.action_probs(j_0).cpu().detach().numpy())

# Load the model and move it to GPU
loaded_model = PredictionModel(Connect4StateEmbedder(), Connect4ActionEmbedder())
loaded_model.load_state_dict(torch.load('connect4_prediction_model.pth'))
loaded_model = loaded_model.to(device)  # Move the loaded model to GPU
loaded_model.eval()

print(j_0, loaded_model.reward_pred(j_0).item(), loaded_model.action_probs(j_0).cpu().detach().numpy())

Move 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1.]], device='cuda:0') 0.5605407953262329 [[0.1638865  0.15204294 0.12782744 0.1432951  0.12901308 0.1448501  0.13908486]]
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 1.]], device='cuda:0') 0.5605407953262329 [[0.1638865  0.15204294 0.12782744 0.1432951  0.12901308 0.1448501  0.13908486]]
