# Model PreTraining

<b>Purpose: </b> Through this script, we will provide the RLHF model a starting point of knowledge on the game via training on annotated gamplay

In [3]:
import os
import json
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import mlflow
import mlflow.pytorch
import importlib_metadata
from google.cloud import storage
from dotenv import load_dotenv
from io import BytesIO

print("Packages imported!")

# Set the device
device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(f"Device: {device}")

# Load environment variables from .env file
load_dotenv()
gcs_credentials_path = os.getenv('GOOGLE_APPLICATION_CREDENTIALS')

# Initialize GCS Client
client = storage.Client()
print(f"GCS Client Initialized: {client}")


Packages imported!
Device: cpu




KeyboardInterrupt: 

In [2]:
# Action mapping from strings to integers
ACTION_MAPPING = {
    "a": 0,
    "b": 1,
    "x": 2,
    "y": 3,
    "up": 4,
    "down": 5,
    "left": 6,
    "right": 7,
    "none": 8
}

### Creating the Data Loader

The dataloader is abnormal in the sense that each data point is actually a sequence of several. <br>
This is done to enable memory, vital for a storyline game like this.

In [None]:
class PokemonDatasetGCS(Dataset):
    def __init__(self, bucket_name, states_prefix, actions_prefix, annotations_prefix, seq_length, transform=None):
        self.bucket = client.bucket(bucket_name)               # GCS bucket
        self.states_prefix = states_prefix                     # Prefix for states (images) in GCS
        self.actions_prefix = actions_prefix                   # Prefix for actions (JSONs) in GCS
        self.annotations_prefix = annotations_prefix           # Prefix for annotations (texts) in GCS
        self.transform = transform                             # Any image transformations
        
        # Fetch the list of blobs (files) in each folder
        self.states = sorted([blob.name for blob in self.bucket.list_blobs(prefix=states_prefix) if blob.name.endswith('.jpg')])
        self.actions = sorted([blob.name for blob in self.bucket.list_blobs(prefix=actions_prefix) if blob.name.endswith('.json')])
        self.annotations = sorted([blob.name for blob in self.bucket.list_blobs(prefix=annotations_prefix) if blob.name.endswith('.txt')])

        self.seq_length = seq_length                           # Desired sequence length
    
    def __len__(self):
        return len(self.states) - self.seq_length
    
    def __getitem__(self, idx):
        state_seq = []
        action_seq = []
        annotation_seq = []

        # Loop to obtain seq_length states, actions, and annotations
        for i in range(self.seq_length):
            # Load image from GCS
            state_blob = self.bucket.blob(self.states[idx + i])
            image_data = state_blob.download_as_bytes()  # Download the image as bytes
            
            # Wrap the byte data with BytesIO to create a file-like object
            image = Image.open(BytesIO(image_data)).convert('RGB')  # Ensure image is in RGB format
            
            if self.transform:
                image = self.transform(image)
            state_seq.append(image)

            # Load action JSON from GCS
            action_blob = self.bucket.blob(self.actions[idx + i])
            try:
                action_data = json.loads(action_blob.download_as_text())
                action = action_data.get('action', None)

                if action is None:
                    print(f"Warning: 'action' key missing in {action_blob.name}")
                    action_seq.append(None)  # Skip this entry
                else:
                    action_seq.append(ACTION_MAPPING[action])
            except json.JSONDecodeError as e:
                print(f"Warning: Failed to decode JSON in {action_blob.name}: {e}")
                action_seq.append(None)  # Skip this entry

            # Load annotation file from GCS
            annotation_blob = self.bucket.blob(self.annotations[idx + i])
            annotation_data = annotation_blob.download_as_text()
            annotations = [list(map(float, line.strip().split())) for line in annotation_data.splitlines()]  # Ensure annotations are list of floats
            annotation_seq.append(annotations)

        # Convert state sequences to tensor
        state_seq_tensor = torch.stack(state_seq)

        # Convert action sequences to tensor
        action_seq_tensor = torch.tensor(action_seq, dtype=torch.long)

        # Pad annotations to a fixed size and convert to tensor
        max_annotations = 25
        padded_annotations = []
        for ann in annotation_seq:
            if len(ann) < max_annotations:
                ann += [[-1, 0, 0, 0, 0]] * (max_annotations - len(ann))
            elif len(ann) > max_annotations:
                ann = ann[:max_annotations]
            padded_annotations.append(ann)
        annotation_seq_tensor = torch.tensor(padded_annotations, dtype=torch.float32)

        return state_seq_tensor, action_seq_tensor, annotation_seq_tensor

