In [23]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
from agents.base_agent import BaseAgent

class MNISTEnvironment:
    """
    A class to manage the MNIST dataset and provide data batches.
    """

    def __init__(self, batch_size=64, shuffle=True, download=True, data_path='./data'):
        """
        Initializes the MNIST environment.

        Args:
            batch_size (int): The batch size for the DataLoader.
            shuffle (bool): Whether to shuffle the data.
            download (bool): Whether to download the dataset if it's not present.
            data_path (str): The path to store the downloaded dataset.
        """
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.data_path = data_path
        self.agents = []

        # Define transformations
        transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.1307,), (0.3081,))
        ])

        # Load the MNIST dataset
        self.train_dataset = datasets.MNIST(root=self.data_path, train=True, download=download, transform=transform)
        self.test_dataset = datasets.MNIST(root=self.data_path, train=False, download=download, transform=transform)

        # Create DataLoaders
        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=self.shuffle)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False) #test loader should not shuffle.
        self.train_iterator = iter(self.train_loader)
        self.test_iterator = iter(self.test_loader)

    def get_train_batch(self):
        """
        Returns a batch of training data.
        """
        try:
            return next(self.train_iterator)
        except StopIteration:
            self.train_iterator = iter(self.train_loader) #reset iterator at end of epoch.
            return next(self.train_iterator)

    def get_test_batch(self):
        """
        Returns a batch of test data.
        """
        try:
            return next(self.test_iterator)
        except StopIteration:
            self.test_iterator = iter(self.test_loader) #reset iterator at end of epoch.
            return next(self.test_iterator)

    def get_train_loader(self):
        """
        Returns the train dataloader
        """
        return self.train_loader

    def get_test_loader(self):
        """
        Returns the test dataloader
        """
        return self.test_loader
    
    def create_agent(self):
        self.agents.append(BaseAgent(28*28, 10))

    def list_agents(self):
        for i, agent in enumerate(self.agents):
            print("Agent " + str(i) + " is created")


In [24]:

env = MNISTEnvironment()
train_batch_data, train_batch_labels = env.get_train_batch()
test_batch_data, test_batch_labels = env.get_test_batch()

print("Train batch data shape:", train_batch_data.shape)
print("Train batch labels shape:", train_batch_labels.shape)
print("Test batch data shape:", test_batch_data.shape)
print("Test batch labels shape:", test_batch_labels.shape)

train_loader = env.get_train_loader()
for batch_data, batch_labels in train_loader:
        print("Example train loader batch shape: ", batch_data.shape)
        break

Train batch data shape: torch.Size([64, 1, 28, 28])
Train batch labels shape: torch.Size([64])
Test batch data shape: torch.Size([64, 1, 28, 28])
Test batch labels shape: torch.Size([64])
Example train loader batch shape:  torch.Size([64, 1, 28, 28])


In [None]:
env.create_agent()
env.list_agents()

Agent 0 is created
Agent 1 is created
Agent 2 is created
Agent 3 is created
Agent 4 is created


In [None]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch
import numpy as np
from utils import config
from agents.base_agent import BaseAgent

# From base class to specified agent
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.SUBSET_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10000, shuffle=False)

# Get a small subset of training data
X_train, y_train = next(iter(train_loader))
X_train = X_train.view(-1, 28 * 28)  # Flatten images

X_test, y_test = next(iter(test_loader))
X_test = X_test.view(-1, 28 * 28)
# From base class to specified brain ?? Maybe.


# Evolution loop
population = [BaseAgent(config.INPUT_SIZE, config.OUTPUT_SIZE) for _ in range(config.POPULATION_SIZE)]

# Mutation function
def mutate(network, mutation_rate=0.05):
    child = BaseAgent(config.INPUT_SIZE, config.OUTPUT_SIZE)
    child.load_state_dict(network.state_dict())  # Copy weights
    
    for param in child.parameters():
        if len(param.shape) > 1:  # Only mutate weights, not biases
            param.data += torch.randn_like(param) * mutation_rate
            
    return child
for gen in range(config.GENERATIONS):
    # Evaluate fitness
    fitness_scores = [evaluate(nn, X_train, y_train) for nn in population]
    
    # Select the top-k networks
    sorted_indices = np.argsort(fitness_scores)[::-1]  # Sort by highest accuracy
    best_networks = [population[i] for i in sorted_indices[:config.TOP_K]]
    
    print(f"Generation {gen+1}, Best Accuracy: {fitness_scores[sorted_indices[0]]:.4f}")

    # Create new population: keep top-k and mutate
    new_population = best_networks[:]
    while len(new_population) < config.POPULATION_SIZE:
        parent = np.random.choice(best_networks)  # Pick a random elite
        child = mutate(parent, config.MUTATION_RATE)
        new_population.append(child)

    population = new_population  # Replace old population


# Evaluation function (accuracy)
def evaluate(network, X, y):
    with torch.no_grad():
        output = network(X)
        predictions = torch.argmax(output, dim=1)
        return (predictions == y).float().mean().item()



# Final evaluation on test set
best_nn = best_networks[0]
test_accuracy = evaluate(best_nn, X_test, y_test)
print(f"Final Test Accuracy: {test_accuracy:.4f}")

Generation 1, Best Accuracy: 0.1390
Generation 2, Best Accuracy: 0.1670
Generation 3, Best Accuracy: 0.1670
Generation 4, Best Accuracy: 0.1940
Generation 5, Best Accuracy: 0.1940
Generation 6, Best Accuracy: 0.1940
Generation 7, Best Accuracy: 0.1940
Generation 8, Best Accuracy: 0.1940
Generation 9, Best Accuracy: 0.1940
Generation 10, Best Accuracy: 0.2020
Generation 11, Best Accuracy: 0.2020
Generation 12, Best Accuracy: 0.2330
Generation 13, Best Accuracy: 0.2330
Generation 14, Best Accuracy: 0.2330
Generation 15, Best Accuracy: 0.2330
Generation 16, Best Accuracy: 0.2330
Generation 17, Best Accuracy: 0.2330
Generation 18, Best Accuracy: 0.2330
Generation 19, Best Accuracy: 0.2330
Generation 20, Best Accuracy: 0.2330
Generation 21, Best Accuracy: 0.2330
Generation 22, Best Accuracy: 0.2330
Generation 23, Best Accuracy: 0.2330
Generation 24, Best Accuracy: 0.2330
Generation 25, Best Accuracy: 0.2330
Generation 26, Best Accuracy: 0.2330
Generation 27, Best Accuracy: 0.2330
Generation