In [1]:
from rl_training import QLearningAI

pygame 2.5.2 (SDL 2.28.2, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [3]:
def test_batch_size(model, input_shape, max_batch_size=1024, step=32):
    """
    Tests the maximum batch size that will fit in memory.
    
    :param model: Model for testing.
    :param input_shape: Input tensor size without considering the batch size.
    :param max_batch_size: Maximum batch size for testing.
    :param step: Step for increasing the batch size.
    :return: Maximum batch size that fits in memory.
    """
    model.eval()  # Switch the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation
        for batch_size in range(step, max_batch_size + step, step):
            try:
                # Create a random input tensor
                input_tensor = torch.randn((batch_size, *input_shape)).to('cuda')
                print(input_tensor.shape)
                
                # Pass the tensor through the model
                _ = model(input_tensor)
                
                print(f"Batch size {batch_size} passed!")
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print(f"Batch size {batch_size} caused out of memory error!")
                    return batch_size - step
                else:
                    raise e
    return max_batch_size

# Usage example:
ai = QLearningAI(None, None)  # Initialize your QLearningAI class
ai.init(-1, None, None)  # Initialize the model
max_batch = test_batch_size(ai.online_model, (17, 16, 5))
print(f"Maximum batch size that fits in memory: {max_batch}")


torch.Size([32, 17, 16, 5])
Batch size 32 passed!
torch.Size([64, 17, 16, 5])
Batch size 64 passed!
torch.Size([96, 17, 16, 5])
Batch size 96 passed!
torch.Size([128, 17, 16, 5])
Batch size 128 passed!
torch.Size([160, 17, 16, 5])
Batch size 160 passed!
torch.Size([192, 17, 16, 5])
Batch size 192 passed!
torch.Size([224, 17, 16, 5])
Batch size 224 passed!
torch.Size([256, 17, 16, 5])
Batch size 256 passed!
torch.Size([288, 17, 16, 5])
Batch size 288 passed!
torch.Size([320, 17, 16, 5])
Batch size 320 passed!
torch.Size([352, 17, 16, 5])
Batch size 352 passed!
torch.Size([384, 17, 16, 5])
Batch size 384 passed!
torch.Size([416, 17, 16, 5])
Batch size 416 passed!
torch.Size([448, 17, 16, 5])
Batch size 448 passed!
torch.Size([480, 17, 16, 5])
Batch size 480 passed!
torch.Size([512, 17, 16, 5])
Batch size 512 passed!
torch.Size([544, 17, 16, 5])
Batch size 544 passed!
torch.Size([576, 17, 16, 5])
Batch size 576 passed!
torch.Size([608, 17, 16, 5])
Batch size 608 passed!
torch.Size([640, 1

In [4]:
import torch

# Sample tensor
tensor = torch.rand((20, 17, 16, 5))

# Horizontal flip
horizontal_flipped = torch.flip(tensor, [3])

# Vertical flip
vertical_flipped = torch.flip(tensor, [2])

# Stack original tensor with the flipped tensors along the batch dimension
augmented_tensor = torch.cat((tensor, horizontal_flipped, vertical_flipped), 0)

# Check the shape
print(augmented_tensor.shape)  # Should print torch.Size([60, 17, 16, 5])


torch.Size([60, 17, 16, 5])
