# Model PreTraining

<b>Purpose: </b> Through this script, we will provide the RLHF model a starting point of knowledge on the game via training on annotated gamplay

In [None]:
import os
import json
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import mlflow
import mlflow.pytorch
import importlib_metadata
from google.cloud import storage
from dotenv import load_dotenv
from io import BytesIO

print("Packages imported!")

# Set the device
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(f"Device: {device}")

# Load environment variables from .env file
load_dotenv()
gcs_credentials_path = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')

# Initialize GCS Client
client = storage.Client()
print(f"GCS Client Initialized: {client}")


In [None]:
# Action mapping from strings to integers
ACTION_MAPPING = {
    "a": 0,
    "b": 1,
    "x": 2,
    "y": 3,
    "up": 4,
    "down": 5,
    "left": 6,
    "right": 7,
    "none": 8
}

### Creating the Data Loader

The dataloader is abnormal in the sense that each data point is actually a sequence of several. <br>
This is done to enable memory, vital for a storyline game like this.

In [None]:
class PokemonDatasetGCS(Dataset):
    def __init__(self, bucket_name, states_prefix,augmentation, actions_prefix, seq_length, transform=None):
        self.bucket = client.bucket(bucket_name)               # GCS bucket
        self.states_prefix = states_prefix                     # Prefix for states (images) in GCS
        self.actions_prefix = actions_prefix                   # Prefix for actions (JSONs) in GCS
        self.transform = transform                             # Any image transformations
        self.augmentation = augmentation                       # Any image augmentations

        # Fetch the list of blobs (files) in each folder
        self.states = sorted([blob.name for blob in self.bucket.list_blobs(prefix=states_prefix) if blob.name.endswith('.jpg')])
        self.actions = sorted([blob.name for blob in self.bucket.list_blobs(prefix=actions_prefix) if blob.name.endswith('.json')])

        self.seq_length = seq_length                           # Desired sequence length
    
    def __len__(self):
        return len(self.states) - self.seq_length
    
    def __getitem__(self, idx):
        state_seq = []
        action_seq = []

        # Loop to obtain seq_length states, actions, and annotations
        for i in range(self.seq_length):
            # Load image from GCS with error handling
            try:
                state_blob = self.bucket.blob(self.states[idx + i])
                image_data = state_blob.download_as_bytes()  # Download the image as bytes

                # Wrap the byte data with BytesIO to create a file-like object
                image = Image.open(BytesIO(image_data)).convert('RGB')  # Ensure image is in RGB format

                if self.augmentation:
                    image = self.augmentation(image)
                
                if self.transform:
                    image = self.transform(image)

                state_seq.append(image)

            except Exception as e:
                print(f"Error downloading image {self.states[idx + i]}: {e}")
                return None  # Skip this batch if there's an error
            
            # Load action JSON from GCS with error handling
            try:
                action_blob = self.bucket.blob(self.actions[idx + i])
                action_data = json.loads(action_blob.download_as_text())
                action = action_data.get('action', None)

                if action is None:
                    print(f"Warning: 'action' key missing in {action_blob.name}")
                    return None  # Skip this batch if there's missing action
                else:
                    action_seq.append(ACTION_MAPPING[action])

            except (json.JSONDecodeError, Exception) as e:
                print(f"Error downloading or decoding action {self.actions[idx + i]}: {e}")
                return None  # Skip this batch if there's an error

        # Convert state sequences to tensor
        state_seq_tensor = torch.stack(state_seq)

        # Convert action sequences to tensor
        action_seq_tensor = torch.tensor(action_seq, dtype=torch.long)

        return state_seq_tensor, action_seq_tensor


In [None]:
class PokemonDataset(Dataset):
    def __init__(self, states_dir, actions_dir, seq_length, transform=None):
        self.states_dir = states_dir                            # Reading in the states directory
        self.actions_dir = actions_dir                          # Reading in the actions directory
        self.transform = transform                              # Any image transformations 
        
        # Filter files to only include valid images
        self.states = sorted([f for f in os.listdir(states_dir) if f.endswith('.jpg')])
        self.actions = sorted([f for f in os.listdir(actions_dir) if f.endswith('.json')])

        self.seq_length = seq_length                            # Reading in the desired seq length

    def __len__(self):
        return len(self.states) - self.seq_length               # Returns len of dataloader
    
    def __getitem__(self, idx):
        state_seq = []
        action_seq = []

        # Obtaining seq_length images, actions, and annotations
        for i in range(self.seq_length):

            # Retrieve image
            img_name = os.path.join(self.states_dir, self.states[idx + i])
            image = Image.open(img_name).convert('RGB')  # Ensure image is in RGB format
            if self.transform:
                image = self.transform(image) # Transform image
            state_seq.append(image)

            # Retrieve action
            action_name = os.path.join(self.actions_dir, self.actions[idx + i])
            ## Check if the JSON file is empty or corrupted
            try:
                with open(action_name, 'r') as f:
                    if os.stat(action_name).st_size == 0:
                        print(f"Warning: {action_name} is empty. Skipping this file.")
                        action_seq.append(None)  # Skip this entry
                        continue

                    action_data = json.load(f)
                    action = action_data.get('action', None)

                    if action is None:
                        print(f"Warning: 'action' key missing in {action_name}")
                        action_seq.append(None)  # Skip this entry
                    else:
                        action_seq.append(ACTION_MAPPING[action])
            except json.JSONDecodeError as e:
                print(f"Warning: Failed to decode JSON in {action_name}: {e}")
                action_seq.append(None)  # Skip this entry

        # Convert state sequences to tensor
        state_seq_tensor = torch.stack(state_seq)

        # Convert action sequences to tensor
        action_seq_tensor = torch.tensor(action_seq, dtype=torch.long)

        
        return state_seq_tensor, action_seq_tensor

