## MLP token selection Alg

In [310]:
import torch
import torch.nn.functional as F
torch.set_printoptions(precision=7, linewidth=200)
rand = torch.randint(0, 100, (1,))
torch.manual_seed(0)

def fmt(t):
    return "[" + ", ".join(f"{x:.7f}" for x in t.tolist()) + "]"

def gumbel_softmax(logits, tau=1.0, hard=False, dim=-1):
    return F.gumbel_softmax(logits, tau=tau, hard=hard, dim=dim)

L = 10  # 토큰 개수
C = 50  # 토큰 embedding 차원

# feature token x
tokens = torch.randn(L, C)

# router = MLP(x) : 여기서는 간단히 랜덤 logits
mu = 1.0
sigma = 0.2
logits = mu + sigma * torch.randn(L)
# logits = torch.abs(logits)
print(f"logits: {fmt(logits)}")

torch.manual_seed(rand)

# gumbel softmax
p_soft_ = F.softmax(logits, dim=0)
print(f"softmax 뒤: {fmt(p_soft_)}")
large_pos = torch.argmax(p_soft_)
print(f"결정 지점 (0-base): {large_pos.item()}")

p_soft = F.gumbel_softmax(logits, tau=1, hard=False, dim=0)
print(f"Gumbel 뒤: {fmt(p_soft)}")
large_pos = torch.argmax(p_soft)
print(f"결정 지점 (0-base): {large_pos.item()}")

# cumulative mask
cumsum_p = torch.cumsum(p_soft, dim=0)
pos = cumsum_p[large_pos].item()
print(f"순차합: {fmt(cumsum_p)}")
print(f"결정 지점 값 1: {pos:.3f}")

keep_soft = 1.0 - torch.roll(cumsum_p, shifts=1, dims=0)
keep_soft[0] = 1.0
print(f"keep_soft: {fmt(keep_soft)}")
pos = keep_soft[large_pos].item()
print(f"결정 지점 값 2: {pos:.3f}")

# ST-trick (optional)
keep_hard = (keep_soft >= pos).float()
keep_mask = (keep_hard - keep_soft).detach() + keep_soft
print(f"keep_mask: {fmt(keep_mask)}")

keep_mask[large_pos] = p_soft[large_pos]
if large_pos + 1 != 10:
    keep_mask[large_pos+1] = 1 - cumsum_p[large_pos]

print(f"keep_mask: {fmt(keep_mask)}")

# masked tokens
tokens_masked = tokens * keep_mask.unsqueeze(-1)



logits: [1.1678578, 1.0495843, 1.0413324, 1.1985563, 0.8202838, 0.7594303, 1.6300246, 0.9567003, 1.5219138, 0.9448810]
softmax 뒤: [0.1021033, 0.0907140, 0.0899685, 0.1052864, 0.0721258, 0.0678676, 0.1620901, 0.0826676, 0.1454804, 0.0816963]
결정 지점 (0-base): 6
Gumbel 뒤: [0.0718179, 0.1162387, 0.0288235, 0.0141581, 0.2481723, 0.0791423, 0.1583174, 0.1305682, 0.0721859, 0.0805758]
결정 지점 (0-base): 4
순차합: [0.0718179, 0.1880566, 0.2168801, 0.2310381, 0.4792104, 0.5583526, 0.7166700, 0.8472382, 0.9194242, 1.0000000]
결정 지점 값 1: 0.479
keep_soft: [1.0000000, 0.9281821, 0.8119434, 0.7831199, 0.7689619, 0.5207896, 0.4416474, 0.2833300, 0.1527618, 0.0805758]
결정 지점 값 2: 0.769
keep_mask: [1.0000000, 1.0000000, 1.0000000, 1.0000000, 1.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000]
keep_mask: [1.0000000, 1.0000000, 1.0000000, 1.0000000, 0.2481723, 0.5207896, 0.0000000, 0.0000000, 0.0000000, 0.0000000]
