In [8]:
import numpy as np
from pathlib import Path

def load_and_compare_tensors():
    # Set up paths
    base_path = Path('/home/xidaren2/xshortfin/goldens')
    iree_input_path = base_path / 'iree_run/prefill_inputs'
    shark_input_path = base_path / 'sharktank/prefill_inputs'
    
    # Load all tensors
    iree_tensors = {f.name: np.load(f) for f in iree_input_path.glob('*.npy')}
    shark_tensors = {f.name: np.load(f) for f in shark_input_path.glob('*.npy')}
    
    print("Tensor Comparison Analysis")
    print("=" * 80)
    
    # Verify same tensor names exist in both
    all_tensors = sorted(set(iree_tensors.keys()) | set(shark_tensors.keys()))
    if set(iree_tensors.keys()) != set(shark_tensors.keys()):
        print("\n⚠️  Warning: Tensor name mismatch!")
        print(f"IREE only: {set(iree_tensors.keys()) - set(shark_tensors.keys())}")
        print(f"Shark only: {set(shark_tensors.keys()) - set(iree_tensors.keys())}")
    
    # Compare each tensor
    for tensor_name in all_tensors:
        print(f"\n{tensor_name}")
        print("-" * 80)
        
        iree_tensor = iree_tensors.get(tensor_name)
        shark_tensor = shark_tensors.get(tensor_name)
        
        if iree_tensor is None or shark_tensor is None:
            print("❌ Tensor missing in one implementation")
            continue
            
        # Basic properties comparison
        print(f"Shape: {'✅ Match' if iree_tensor.shape == shark_tensor.shape else '❌ Mismatch'}")
        print(f"  IREE:  {iree_tensor.shape}")
        print(f"  Shark: {shark_tensor.shape}")
        
        print(f"\nDtype: {'✅ Match' if iree_tensor.dtype == shark_tensor.dtype else '❌ Mismatch'}")
        print(f"  IREE:  {iree_tensor.dtype}")
        print(f"  Shark: {shark_tensor.dtype}")
        
        # For small tensors, show values
        if iree_tensor.size < 32:
            print("\nValues:")
            print(f"  IREE:  {iree_tensor.tolist()}")
            print(f"  Shark: {shark_tensor.tolist()}")
            print(f"  Match: {'✅' if np.array_equal(iree_tensor, shark_tensor) else '❌'}")
        
        # For larger tensors, show some statistics
        else:
            print("\nStatistics:")
            print(f"  IREE  - min: {iree_tensor.min():.6f}, max: {iree_tensor.max():.6f}, "
                  f"mean: {iree_tensor.mean():.6f}, std: {iree_tensor.std():.6f}")
            print(f"  Shark - min: {shark_tensor.min():.6f}, max: {shark_tensor.max():.6f}, "
                  f"mean: {shark_tensor.mean():.6f}, std: {shark_tensor.std():.6f}")
            
            # For numerical tensors, compute differences
            if np.issubdtype(iree_tensor.dtype, np.number) and np.issubdtype(shark_tensor.dtype, np.number):
                if iree_tensor.dtype != shark_tensor.dtype:
                    # Convert to higher precision for comparison
                    iree_tensor = iree_tensor.astype(np.float64)
                    shark_tensor = shark_tensor.astype(np.float64)
                
                abs_diff = np.abs(iree_tensor - shark_tensor)
                print("\nDifferences:")
                print(f"  Max absolute diff: {abs_diff.max():.6e}")
                print(f"  Mean absolute diff: {abs_diff.mean():.6e}")
                print(f"  Std of diff: {abs_diff.std():.6e}")

load_and_compare_tensors()

Tensor Comparison Analysis

cache_state_0.npy
--------------------------------------------------------------------------------
Shape: ✅ Match
  IREE:  (128, 2662400)
  Shark: (128, 2662400)

Dtype: ❌ Mismatch
  IREE:  float16
  Shark: float32

Statistics:
  IREE  - min: 0.000000, max: 0.000000, mean: 0.000000, std: 0.000000
  Shark - min: 0.000000, max: 0.000000, mean: 0.000000, std: 0.000000