Initialise Dataloader

In [None]:
# Chosen transformation
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

# Chosen augmentatons
augmentation = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # Random colour augmentation
    transforms.RandomRotation(degrees=30),                # Random rotation
])

seq_length = 3
# Set the GCS bucket and folder prefixes
bucket_name = 'pokemonplatinumai-annotationimages'
states_prefix = 'phase-1/images/'
actions_prefix = 'phase-1/actions/'

dataset = PokemonDatasetGCS(
    bucket_name=bucket_name,
    states_prefix=states_prefix,
    actions_prefix=actions_prefix,
    seq_length=seq_length,
    transform=transform,
    augmentation = augmentation
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

print("DataLoader loaded!")

### Model Training

Initialise the model

In [None]:
from models.PokemonModelLSTM import PokemonModelLSTM
# Setting Hyperparameters
num_actions = 9  # (Total Number of Actions: [A, B, X, Y, Up, Down, Left, Right, None]) (Excluding Start, Select, L, R to reduce model complexity)
input_size = 32 * 160 * 160
hidden_size = 128
num_layers = 2
num_epochs = 20
learning_rate = 0.003

# Initialising model
model = PokemonModelLSTM(input_size, hidden_size, num_layers, num_actions).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


In [None]:
# Load the model and optimizer state if a checkpoint exists
if os.path.exists("models/pokemon_model_lstm.pth"):
    state_dict = torch.load("models/pokemon_model_lstm.pth")
    model.load_state_dict(state_dict)

### Phase1

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0  # Track loss for the current epoch
    
    for i, batch in enumerate(dataloader):
        # Skip batches that returned None
        if batch is None:
            print(f"Skipping batch {i} due to missing data.")
            continue
        else:
            print(f"Processing epoch {epoch}, batch {i}")
        
        state_seq, action_seq = batch
        state_seq, action_seq = state_seq.to(device), action_seq.to(device)

        optimizer.zero_grad()

        # Forward pass
        output = model(state_seq)

        # Calculate loss
        action_seq = action_seq[:, -1]  # Get the last action in the sequence for each batch
        loss = criterion(output, action_seq)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Accumulate loss for the epoch
        epoch_loss += loss.item()
    
    # Log loss for the epoch
    avg_loss = epoch_loss / len(dataloader)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss}")

    
    # Save the model, optimizer state, and other information after each epoch
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
    }
    model_path = f"models/phase1/pokemon_model_lstm_epoch_{epoch+1}.pth"
    torch.save(checkpoint, model_path)

# Save final model with all relevant states
final_checkpoint = {
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss.item(),
    }

final_model_path = "models/pokemon_model_lstm_final.pth"
torch.save(final_checkpoint, final_model_path)

print("Model and metrics logged with MLflow!")

### Testing the model

In [None]:
from google.cloud import storage
from dotenv import load_dotenv
from io import BytesIO
import os
from torchvision import transforms
from PIL import Image
from RLHF_Scripts.modular_scripts.rlhf_utils import ACTION_MAP_DIALOGUE, REVERSED_ACTION_MAPPING

# Load environment variables from .env file
load_dotenv()
gcs_credentials_path = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')

# Initialize GCS Client
client = storage.Client()
print(f"GCS Client Initialized: {client}")

# Define the GCS bucket and file path
bucket_name = 'pokemonplatinumai-annotationimages'
file_path = 'phase-1/images/2024-09-01_21-41-17_png.rf.8afcea2c3ddb01cfa1caa95e8777b066.jpg'

# Preprocess the image
transform = transforms.Compose([
    transforms.Resize((640,640)),  # Resize image
    transforms.ToTensor(),  # Convert image to tensor
])

# Initialize GCS client and get the image file
client = storage.Client()
bucket = client.get_bucket(bucket_name)
blob = bucket.blob(file_path)

# Download the image as a byte stream and open it
image_data = blob.download_as_bytes()
image = Image.open(BytesIO(image_data))

