# Trainer pytorch scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [2]:
import os
from typing import Any
from pprint import pprint
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'device: {device}')
assert torch.cuda.is_available()

import numpy as np
np.set_printoptions(linewidth=150)

device: cuda


In [3]:
%load_ext line_profiler
%load_ext memory_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


# Load Trajectories

In [4]:
# Load Trajectories (same as before)
from rgi.core import trajectory

game_name = "connect4"
trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 1100


In [20]:
# Define trajectory NamedTuple and unroll_trajectory function (same as before)
from typing import NamedTuple
import numpy as np
class TrajectoryStep(NamedTuple):
    move_index: int
    state: torch.Tensor
    action: torch.Tensor
    next_state: torch.Tensor
    reward: torch.Tensor

def fixup_reward(x):
    return (x+1) / 2

@torch.jit.script
def fixup_reward_jit(x):
    return (x+1) / 2

def unroll_trajectory_fast(encoded_trajectories: list[trajectory.EncodedTrajectory]):
    print('xxxxxxxxxxxxxxx')
    for t in encoded_trajectories:
        # Convert entire arrays at once
        states = np.array(t.states, dtype=np.float32)
        actions = np.array(t.actions, dtype=np.int32)
        reward_value = np.array(fixup_reward(t.final_rewards[0])) # calculate once
        reward_value_nofn = np.array((t.final_rewards[0]+1)/2) # calculate once
        reward_value_nofn = np.array(t.final_rewards[0]) # calculate once
        reward_value_nofn = np.array(t.final_rewards) # calculate once
        reward_value_nofn = np.array(t.final_rewards)[0] # calculate once
        reward_value_nofn = (np.array(t.final_rewards)[0] +1)/2# calculate once
        reward_value_nofn = fixup_reward(np.array(t.final_rewards[0])) # calculate once
        reward_value_nofn = torch.tensor(np.array(t.final_rewards[0])) # calculate once
        reward_value_nofn = fixup_reward_jit(torch.tensor(np.array(t.final_rewards[0]))) # calculate once
        
        for i in range(t.length - 1):
            yield TrajectoryStep(i, 
                                 torch.from_numpy(states[i]),
                                 # torch.from_numpy(actions[i]),
                                 torch.tensor(actions[i], dtype=torch.int32),
                                 torch.from_numpy(states[i + 1]),
                                 # torch.from_numpy(reward[i])
                                 torch.tensor(reward_value, dtype=torch.float32),
                                 )

# all_trajectory_steps_fast = list(unroll_trajectory_fast(trajectories))
# print(f'num_trajectory_steps: {len(all_trajectory_steps_fast)}')


In [21]:
def do_stuff(n=3):
    for _ in range(n):
        iter = unroll_trajectory_fast(trajectories)
        xx = list(iter)
        ll = len(xx)
        

%lprun -f unroll_trajectory_fast do_stuff(1)
# %lprun -f do_stuff -f unroll_trajectory_fast do_stuff(1)
# %memit do_stuff(1)

xxxxxxxxxxxxxxx


TypeError: the read only flag is not supported, should always be False

In [27]:
# %load_ext memory_profiler
# # %memit do_stuff(1)
# # %mprun -f do_stuff do_stuff(1)

# import memory_profiler
# # %mprun -f do_stuff memory_profiler.profile(do_stuff)(1)

# @memory_profiler.profile
# def do_stuff(n=3):
#     for _ in range(n):
#         iter = unroll_trajectory_fast(trajectories)
#         xx = list(iter)
#         ll = len(xx)

#  %mprun -f do_stuff do_stuff(1)
# # do_stuff(1)

The memory_profiler extension is already loaded. To reload it, use:
  %reload_ext memory_profiler
ERROR: Could not find file /tmp/ipykernel_661/3115989386.py
ERROR: Could not find file /tmp/ipykernel_661/1455663398.py
xxxxxxxxxxxxxxx



Filename: /workspaces/rgi/.venv/lib/python3.12/site-packages/memory_profiler.py

Line #    Mem usage    Increment  Occurrences   Line Contents
  1185   1148.7 MiB   1148.7 MiB           2               @wraps(wrapped=func)
  1186                                                     def wrapper(*args, **kwargs):
  1187   1148.7 MiB      0.0 MiB           2                   prof = get_prof()
  1188   1148.7 MiB      0.0 MiB           2                   val = prof(func)(*args, **kwargs)
  1189   1148.7 MiB      0.0 MiB           2                   show_results_bound(prof)
  1190   1148.7 MiB      0.0 MiB           2                   return val

In [None]:
# Profile data loading code.
import cProfile
import pstats

profiler = cProfile.Profile()
profiler.enable()