Differences:
  Max absolute diff: 0.000000e+00
  Mean absolute diff: 0.000000e+00
  Std of diff: 0.000000e+00

seq_block_ids.npy
--------------------------------------------------------------------------------
Shape: ✅ Match
  IREE:  (1, 1)
  Shark: (1, 1)

Dtype: ✅ Match
  IREE:  int64
  Shark: int64

Values:
  IREE:  [[127]]
  Shark: [[127]]
  Match: ✅

seq_lens.npy
--------------------------------------------------------------------------------
Shape: ✅ Match
  IREE:  (1,)
  Shark: (1,)

Dtype: ✅ Match
  IREE:  int64
  Shark: int64

Values:
  IREE:  [11]
  Shark: [11]
  Match: ✅

tokens.npy
--

In [9]:
import numpy as np
from pathlib import Path
from scipy.special import softmax, log_softmax

# Load logits
base_path = Path('/home/xidaren2/xshortfin/goldens')
iree_path = base_path / 'iree_run'
shark_path = base_path / 'sharktank'

iree_logits = np.load(iree_path / 'prefill_outputs/logits.npy')
shark_logits = np.load(shark_path / 'prefill_outputs/logits.npy')

# Print shapes for verification
print(f"IREE logits shape: {iree_logits.shape}")
print(f"Shark logits shape: {shark_logits.shape}")

def compute_metrics(iree_slice, shark_slice, k=10):
    """Compute comparison metrics between IREE and Shark logits."""
    # Top-k analysis
    iree_top_k = np.argsort(iree_slice)[-k:][::-1]
    shark_top_k = np.argsort(shark_slice)[-k:][::-1]
    
    # Probability distributions
    iree_probs = softmax(iree_slice)
    shark_probs = softmax(shark_slice)
    iree_logprobs = log_softmax(iree_slice)
    shark_logprobs = log_softmax(shark_slice)
    
    # Cross entropy calculations
    ce_iree_to_shark = -np.sum(shark_probs * iree_logprobs)
    ce_shark_to_iree = -np.sum(iree_probs * shark_logprobs)
    
    return {
        'iree_top_k': iree_top_k,
        'shark_top_k': shark_top_k,
        'top_1_match': iree_top_k[0] == shark_top_k[0],
        'iree_top_1': iree_top_k[0],
        'shark_top_1': shark_top_k[0],
        'iree_top_1_prob': iree_probs[iree_top_k[0]],
        'shark_top_1_prob': shark_probs[shark_top_k[0]],
        'ce_iree_to_shark': ce_iree_to_shark,
        'ce_shark_to_iree': ce_shark_to_iree,
        'symmetric_ce': (ce_iree_to_shark + ce_shark_to_iree) / 2,
        'different_tokens_in_topk': len(set(iree_top_k) ^ set(shark_top_k)),
        'iree_probs': iree_probs,
        'shark_probs': shark_probs
    }

# Analysis for a single position (can be modified/rerun for different positions)
pos = 0  # Change this to analyze different positions
iree_slice = iree_logits[0, pos, :]
shark_slice = shark_logits[0, pos, :]
metrics = compute_metrics(iree_slice, shark_slice)

# Print basic comparison
print(f"\nPosition {pos} Analysis:")
print(f"Top-1 Match: {metrics['top_1_match']}")
print(f"IREE Top-1 Token: {metrics['iree_top_1']} (prob: {metrics['iree_top_1_prob']:.4f})")
print(f"Shark Top-1 Token: {metrics['shark_top_1']} (prob: {metrics['shark_top_1_prob']:.4f})")
print(f"Cross Entropy (IREE→Shark): {metrics['ce_iree_to_shark']:.4f}")
print(f"Cross Entropy (Shark→IREE): {metrics['ce_shark_to_iree']:.4f}")
print(f"Symmetric Cross Entropy: {metrics['symmetric_ce']:.4f}")

# Print top-10 comparison
print("\nTop 10 tokens comparison:")
print(f"{'Rank':^5} | {'IREE Token':^10} | {'IREE Prob':^10} | {'Shark Token':^10} | {'Shark Prob':^10}")
print("-" * 60)

