In [2]:
import torch
from model_arch import simplifiedV2

def test_model():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Model parameters
    vocab_size = 1000
    embedding_dim = 256
    transformer_dim = 256
    num_layers = 4
    num_heads = 8
    batch_size = 2
    seq_length = 10
    
    # Initialize model
    model = simplifiedV2(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        transformer_dim=transformer_dim,
        num_layers=num_layers,
        num_heads=num_heads
    )
    
    # Create dummy input data
    input_tokens = torch.randint(0, vocab_size, (batch_size, seq_length))
    timesteps = torch.rand(batch_size)  # Random timesteps between 0 and 1
    
    # Create padding mask (example: last 2 tokens are padding)
    padding_mask = torch.ones_like(input_tokens, dtype=torch.bool)
    padding_mask[:, -2:] = False
    
    # Move everything to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    input_tokens = input_tokens.to(device)
    timesteps = timesteps.to(device)
    padding_mask = padding_mask.to(device)
    
    # Set model to evaluation mode
    model.eval()
    
    # Forward pass
    with torch.no_grad():
        output = model(input_tokens, timesteps, padding_mask)
    
    # Print shapes and sample outputs
    print(f"Input shape: {input_tokens.shape}")
    print(f"Output shape: {output.shape}")
    print(f"\nSample input tokens:\n{input_tokens[0]}")
    print(f"\nSample output logits (first 5 values):\n{output[0, 0, :5]}")
    
    # Test that output dimensions are correct
    assert output.shape == (batch_size, seq_length, vocab_size), "Output shape mismatch"
    
    return output

if __name__ == "__main__":
    try:
        output = test_model()
        print("\nModel test completed successfully!")
    except Exception as e:
        print(f"Test failed with error: {str(e)}")

Input shape: torch.Size([2, 10])
Output shape: torch.Size([2, 10, 1000])

Sample input tokens:
tensor([725, 839, 772, 184, 760, 463, 911,  20, 932, 581], device='cuda:0')

Sample output logits (first 5 values):
tensor([ 0.1450, -0.8028, -0.0697,  0.7078, -0.0741], device='cuda:0')

Model test completed successfully!


  output = torch._nested_tensor_from_mask(