# Your slow code here
all_trajectory_steps_2 = list(unroll_trajectory(trajectories))

profiler.disable()
# pip install snakeviz
# snakeviz notebooks/profile_results.prof
# profiler.dump_stats('profile_results.prof')  # 38.6s
profiler.dump_stats('profile_results_2.prof')

stats = pstats.Stats(profiler).sort_stats('cumulative')
stats.print_stats(100)

In [None]:
# Prepare batches
state_batch = torch.stack([t.state for t in all_trajectory_steps])
action_batch = torch.stack([t.action for t in all_trajectory_steps])
reward_batch = torch.stack([t.reward for t in all_trajectory_steps])
batch_full = {'state': state_batch, 'action': action_batch, 'reward': reward_batch}

print(f'state shape:  {batch_full["state"].shape}')
print(f'action shape: {batch_full["action"].shape}')
print(f'reward shape: {batch_full["reward"].shape}')


# Save/Load in pytorch format.

In [None]:
# Convert to tensors
states = torch.stack([t.state for t in all_trajectory_steps])
actions = torch.stack([t.action for t in all_trajectory_steps])
next_states = torch.stack([t.next_state for t in all_trajectory_steps])
rewards = torch.stack([t.reward for t in all_trajectory_steps])

# Save processed data
torch.save({
    'states': states,
    'actions': actions,
    'next_states': next_states,
    'rewards': rewards
}, 'processed_trajectories.pt')

In [None]:
loaded_data = torch.load('processed_trajectories.pt', weights_only=True)
reloaded_states = loaded_data['states']
reloaded_actions = loaded_data['actions']
reloaded_next_states = loaded_data['next_states']
reloaded_rewards = loaded_data['rewards']

In [None]:
from torch.utils.data import Dataset

class TrajectoryDataset(Dataset):
    def __init__(self, states, actions, next_states, rewards):
        self.states = states
        self.actions = actions
        self.next_states = next_states
        self.rewards = rewards

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return (self.states[idx], self.actions[idx], 
                self.next_states[idx], self.rewards[idx])

# Create dataset
dataset = TrajectoryDataset(states, actions, next_states, rewards)

# Use DataLoader for efficient batching
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define Models

In [None]:
class Connect4StateEmbedder(nn.Module):
    def __init__(self, embedding_dim=64, hidden_dim=256):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64 * 6 * 7, self.hidden_dim)
        self.linear2 = nn.Linear(self.hidden_dim, self.embedding_dim)
        
    def _state_to_array(self, encoded_state_batch):
        return encoded_state_batch[:, :-1].reshape(-1, 1, 6, 7)
    
    def forward(self, encoded_state_batch):
        x = self._state_to_array(encoded_state_batch)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.flatten(x)
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [None]:
def test_state_embedder():
    state_embedder = Connect4StateEmbedder().to(device)
    sample_states = state_batch[:2].to(device)
    embeddings = state_embedder(sample_states)
    print("Sample states shape:", sample_states.shape)
    print("State embeddings shape:", embeddings.shape)
    print("First few values of embeddings:")
    print(embeddings[:, :5])

test_state_embedder()

In [None]:
class Connect4ActionEmbedder(nn.Module):
    def __init__(self, embedding_dim=64, num_actions=7):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_actions = num_actions
        self.embedding = nn.Embedding(num_actions, embedding_dim)
    
    def forward(self, action):
        return self.embedding(action - 1)
    
    def all_action_embeddings(self):
        return self.embedding.weight

In [None]:
# After defining Connect4ActionEmbedder
def test_action_embedder():
    action_embedder = Connect4ActionEmbedder().to(device)
    sample_actions = action_batch[:1].to(device)
    embeddings = action_embedder(sample_actions)
    print("Action embeddings shape:", embeddings.shape)
    print("Action embeddings:")
    print(embeddings)
    print("All action embeddings:")
    print(action_embedder.all_action_embeddings()[:, :5])

test_action_embedder()

In [None]:
class PredictionModel(nn.Module):
    def __init__(self, state_embedder, action_embedder, embedding_dim=64, num_actions=7):
        super().__init__()
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder
        self.embedding_dim = embedding_dim
        self.num_actions = num_actions
        self.reward_head = nn.Linear(self.embedding_dim, 1)
    
    def action_logits(self, state_batch):
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = torch.matmul(state_embeddings, all_action_embeddings.t())
        return logits
    
    def action_probs(self, state_batch):
        logits = self.action_logits(state_batch)
        return torch.softmax(logits, dim=-1)
    
    def reward_pred(self, state_batch):
        state_embeddings = self.state_embedder(state_batch)
        return self.reward_head(state_embeddings)