# Apply transformations
image = transform(image)
image = image.unsqueeze(0).unsqueeze(0)  # Add batch dimensions


In [None]:
import torch
from models.PokemonModelLSTM import PokemonModelLSTM

# Setting Hyperparameters
num_actions = 9  # Total number of actions
input_size = 32 * 160 * 160
hidden_size = 128
num_layers = 2
num_epochs = 20
learning_rate = 0.001

# Initialize model
model = PokemonModelLSTM(input_size, hidden_size, num_layers, num_actions)
checkpoint = torch.load("models/phase1/pokemon_model_lstm_epoch_5.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()  # Set the model to evaluation mode
print('Model loaded')

# Perform prediction
with torch.no_grad():  # Disable gradient calculations for inference
    action = model(image)  # Pass the image through the model
    predicted_action = torch.argmax(action, dim=1)  # Get the predicted action

action = ACTION_MAP_DIALOGUE[REVERSED_ACTION_MAPPING[predicted_action.item()]]
print(f'Prediction made: {action}')

In [None]:
# Initialize MLflow experiment
mlflow.set_experiment("PokemonModelLSTM_3")

# Start MLflow run
with mlflow.start_run():
    # Log parameters
    mlflow.log_param("num_actions", num_actions)
    mlflow.log_param("input_size", input_size)
    mlflow.log_param("hidden_size", hidden_size)
    mlflow.log_param("num_layers", num_layers)
    mlflow.log_param("learning_rate", learning_rate)
    mlflow.log_param("num_epochs", num_epochs)
 
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0                       # Track loss for the current epoch
        for i, (state_seq, action_seq) in enumerate(dataloader):
            print(f"Currently processing Epoch {epoch}, batch {i}")
            state_seq, action_seq = state_seq.to(device), action_seq.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(state_seq)

            # Calculate loss
            action_seq = action_seq[:, -1]  # Get the last action in the sequence for each batch
            loss = criterion(output, action_seq)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            # Accumulate loss for the epoch
            epoch_loss += loss.item()
        
        # Save the model to MLflow
        print(f"Saving model to MLflow")
        mlflow.pytorch.log_model(model, f"model_epoch_{epoch+1}")

        # Log loss for the epoch
        avg_loss = epoch_loss / len(dataloader)
        try: 
            mlflow.log_metric("loss", avg_loss, step=epoch)
        except Exception as e:
            print(f"Error: {e}")
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss}")
        
        # Save the model, optimizer state, and other information after each epoch
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
        }
        model_path = f"models/pokemon_model_lstm_epoch_{epoch+1}.pth"
        torch.save(checkpoint, model_path)

    # Save final model with all relevant states
    final_checkpoint = {
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
        }
    
    final_model_path = "models/pokemon_model_lstm_final.pth"
    torch.save(final_checkpoint, final_model_path)
    mlflow.pytorch.log_model(model, "final_model")

print("Model and metrics logged with MLflow!")

In [None]:
# Create a simple test model (or use your existing model)
class TestModel(torch.nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.fc = torch.nn.Linear(10, 2)
    
    def forward(self, x):
        return self.fc(x)

# Initialize the model and optimizer
model = TestModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Test the logging step without running the full training loop
def test_mlflow_logging():
    with mlflow.start_run():
        # Create a dummy state dict and loss for testing purposes
        test_checkpoint = {
            'epoch': 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': 0.5,  # Dummy loss value
        }

        # Save and log the model using MLflow
        model_path = "models/test_model.pth"
        torch.save(test_checkpoint, model_path)

        try:
            # Test logging the model with MLflow
            mlflow.pytorch.log_model(model, "test_model")
            print("MLflow model logging successful.")
        except Exception as e:
            print(f"Error during MLflow logging: {e}")

# Run the test function
test_mlflow_logging()


In [None]:
def check_versions():
    try:
        # Check torch version using importlib_metadata
        torch_version = importlib_metadata.version("torch")
        print(f"importlib_metadata found torch version: {torch_version}")
    except KeyError as e:
        print(f"Error: Could not find version for 'torch': {e}")
    except Exception as e:
        print(f"Unexpected error occurred: {e}")

    # Also check torch version from the package itself
    print(f"torch.__version__: {torch.__version__}")

    try:
        # Simulate a log model call with MLflow to verify functionality
        model = torch.nn.Linear(2, 2)  # Create a simple dummy model for testing
        mlflow.set_tracking_uri("mlruns")
        mlflow.set_experiment("VerificationExperiment")

        with mlflow.start_run():
            mlflow.pytorch.log_model(
                model, 
                artifact_path="dummy_model",
                pip_requirements=[f"torch=={torch.__version__}", "cloudpickle==2.0.0"]
            )
            print("MLflow logging successful.")
    except Exception as e:
        print(f"MLflow encountered an error: {e}")

# Run the verification
check_versions()
