In [5]:
import numpy as np
import plotly.graph_objects as go
import pandas as pd


np.random.seed(42)

logits = np.array([2.0, 1.0, 0.1])
labels = ['v1', 'v2', 'v3']

def sample_gumbel(shape, eps=1e-20):
    U = np.random.uniform(0, 1, shape)
    return -np.log(-np.log(U + eps) + eps)

tau = 0.5
samples = 10

logits_data = []

for i in range(samples):
    noise = sample_gumbel(logits.shape)
    for j, label in enumerate(labels):
        logits_data.append({
            'Sample': i,
            'Codeword': label,
            'Noisy Logit': (logits[j] + noise[j]) / tau,
            'Type': 'With Gumbel Noise'
        })
        logits_data.append({
            'Sample': i,
            'Codeword': label,
            'Noisy Logit': logits[j] / tau,
            'Type': 'No Noise'
        })

df = pd.DataFrame(logits_data)

fig = go.Figure()

colors = ['#1f77b4', '#ff7f0e', '#2ca02c']

for i, codeword in enumerate(labels):
    df_noise = df[(df['Codeword'] == codeword) & (df['Type'] == 'With Gumbel Noise')]
    df_nonoise = df[(df['Codeword'] == codeword) & (df['Type'] == 'No Noise')]
    
    color = colors[i % len(colors)]
    
    fig.add_trace(go.Scatter(
        x=df_noise['Sample'],
        y=df_noise['Noisy Logit'],
        mode='lines+markers',
        name=f'{codeword} (w noise)',
        line=dict(dash='solid'),
        marker=dict(color=color),
        opacity=1.0
    ))
    
    
    fig.add_trace(go.Scatter(
        x=df_nonoise['Sample'],
        y=df_nonoise['Noisy Logit'],
        mode='lines+markers',
        name=f'{codeword} (w/o noise)',
        line=dict(dash='dot'),
        marker=dict(color=color),
        opacity=0.3
    ))

fig.update_layout(
    title='Comparison of Logits With and Without Gumbel Noise',
    xaxis_title='Sample Index',
    yaxis_title='(logit + noise) / tau',
    legend_title='Codewords'
)

fig.show()


In [8]:
import numpy as np
import plotly.graph_objects as go

x_vals = np.linspace(-5, 5, 200)

argmax_outputs = [1 if x > 0 else 0 for x in x_vals]

# tau_0_5 = [np.exp(x / 0.5) / (np.exp(x / 0.5) + 1.0) for x in x_vals]
# tau_1_0 = [np.exp(x / 1.0) / (np.exp(x / 1.0) + 1.0) for x in x_vals]

taus = np.linspace(0.02, 2.0, 20)
tau_lines = []
for tau in taus:
    tau_line = np.exp(x_vals / tau) / (np.exp(x_vals / tau) + 1.0)
    tau_lines.append(tau_line)

fig = go.Figure()

# argmax
fig.add_trace(go.Scatter(
    x=x_vals, y=argmax_outputs,
    mode='lines',
    name='argmax(x, 0)',
    line=dict(dash='dot', color='red')
))

opacitys = np.linspace(0.95, 0.2, len(tau_lines))
for i, tau_line in enumerate(tau_lines):
    fig.add_trace(
        go.Scatter(
            x=x_vals, y=tau_line,
            mode='lines',
            name=f'softmax(x, 0), tau={taus[i]:.2f}',
            line=dict(color="#a44ede"),
            opacity=opacitys[i],
            showlegend=False
        )
    )

fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='markers',
        marker=dict(
            colorscale=[
                [0, 'rgba(164, 78, 222, 0.95)'],
                [1, 'rgba(164, 78, 222, 0.2)']
            ],
            cmin=taus.min(),
            cmax=taus.max(),
            colorbar=dict(
                title='tau',
                title_side='top',
                thickness=15,
                len=0.7,
                tickvals=[taus.min(), taus.max()],
                ticktext=[f'{taus.min():.2f}', f'{taus.max():.2f}']
            )
        ),
        showlegend=False
    )
)

fig.update_layout(
    title='Argmax vs Softmax Approximation (tau: 0.02 ~ 2.0)',
    xaxis_title='Logit x',
    yaxis_title='Probability of Choosing x',
)

fig.show()

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px

num_codewords = 3
logits = np.linspace(1, 2.0, num_codewords)
labels = [f'codeword {i+1}' for i in range(num_codewords)]

np.random.seed(42)
def sample_gumbel(shape, eps=1e-20):
    U = np.random.uniform(0, 1, shape)
    return -np.log(-np.log(U + eps) + eps)

gumbel_noise = sample_gumbel(logits.shape)

taus = [2.0, 1.0, 0.5, 0.1]

softmax_data = []

for tau in taus:
    noisy_logits = (logits + gumbel_noise) / tau
    probs = np.exp(noisy_logits) / np.sum(np.exp(noisy_logits))
    for i in range(num_codewords):
        softmax_data.append({
            'Tau': tau,
            'Codeword': labels[i],
            'Probability': probs[i]
        })

# 建立 dataframe
df = pd.DataFrame(softmax_data)

fig = px.strip(
    df,
    x='Codeword',
    y='Probability',
    color='Tau',
    stripmode='group',
    title='Gumbel-Softmax with different Tau',
    labels={'Probability': 'Selection Probability'},
    category_orders={'Codeword': labels}
)

fig.update_traces(
    jitter=0.4,
    marker=dict(size=10, opacity=0.8),
    selector=dict(type='scatter')
)

fig.update_layout(
    xaxis_title='Codewords',
    yaxis_title='Probability',
    legend_title='tau',
    yaxis=dict(range=[-0.05, 1.05]),
    width=800,
    height=300
)

fig.show()


In [3]:
import torch
import torch.nn as nn

# batch_size=1, channel=4, time_steps=10
# x = torch.randn(1, 4, 3)

torch.manual_seed(39)
x = torch.tensor([[[1.0, 0.0, 1.0],
                   [0.0, 1.0, 0.0],
                   [1.0, 0.0, 1.0]]])
print(x)

pointwise_conv = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=1)


output = pointwise_conv(x)
print(output)
print(output.shape)


tensor([[[1., 0., 1.],
         [0., 1., 0.],
         [1., 0., 1.]]])
tensor([[[-0.3052,  0.4857, -0.3052],
         [ 0.6070, -0.2642,  0.6070],
         [-0.4318, -0.1112, -0.4318],
         [ 0.4235,  0.0670,  0.4235],
         [ 0.8403,  0.1467,  0.8403],
         [-0.3708,  0.7868, -0.3708]]], grad_fn=<ConvolutionBackward0>)
torch.Size([1, 6, 3])


In [None]:
import torch
import torch.nn as nn

x = torch.randn(2, 64, 100)
pointwise_conv = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1)
out = pointwise_conv(x)
print(out.shape)

depthwise_conv = nn.Conv1d(in_channels=128, out_channels=128, kernel_size=3, padding=1, groups=128)
out2 = depthwise_conv(out)
print(out2.shape)

glu_input = torch.randn(2, 128, 100)
glu = nn.GLU(dim=1)
print(glu_input.shape)
out3 = glu(glu_input)
print(out3.shape)


torch.Size([2, 128, 100])
torch.Size([2, 128, 100])
torch.Size([2, 128, 100])
torch.Size([2, 64, 100])
