In [1]:
import torch
import math
from transformers import GPT2LMHeadModel
from Levenshtein import distance as lev_dist
# helper edit distance
#def lev_dist(s1, s2):
#    return abs(len(s1) - len(s2)) 

# --- Tokenizer Mapping (Must match training) ---
chars = sorted(list("0123456789+="))
char_to_id = {s: i for i, s in enumerate(chars)}
id_to_char = {i: s for i, s in enumerate(chars)}

def has_carry(a, b):
    """Detects if an addition problem requires at least one carry."""
    sa, sb = str(a)[::-1], str(b)[::-1]
    max_l = max(len(sa), len(sb))
    sa, sb = sa.ljust(max_l, '0'), sb.ljust(max_l, '0')
    carry = 0
    for i in range(max_l):
        if int(sa[i]) + int(sb[i]) + carry >= 10:
            return True
        carry = (int(sa[i]) + int(sb[i]) + carry) // 10
    return False

def decode_prediction(model, prompt, max_new=6):
    """Generates and cleans output with an explicit attention mask."""
    device = next(model.parameters()).device
    input_ids = torch.tensor([char_to_id[c] for c in prompt]).unsqueeze(0).to(device)

    # attention mask
    attention_mask = torch.ones_like(input_ids).to(device)
    
    with torch.no_grad():
        output_tokens = model.generate(
            input_ids, 
            attention_mask=attention_mask,      # Pass the mask here
            max_new_tokens=max_new, 
            pad_token_id=char_to_id['='],
            eos_token_id=char_to_id['='],      # Explicitly set EOS
            do_sample=False                    # Greedy decoding
        )
    
    gen_ids = output_tokens[0][len(input_ids[0]):]
    pred_str = ""
    for tid in gen_ids:
        c = id_to_char[tid.item()]
        if c in "0123456789": 
            pred_str += c
        else: 
            break
    return pred_str

def run_eval(model_path, test_file, is_reversed=True):
    model = GPT2LMHeadModel.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    
    stats = {
        "total": 0, "em": 0, "mae": 0, "edit_dist": 0,
        "digit_correct": 0, "digit_total": 0,
        "carry_cases": 0, "carry_correct": 0,
        "no_carry_cases": 0, "no_carry_correct": 0
    }

    with open(test_file, 'r') as f:
        for line in f:
            if not line.strip(): continue
            
            prompt_part, expected_str = line.strip().split('=')
            a_val, b_val = map(int, prompt_part.replace('+', ' ').split())
            
            # 1. Prediction
            pred_str = decode_prediction(model, prompt_part + "=")
            
            # 2. Numerical Values (Un-reverse for math logic)
            true_num = a_val + b_val
            try:
                # If result was reversed in training (e.g. 123+456=975), un-reverse to 579
                cleaned_pred = pred_str[::-1] if is_reversed else pred_str
                pred_num = int(cleaned_pred)
            except:
                pred_num = 0 

            # --- Metrics Calculation ---
            stats["total"] += 1
            
            # Exact Match (String comparison)
            if pred_str == expected_str:
                stats["em"] += 1
            
            # MAE (Numerical distance)
            stats["mae"] += abs(true_num - pred_num)
            
            # Edit Distance (String similarity)
            stats["edit_dist"] += lev_dist(pred_str, expected_str)
            
            # Digit-wise Accuracy
            max_len = max(len(pred_str), len(expected_str))
            p_pad = pred_str.ljust(max_len, ' ')
            e_pad = expected_str.ljust(max_len, ' ')
            for p, e in zip(p_pad, e_pad):
                if p == e: stats["digit_correct"] += 1
                stats["digit_total"] += 1
            
            # Carry-Conditional Accuracy
            requires_carry = has_carry(a_val, b_val)
            if requires_carry:
                stats["carry_cases"] += 1
                if pred_str == expected_str: stats["carry_correct"] += 1
            else:
                stats["no_carry_cases"] += 1
                if pred_str == expected_str: stats["no_carry_correct"] += 1

    # Final Report Output
    print(f"\n===== RESULTS FOR {test_file} =====")
    print(f"Exact Match Accuracy:     {stats['em']/stats['total']:.2%}")
    print(f"Digit-wise Accuracy:      {stats['digit_correct']/stats['digit_total']:.2%}")
    print(f"Mean Absolute Error (MAE): {stats['mae']/stats['total']:.2f}")
    print(f"Average Edit Distance:    {stats['edit_dist']/stats['total']:.2f}")
    print("-" * 40)
    if stats['no_carry_cases'] > 0:
        print(f"Easy (No Carry) Accuracy: {stats['no_carry_correct']/stats['no_carry_cases']:.2%}")
    if stats['carry_cases'] > 0:
        print(f"Hard (With Carry) Accuracy:{stats['carry_correct']/stats['carry_cases']:.2%}")
    
    #return stats

### Evaluation data strategy 1 (randomly uniform samples)

In [2]:
run_eval("./final_model", "test_id.txt")
run_eval("./final_model", "test_ood.txt")


===== RESULTS FOR test_id.txt =====
Exact Match Accuracy:     100.00%
Digit-wise Accuracy:      100.00%
Mean Absolute Error (MAE): 0.00
Average Edit Distance:    0.00
----------------------------------------
Easy (No Carry) Accuracy: 100.00%
Hard (With Carry) Accuracy:100.00%

===== RESULTS FOR test_ood.txt =====
Exact Match Accuracy:     0.00%
Digit-wise Accuracy:      3.35%
Mean Absolute Error (MAE): 10880.65
Average Edit Distance:    3.73
----------------------------------------
Easy (No Carry) Accuracy: 0.00%
Hard (With Carry) Accuracy:0.00%


### Evaluation data strategy 2 (Mixed-Length Training)

In [3]:
run_eval("./final_model_mixed", "test_id_mixed.txt")
run_eval("./final_model_mixed", "test_ood_mixed.txt")


===== RESULTS FOR test_id_mixed.txt =====
Exact Match Accuracy:     100.00%
Digit-wise Accuracy:      100.00%
Mean Absolute Error (MAE): 0.00
Average Edit Distance:    0.00
----------------------------------------
Easy (No Carry) Accuracy: 100.00%
Hard (With Carry) Accuracy:100.00%

===== RESULTS FOR test_ood_mixed.txt =====
Exact Match Accuracy:     0.00%
Digit-wise Accuracy:      0.00%
Mean Absolute Error (MAE): 10940.77
Average Edit Distance:    4.62
----------------------------------------
Easy (No Carry) Accuracy: 0.00%
Hard (With Carry) Accuracy:0.00%


### Evaluation data strategy 3 (Zero Padding)

In [4]:
run_eval("./final_model_padded", "test_id_padded.txt")
run_eval("./final_model_padded", "test_ood_padded.txt")


===== RESULTS FOR test_id_padded.txt =====
Exact Match Accuracy:     99.50%
Digit-wise Accuracy:      99.88%
Mean Absolute Error (MAE): 0.50
Average Edit Distance:    0.01
----------------------------------------
Easy (No Carry) Accuracy: 100.00%
Hard (With Carry) Accuracy:99.15%

===== RESULTS FOR test_ood_padded.txt =====
Exact Match Accuracy:     0.00%
Digit-wise Accuracy:      48.64%
Mean Absolute Error (MAE): 10287.83
Average Edit Distance:    2.55
----------------------------------------
Easy (No Carry) Accuracy: 0.00%
Hard (With Carry) Accuracy:0.00%
