diff --git a/examples/models/llama/tests/test_ring_attention.py b/examples/models/llama/tests/test_ring_attention.py index df0d0733033..ae440e00e47 100644 --- a/examples/models/llama/tests/test_ring_attention.py +++ b/examples/models/llama/tests/test_ring_attention.py @@ -163,10 +163,17 @@ def test_single_token_processing( ) # Check that outputs are the same - self.assertTrue( - torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), - f"Outputs differ at position {pos}", - ) + if kv_cache_type == KVCacheType.REGULAR: + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-7, atol=1e-7), + f"Outputs differ at position {pos}", + ) + else: + # For quantized kv cache we need bigger margin + self.assertTrue( + torch.allclose(baseline_out, ring_out, rtol=1e-6, atol=1e-6), + f"Outputs differ at position {pos}", + ) def test_single_token_processing_quantized(self): """Test single token processing with QuantizedKVCache."""