In [None]:
# After defining PredictionModel
def test_prediction_model():
    state_embedder = Connect4StateEmbedder().to(device)
    action_embedder = Connect4ActionEmbedder().to(device)
    prediction_model = PredictionModel(state_embedder, action_embedder).to(device)
    
    sample_states = state_batch[:2].to(device)
    
    action_logits = prediction_model.action_logits(sample_states)
    action_probs = prediction_model.action_probs(sample_states)
    reward_pred = prediction_model.reward_pred(sample_states)
    
    print("Action logits shape:", action_logits.shape)
    print("Action logits:")
    print(action_logits)
    print("\nAction probabilities:")
    print(action_probs)
    print("\nReward predictions:")
    print(reward_pred)

test_prediction_model()

# Define Loss Function

In [None]:
# Loss function
def loss_fn(prediction_model, batch, l2_weight=1e-4):
    action_logits = prediction_model.action_logits(batch['state'])
    action_labels = (batch['action'] - 1).long()  # Convert to Long type
    action_data_loss = nn.functional.cross_entropy(action_logits, action_labels)
    
    reward_pred = prediction_model.reward_pred(batch['state']).squeeze()
    reward_labels = batch['reward']
    reward_data_loss = nn.functional.mse_loss(reward_pred, reward_labels)
    
    l2_loss = sum((p ** 2).sum() for p in prediction_model.parameters())
    
    total_loss = action_data_loss + reward_data_loss + l2_weight * l2_loss
    return total_loss, (action_logits, reward_pred)

# Training function
def train_step(prediction_model, optimizer, batch):
    prediction_model.train()
    optimizer.zero_grad()
    loss, (logits, reward_pred) = loss_fn(prediction_model, batch)
    loss.backward()
    optimizer.step()
    return loss.item(), logits, reward_pred

In [None]:
import torch.utils.data as data

# Prepare dataset
class TrajectoryDataset(data.Dataset):
    def __init__(self, states, actions, rewards):
        self.states = states
        self.actions = actions
        self.rewards = rewards
    
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return self.states[idx], self.actions[idx], self.rewards[idx]

# Create dataset and dataloader
dataset = TrajectoryDataset(state_batch, action_batch, reward_batch)
batch_size = 64  # Adjust this value based on your GPU memory
dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Main training loop
def train_model(print_logits=False, num_epochs=10):
    state_embedder = Connect4StateEmbedder().to(device)
    action_embedder = Connect4ActionEmbedder().to(device)
    prediction_model = PredictionModel(state_embedder, action_embedder).to(device)
    optimizer = optim.Adam(prediction_model.parameters(), lr=0.0005)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for i, (states, actions, rewards) in enumerate(dataloader):
            states, actions, rewards = states.to(device), actions.to(device), rewards.to(device)
            batch = {'state': states, 'action': actions, 'reward': rewards}
            
            loss, (logits, reward_pred) = loss_fn(prediction_model, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(dataloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
        
        if print_logits:
            for j in range(4):
                state = states[j].unsqueeze(0)
                action_probs = prediction_model.action_probs(state)
                reward_pred = prediction_model.reward_pred(state).item()
                reward_true = rewards[j].item()
                print(j, f'r={reward_true:.4f} p={reward_pred:.4f}', action_probs.cpu().detach().numpy())
            print()
    
    return prediction_model

In [None]:
# Train the model
prediction_model = train_model(print_logits=True, num_epochs=10)

In [None]:
# Save the model
torch.save(prediction_model.state_dict(), 'connect4_prediction_model.pth')

In [None]:
# Load the model
loaded_model = PredictionModel(Connect4StateEmbedder(), Connect4ActionEmbedder())
loaded_model.load_state_dict(torch.load('connect4_prediction_model.pth'))
loaded_model.eval()

In [None]:
# Test the loaded model
from rgi.games import connect4
import numpy as np

game = connect4.Connect4Game()
serializer = connect4.Connect4Serializer()

print('Move 0:')
s_0 = game.initial_state()
state_array = serializer.state_to_jax_array(game, s_0)
writeable_state = np.array(state_array, dtype=np.float32)  # Create a writeable copy
j_0 = torch.tensor(writeable_state, dtype=torch.float32).unsqueeze(0).to(device)

print(j_0, prediction_model.reward_pred(j_0).item(), prediction_model.action_probs(j_0).cpu().detach().numpy())

# Load the model and move it to GPU
loaded_model = PredictionModel(Connect4StateEmbedder(), Connect4ActionEmbedder())
loaded_model.load_state_dict(torch.load('connect4_prediction_model.pth'))
loaded_model = loaded_model.to(device)  # Move the loaded model to GPU
loaded_model.eval()

print(j_0, loaded_model.reward_pred(j_0).item(), loaded_model.action_probs(j_0).cpu().detach().numpy())