In [1]:
import torch

In [2]:
from model import MemoryRNN
from data_loader import generate_task_data
from train import train_model
print("Successfully imported all modules.")

Successfully imported all modules.


In [3]:
# --- 2. CONFIGURATION: Defining the Experiment Parameters ---

INPUT_SIZE = 3      # Vocabulary size (Pad, A, B)
HIDDEN_SIZE = 16    # The model's memory capacity
OUTPUT_SIZE = 3     # Number of possible answers the model can give
SEQUENCE_LENGTH = 10# How long the model needs to remember the signal
#250  600  900 1300
NUM_EPOCHS = 1300   # How many training cycles to run
BATCH_SIZE = 128    # How many examples to show the model at once
LEARNING_RATE = 0.005 # How quickly the model learns

In [4]:
# --- 3. INSTANTIATION: Building the Model from Our Blueprint ---
# We create an actual instance of our model using the blueprint from model.py
# and the parameters we defined above.
print("\nCreating the MemoryRNN model...")
rnn_model = MemoryRNN(
    input_size=INPUT_SIZE, 
    hidden_size=HIDDEN_SIZE, 
    output_size=OUTPUT_SIZE
)
print(rnn_model)


Creating the MemoryRNN model...
MemoryRNN(
  (rnn): RNN(3, 16, batch_first=True)
  (fc): Linear(in_features=16, out_features=3, bias=True)
)


In [5]:
# --- 4. EXECUTION: Running the Training Process ---
# We call our training function from train.py, passing it the model we just
# created and all the training parameters. This function will return the
# model after its weights have been updated through learning.
trained_model = train_model(
    model=rnn_model,
    num_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    sequence_length=SEQUENCE_LENGTH,
    learning_rate=LEARNING_RATE
)

# --- 5. EXAMINATION: Testing the Trained Model ---
# This is the most important part: the final exam for our model.
# We must verify that it has actually learned the task.

print("\n--- Testing the Trained Model on a New, Unseen Example ---")

# Set the model to evaluation mode. This is a good practice that turns off
# certain training-specific layers like Dropout.
trained_model.eval()

# The `torch.no_grad()` context manager tells PyTorch that we are not
# training, so it doesn't need to calculate gradients, which saves memory and computation.
with torch.no_grad():
    # Generate one single, new test sample the model has never seen before.
    test_input, test_label = generate_task_data(1, SEQUENCE_LENGTH)
    
    # Get the model's raw output (logits) for this test sample.
    test_output = trained_model(test_input)
    
    # Find the model's actual prediction by finding the index of the highest logit.
    _, predicted_idx = torch.max(test_output.data, 1)

    signal_map = {1: 'A', 2: 'B'}
    correct_signal = signal_map[test_label.item()]
    predicted_signal = signal_map[predicted_idx.item()]
    
    print(f"The task was to remember the signal: '{correct_signal}'")
    print(f"The model's final prediction was:     '{predicted_signal}'")
    
    if correct_signal == predicted_signal:
        print("\n[SUCCESS]: The model has learned to integrate information over time.")
    else:
        print("\n[FAILURE]: The model did not learn the task. Consider increasing NUM_EPOCHS.")

--- Starting Training ---
Epoch [25/1300], Loss: 0.7320
Epoch [50/1300], Loss: 0.7002
Epoch [75/1300], Loss: 0.6974
Epoch [100/1300], Loss: 0.7146
Epoch [125/1300], Loss: 0.7046
Epoch [150/1300], Loss: 0.6963
Epoch [175/1300], Loss: 0.6952
Epoch [200/1300], Loss: 0.6922
Epoch [225/1300], Loss: 0.7018
Epoch [250/1300], Loss: 0.6853
Epoch [275/1300], Loss: 0.6941
Epoch [300/1300], Loss: 0.6939
Epoch [325/1300], Loss: 0.6936
Epoch [350/1300], Loss: 0.6929
Epoch [375/1300], Loss: 0.6933
Epoch [400/1300], Loss: 0.6895
Epoch [425/1300], Loss: 0.6937
Epoch [450/1300], Loss: 0.6940
Epoch [475/1300], Loss: 0.6932
Epoch [500/1300], Loss: 0.6926
Epoch [525/1300], Loss: 0.7014
Epoch [550/1300], Loss: 0.6875
Epoch [575/1300], Loss: 0.6962
Epoch [600/1300], Loss: 0.6932
Epoch [625/1300], Loss: 0.6930
Epoch [650/1300], Loss: 0.6934
Epoch [675/1300], Loss: 0.6905
Epoch [700/1300], Loss: 0.6919
Epoch [725/1300], Loss: 0.6930
Epoch [750/1300], Loss: 0.6937
Epoch [775/1300], Loss: 0.6931
Epoch [800/1300]