In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

# Data
activations = ["Standard GELU", "Cached GELU", "Tanh Approx", "Sigmoid Approx"]
avg_times    = [0.009983,       0.009369,     0.009543,      0.009590]
accuracies   = [92.43,          92.43,        92.43,         92.43]

x = np.arange(len(activations))
width = 0.6

fig, ax1 = plt.subplots(figsize=(9, 6))

# --- Bar chart for inference time ---
colors = ['#3498db', '#27ae60', '#f39c12', '#9b59b6']
bars = ax1.bar(x, avg_times, width, color=colors, edgecolor='black', linewidth=1)

# Highlight the Cached GELU bar
bars[1].set_facecolor('#2ecc71')
bars[1].set_edgecolor('black')
bars[1].set_linewidth(2)

# Annotate bar values
for bar in bars:
    y = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2, y + 0.00005,
             f"{y:.5f}", ha='center', va='bottom', fontsize=10)

ax1.set_xticks(x)
ax1.set_xticklabels(activations, rotation=20, ha='right', fontsize=11)
ax1.set_ylabel("Avg. Inference Time (s)", fontsize=12)
ax1.set_ylim(0, max(avg_times) + 0.001)
ax1.grid(axis='y', linestyle='--', alpha=0.5)

# --- Line plot for accuracy on twin axis ---
ax2 = ax1.twinx()
line, = ax2.plot(x, accuracies, color='firebrick', linestyle='-', marker='o',
                 markersize=8, linewidth=2)
# Annotate accuracy points
for xi, acc in zip(x, accuracies):
    ax2.text(xi, acc + 0.02, f"{acc:.2f}%", ha='center',
             va='bottom', color='firebrick', fontsize=10)

ax2.set_ylabel("Accuracy (%)", color='firebrick', fontsize=12)
ax2.set_ylim(min(accuracies) - 0.1, max(accuracies) + 0.5)
ax2.tick_params(axis='y', labelcolor='firebrick')

# --- Custom Legend ---
bar_patch = mpatches.Patch(color='#3498db', label='Std. GELU Time')
cached_patch = mpatches.Patch(color='#2ecc71', label='Cached GELU Time', edgecolor='black')
tanh_patch = mpatches.Patch(color='#f39c12', label='Tanh Approx Time')
sigmoid_patch = mpatches.Patch(color='#9b59b6', label='Sigmoid Approx Time')
acc_line = mlines.Line2D([], [], color='firebrick', marker='o', linestyle='-',
                         markersize=8, label='Accuracy (%)')

# Place legend below the plot
handles = [bar_patch, cached_patch, tanh_patch, sigmoid_patch, acc_line]
labels  = [h.get_label() for h in handles]
fig.legend(handles, labels, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.15), fontsize=10)

# --- Title & layout ---
plt.title("Activation Function Comparison: Inference Time & Accuracy", fontsize=14, pad=15)
plt.tight_layout()

plt.show()