for i in range(10):
    iree_token = metrics['iree_top_k'][i]
    shark_token = metrics['shark_top_k'][i]
    iree_prob = metrics['iree_probs'][iree_token]
    shark_prob = metrics['shark_probs'][shark_token]
    print(f"{i+1:^5} | {iree_token:^10} | {iree_prob:^10.4f} | {shark_token:^10} | {shark_prob:^10.4f}")

IREE logits shape: (1, 16, 32000)
Shark logits shape: (1, 16, 32000)

Position 0 Analysis:
Top-1 Match: True
IREE Top-1 Token: 29532 (prob: 0.2710)
Shark Top-1 Token: 29532 (prob: 0.2686)
Cross Entropy (IREE→Shark): 2.2219
Cross Entropy (Shark→IREE): 2.2169
Symmetric Cross Entropy: 2.2194

Top 10 tokens comparison:
Rank  | IREE Token | IREE Prob  | Shark Token | Shark Prob
------------------------------------------------------------
  1   |   29532    |   0.2710   |   29532    |   0.2686  
  2   |   29536    |   0.2507   |   29536    |   0.2533  
  3   |   29529    |   0.1241   |   29529    |   0.1236  
  4   |   29556    |   0.0826   |   29556    |   0.0833  
  5   |   29562    |   0.0555   |   29562    |   0.0557  
  6   |   29561    |   0.0530   |   29561    |   0.0513  
  7   |   29570    |   0.0350   |   29570    |   0.0351  
  8   |     13     |   0.0299   |     13     |   0.0299  
  9   |   29574    |   0.0293   |   29574    |   0.0293  
 10   |   29571    |   0.0254   |   29571

In [10]:
# Analyze positions 0 to 15 in a loop
results = []
for pos in range(16):
    iree_slice = iree_logits[0, pos, :]
    shark_slice = shark_logits[0, pos, :]
    metrics = compute_metrics(iree_slice, shark_slice)
    results.append(metrics)

# Print summary table
print(f"{'Position':^8} | {'Top-1 Match':^11} | {'IREE Top-1':^10} | {'Shark Top-1':^11} | {'CE (I→S)':^10} | {'CE (S→I)':^10} | {'Sym CE':^10}")
print("-" * 82)

for pos, metrics in enumerate(results):
    print(f"{pos:^8} | {str(metrics['top_1_match']):^11} | {metrics['iree_top_1']:^10} | "
          f"{metrics['shark_top_1']:^11} | {metrics['ce_iree_to_shark']:^10.4f} | "
          f"{metrics['ce_shark_to_iree']:^10.4f} | {metrics['symmetric_ce']:^10.4f}")

Position | Top-1 Match | IREE Top-1 | Shark Top-1 |  CE (I→S)  |  CE (S→I)  |   Sym CE  
----------------------------------------------------------------------------------
   0     |    True     |   29532    |    29532    |   2.2219   |   2.2169   |   2.2194  
   1     |    True     |   29567    |    29567    |   3.3613   |   3.3616   |   3.3615  
   2     |    True     |   29532    |    29532    |   2.1337   |   2.1323   |   2.1330  
   3     |    True     |   29500    |    29500    |   3.2956   |   3.3074   |   3.3015  
   4     |    True     |   29556    |    29556    |   1.0110   |   1.0065   |   1.0087  
   5     |    True     |   29500    |    29500    |   1.7236   |   1.7242   |   1.7239  
   6     |    True     |   29562    |    29562    |   0.2077   |   0.2080   |   0.2079  
   7     |    True     |   29500    |    29500    |   1.6375   |   1.6412   |   1.6393  
   8     |    True     |   29561    |    29561    |   0.1707   |   0.1695   |   0.1701  
   9     |    True     |   

In [11]:
iree_slice = iree_logits[0, 10, :]
shark_slice = shark_logits[0, 10, :]

iree_slice.argsort()

array([25790, 19701, 30583, ..., 29529, 29532, 29570])

In [12]:
shark_slice.argsort()

array([25790, 19701, 30583, ..., 29529, 29532, 29570])