In [3]:
import torch

dtype = torch.float32 
BATCH = 1
SEQ = 10
DIM = 2

# ... your setup ...
u = torch.randn(BATCH, SEQ , dtype=dtype)
z = u
k = 3
zminus = z - 1

# Merge zminus and z to find beta candidates (for the *same* batch as we'll use below)
merged_tensor = torch.cat((z, z - 1), dim=1)
merged_sorted_tensor, _ = torch.sort(merged_tensor, descending=True)
merged_sorted_tensor = merged_sorted_tensor[0]

# Sort z (descending) for the same batch
z, _ = torch.sort(z, descending=True)
z = z[0]                                  # 1D scores (length m)
m = z.numel()

# Cumulative sums and sentinel padding
z_cum = torch.cumsum(z, dim=0)            # S[t] = sum_{i=0..t} z[i]
z_pad = torch.cat([z, z.new_tensor([-float('inf')])])  # z[m] = -inf for right-guard

pairs = []  # will store (U_count, W_count) as *counts*
for beta in merged_sorted_tensor:
    b = beta.item()

    # masks over sorted z
    u_ge_mask = (z >= b + 1)      # will be True for indices [0 .. U-1]
    w_gt_mask = (z > b)           # will be True for indices [0 .. W-1]

    U = int(u_ge_mask.sum().item())   # <-- counts, not indices
    W = int(w_gt_mask.sum().item())   # <-- counts, not indices
    if W <= U:                        # need at least one fractional slot
        continue

    pairs.append((U, W))

# scan in the paper's order: descending in beta (pairs already collected that way)
tau = None
for U, W in pairs:
    # sums over i in [U .. W-1] using zero-based inclusive prefix sums:
    # sum_{U..W-1} z = z_cum[W-1] - (z_cum[U-1] if U>0 else 0)
    seg_sum = z_cum[W-1] - (z_cum[U-1] if U > 0 else z_cum.new_zeros(()))
    tau_cand = (seg_sum + (U - k)) / (W - U)

    # interval checks (use sentinels; interpret counts -> 0-based indices)
    left_ok  = (z[W-1] > tau_cand) if W >= 1 else True
    right_ok = (tau_cand >= z_pad[W])  # z_pad[W] is safe for W==m
    up_ok    = (z[U-1] >= tau_cand + 1) if U >= 1 else True
    down_ok  = ((tau_cand + 1) > z_pad[U])  # z_pad[U] safe for U==m

    if left_ok and right_ok and up_ok and down_ok:
        tau = tau_cand
        break

if tau is None:
    print('hello')
    # (extremely rare) fallback: bisection on tau
    lo = (z.min() - 1).item()
    hi = z.max().item()
    for _ in range(60):
        mid = (lo + hi) / 2.0
        s = torch.clamp(z - mid, 0, 1).sum().item()
        if s > k:   # need larger tau to shrink sum
            lo = mid
        else:
            hi = mid
    tau = torch.tensor((lo + hi) / 2.0, dtype=z.dtype, device=z.device)

p = torch.clamp(z - tau, 0, 1)
print("sum p =", float(p.sum()))
# p is for the sorted z[0]; map it back to your original order if needed.

sum p = 3.0