In [None]:
class PokemonDataset(Dataset):
    def __init__(self, states_dir, actions_dir, annotations_dir, seq_length, transform=None):
        self.states_dir = states_dir                            # Reading in the states directory
        self.actions_dir = actions_dir                          # Reading in the actions directory
        self.annotations_dir = annotations_dir                  # Reading in the annotations directory
        self.transform = transform                              # Any image transformations 
        
        # Filter files to only include valid images
        self.states = sorted([f for f in os.listdir(states_dir) if f.endswith('.jpg')])
        self.actions = sorted([f for f in os.listdir(actions_dir) if f.endswith('.json')])
        self.annotations = sorted([f for f in os.listdir(annotations_dir) if f.endswith('.txt')])

        self.seq_length = seq_length                            # Reading in the desired seq length

    def __len__(self):
        return len(self.states) - self.seq_length               # Returns len of dataloader
    
    def __getitem__(self, idx):
        state_seq = []
        action_seq = []
        annotation_seq = []

        # Obtaining seq_length images, actions, and annotations
        for i in range(self.seq_length):

            # Retrieve image
            img_name = os.path.join(self.states_dir, self.states[idx + i])
            image = Image.open(img_name).convert('RGB')  # Ensure image is in RGB format
            if self.transform:
                image = self.transform(image) # Transform image
            state_seq.append(image)

            # Retrieve action
            action_name = os.path.join(self.actions_dir, self.actions[idx + i])
            ## Check if the JSON file is empty or corrupted
            try:
                with open(action_name, 'r') as f:
                    if os.stat(action_name).st_size == 0:
                        print(f"Warning: {action_name} is empty. Skipping this file.")
                        action_seq.append(None)  # Skip this entry
                        continue

                    action_data = json.load(f)
                    action = action_data.get('action', None)

                    if action is None:
                        print(f"Warning: 'action' key missing in {action_name}")
                        action_seq.append(None)  # Skip this entry
                    else:
                        action_seq.append(ACTION_MAPPING[action])
            except json.JSONDecodeError as e:
                print(f"Warning: Failed to decode JSON in {action_name}: {e}")
                action_seq.append(None)  # Skip this entry


            #Retrieve Annotation
            annotation_name = os.path.join(self.annotations_dir, self.annotations[idx + i].replace('.jpg', '.txt'))
            with open(annotation_name, 'r') as f:
                annotations = [list(map(float, line.strip().split())) for line in f] # Ensure annotation is list of floats
                annotation_seq.append(annotations)

        # Convert state sequences to tensor
        state_seq_tensor = torch.stack(state_seq)

        # Convert action sequences to tensor
        action_seq_tensor = torch.tensor(action_seq, dtype=torch.long)

        max_annotations = 25
        padded_annotations = []
        for ann in annotation_seq:
            if len(ann) < max_annotations:
                ann += [[-1, 0, 0, 0, 0]] * (max_annotations - len(ann))
            elif len(ann) > max_annotations:
                ann = ann[:max_annotations]
            padded_annotations.append(ann)
        annotation_seq_tensor = torch.tensor(padded_annotations, dtype=torch.float32)

        
        return state_seq_tensor, action_seq_tensor, annotation_seq_tensor

Initialise Dataloader

In [None]:
# Chosen transformation
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor()
])

seq_length = 3
# Set the GCS bucket and folder prefixes
bucket_name = 'pokemonplatinumai-annotationimages'
states_prefix = 'images/'
actions_prefix = 'actions/'
annotations_prefix = 'labels/'

