In [None]:
!pip install python-fasthtml

In [34]:
server.stop()

In [35]:
from fasthtml.common import *
from fasthtml.jupyter import *
import torch

app = FastHTML()
count = [0]

def _close(a, b, default=False):
    gtype = a.dtype
    
    if gtype in [torch.uint8, torch.int32, torch.int64]:
        if a.shape == b.shape: return torch.equal(a,b)
        return False
    
    if not default:
        if gtype == torch.float32:
            atol, rtol = 1e-6, 1e-5
        elif gtype == torch.bfloat16:
            atol, rtol = 1e-3, 1e-2
        else:
            atol, rtol = 1e-4, 1e-3
    else:
        atol, rtol = 1e-8, 1e-5
    return torch.allclose(a, b, rtol=rtol, atol=atol)
    
@app.route('/')
def home():
    pytorch_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"]
    dropdown_a = Select(
        Option("Select PyTorch Version A", value="", selected=True),
        *[Option(version, value=version) for version in pytorch_versions],
        name="version_a",
        id="version_a"
    )

    dropdown_b = Select(
        Option("Select PyTorch Version B", value="", selected=True),
        *[Option(version, value=version) for version in pytorch_versions],
        name="version_b",
        id="version_b"
    )
    return Div(
        H1("Index Artifact Comparison"),
        dropdown_a,
        dropdown_b,
        Button("Compare", hx_post="/compare", hx_target="#comparison", hx_include="#version_a, #version_b"),
        Div(id="comparison")
    )

@app.route('/compare')
def compare(version_a: str = "", version_b: str = ""):
    path_a = f"20250915-0.2.22.main.torch.{version_a}-1/indexing/ConditionalQA/"
    path_b = f"20250915-0.2.22.main.torch.{version_b}-1/indexing/ConditionalQA/"

    a = os.listdir(path_a)
    b = os.listdir(path_b)
    f_match = 0
    for i, f in enumerate(a): 
        if f == b[i]: f_match += 1
            
    a_pts = [f for f in a if f.endswith(".pt")]
    b_pts = [f for f in b if f.endswith(".pt")]
    
    shape_mismatches = 0
    shape_count = 0
    default = "False"
    
    for i, f in enumerate(a_pts):
        a_pt = torch.load(path_a + f)
        b_pt = torch.load(path_b + f)
        
        if isinstance(a_pt, tuple):
            shape_count += 2
            match1 = a_pt[0].shape == b_pt[0].shape
            match2 = a_pt[1].shape == b_pt[1].shape
            if not match1: shape_mismatches += 1
            if not match2: shape_mismatches += 1
        else:
            shape_count += 1
            match = a_pt.shape == b_pt.shape
            if not match: shape_mismatches += 1

    value_mismatches = 0
    mismatches = []

    for i, f in enumerate(a_pts):
        a_pt = torch.load(path_a + f)
        b_pt = torch.load(path_b + f)
        
        if isinstance(a_pt, tuple):
            if a_pt[0].shape == b_pt[0].shape:
                match1 = _close(a_pt[0], b_pt[0], default=default)
            else:
                match1 = False
                
            if a_pt[1].shape == b_pt[1].shape:
                match2 = _close(a_pt[1], b_pt[1], default=default)
            else:
                match2 = False
                
            if not (match1 and match2):
                value_mismatches += 1
                mismatches.append(f)
        else:
            if a_pt.shape == b_pt.shape:
                match = _close(a_pt, b_pt, default=default)
            else:
                match = False
                
            if not match:
                value_mismatches += 1 
                mismatches.append(f)

    if len(mismatches) == 0: mismatches.append("All tensors match!")
        
    return Div(
        H2(f"Comparing PyTorch {version_a} vs {version_b}"),
        P(f"{f_match}/{len(a)} file names match"),
        P(f"{shape_count - shape_mismatches}/{shape_count} Tensor Shapes Match"),
        H3("Tensor Value Mismatches"),
        *[P(f, style="color: red;") for f in mismatches]
    )

# Use port 8000 - Modal forwards this to the tunnel URL
server = JupyUvi(app, port=8000, host='0.0.0.0')