dataset = PokemonDatasetGCS(
    bucket_name=bucket_name,
    states_prefix=states_prefix,
    actions_prefix=actions_prefix,
    annotations_prefix=annotations_prefix,
    seq_length=seq_length,
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

print("DataLoader loaded!")

#### Checking Files

Ensuring the files in the dataloader are processed properly

In [None]:
# Loop through the dataloader and check if images and actions are valid
for i, (state_seq, action_seq, annotation_seq) in enumerate(dataloader):
    try:
        # Move tensors to device
        state_seq, action_seq, annotation_seq = state_seq.to(device), action_seq.to(device), annotation_seq.to(device)
        print(f"Processing batch {i}")
    except Exception as e:
        print(f"Error processing batch {i}: {e}")

    # Check the JSON action files using the GCS blobs instead of local paths
    for idx, action in enumerate(action_seq):
        action_file_name = dataset.actions[i + idx]  # Retrieve the actual GCS file name
        action_blob = dataset.bucket.blob(action_file_name)  # Get the blob from GCS

        try:
            # Check if the action file is empty
            action_content = action_blob.download_as_text()  # Download the content as text
            if not action_content:
                raise ValueError(f"Action file {action_file_name} is empty.")

            # Parse the JSON content
            action_data = json.loads(action_content)
            
            # Ensure 'action' is in the JSON
            if 'action' not in action_data:
                raise KeyError(f"Key 'action' missing in {action_file_name}")

        except json.JSONDecodeError:
            print(f"Invalid JSON format in file: {action_file_name}")
        except Exception as e:
            print(f"Error processing file {action_file_name}: {e}")


### Model Training

Initialise the model

In [None]:
from models.PokemonModelLSTM import PokemonModelLSTM
# Setting Hyperparameters
num_actions = 9  # (Total Number of Actions: [A, B, X, Y, Up, Down, Left, Right, None]) (Excluding Start, Select, L, R to reduce model complexity)
input_size = 32 * 160 * 160
hidden_size = 128
num_layers = 2
num_epochs = 20
learning_rate = 0.001

# Initialising model
model = PokemonModelLSTM(input_size, hidden_size, num_layers, num_actions).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


In [None]:
# Load the model and optimizer state if a checkpoint exists
if os.path.exists("models/pokemon_model_lstm.pth"):
    state_dict = torch.load("models/pokemon_model_lstm.pth")
    model.load_state_dict(state_dict)

In [None]:
# Initialize MLflow experiment
mlflow.set_experiment("PokemonModelLSTM_Pretrain_3")

# Start MLflow run
with mlflow.start_run():
    # Log parameters
    mlflow.log_param("num_actions", num_actions)
    mlflow.log_param("input_size", input_size)
    mlflow.log_param("hidden_size", hidden_size)
    mlflow.log_param("num_layers", num_layers)
    mlflow.log_param("learning_rate", learning_rate)
    mlflow.log_param("num_epochs", num_epochs)
 
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0                       # Track loss for the current epoch
        for i, (state_seq, action_seq, annotation_seq) in enumerate(dataloader):
            print(f"Currently processing Epoch {epoch}, batch {i}")
            state_seq, action_seq, annotation_seq = state_seq.to(device), action_seq.to(device), annotation_seq.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(state_seq, annotations=annotation_seq)

            # Calculate loss
            action_seq = action_seq[:, -1]  # Get the last action in the sequence for each batch
            loss = criterion(output, action_seq)
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            # Accumulate loss for the epoch
            epoch_loss += loss.item()
        
        # Save the model to MLflow
        print(f"Saving model to MLflow")
        mlflow.pytorch.log_model(model, f"model_epoch_{epoch+1}")

        # Log loss for the epoch
        avg_loss = epoch_loss / len(dataloader)
        mlflow.log_metric("loss", avg_loss, step=epoch)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss}")
        
        # Save the model, optimizer state, and other information after each epoch
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
        }
        model_path = f"models/pokemon_model_lstm_epoch_{epoch+1}.pth"
        torch.save(checkpoint, model_path)

    # Save final model with all relevant states
    final_checkpoint = {
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
        }
    
    final_model_path = "models/pokemon_model_lstm_final.pth"
    torch.save(final_checkpoint, final_model_path)
    mlflow.pytorch.log_model(model, "final_model")

print("Model and metrics logged with MLflow!")

In [None]:
import torch
import mlflow
import mlflow.pytorch

# Create a simple test model (or use your existing model)
class TestModel(torch.nn.Module):
    def __init__(self):
        super(TestModel, self).__init__()
        self.fc = torch.nn.Linear(10, 2)
    
    def forward(self, x):
        return self.fc(x)

# Initialize the model and optimizer
model = TestModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Test the logging step without running the full training loop
def test_mlflow_logging():
    with mlflow.start_run():
        # Create a dummy state dict and loss for testing purposes
        test_checkpoint = {
            'epoch': 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': 0.5,  # Dummy loss value
        }

        # Save and log the model using MLflow
        model_path = "models/test_model.pth"
        torch.save(test_checkpoint, model_path)

        try:
            # Test logging the model with MLflow
            mlflow.pytorch.log_model(model, "test_model")
            print("MLflow model logging successful.")
        except Exception as e:
            print(f"Error during MLflow logging: {e}")

# Run the test function
test_mlflow_logging()


In [None]:
def check_versions():
    try:
        # Check torch version using importlib_metadata
        torch_version = importlib_metadata.version("torch")
        print(f"importlib_metadata found torch version: {torch_version}")
    except KeyError as e:
        print(f"Error: Could not find version for 'torch': {e}")
    except Exception as e:
        print(f"Unexpected error occurred: {e}")

    # Also check torch version from the package itself
    print(f"torch.__version__: {torch.__version__}")

    try:
        # Simulate a log model call with MLflow to verify functionality
        model = torch.nn.Linear(2, 2)  # Create a simple dummy model for testing
        mlflow.set_tracking_uri("mlruns")
        mlflow.set_experiment("VerificationExperiment")

        with mlflow.start_run():
            mlflow.pytorch.log_model(
                model, 
                artifact_path="dummy_model",
                pip_requirements=[f"torch=={torch.__version__}", "cloudpickle==2.0.0"]
            )
            print("MLflow logging successful.")
    except Exception as e:
        print(f"MLflow encountered an error: {e}")

# Run the verification
check_versions()
