In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import HTML

# -----------------------------
# 1. داده مصنوعی
# -----------------------------
np.random.seed(42)
X = 2 * np.random.rand(50)
Y = 4 + 3 * X + np.random.randn(50)

# -----------------------------
# 2. گرادیان نزولی
# -----------------------------
a, b = 0.0, 0.0
lr = 0.1
epochs = 30

a_hist, b_hist = [a], [b]
cost_hist = [np.sum((Y - (a*X + b))**2)/(2*len(X))]

for _ in range(epochs):
    y_pred = a * X + b
    da = -2 * np.sum(X * (Y - y_pred)) / len(X)
    db = -2 * np.sum(Y - y_pred) / len(X)
    a -= lr * da
    b -= lr * db
    a_hist.append(a)
    b_hist.append(b)
    cost_hist.append(np.sum((Y - (a*X + b))**2)/(2*len(X)))

# -----------------------------
# 3. سطح تابع هزینه
# -----------------------------
a_vals = np.linspace(min(a_hist)*0.8, max(a_hist)*1.2, 50)
b_vals = np.linspace(min(b_hist)*0.8, max(b_hist)*1.2, 50)
A, B = np.meshgrid(a_vals, b_vals)
Z = np.zeros_like(A)
for i in range(A.shape[0]):
    for j in range(A.shape[1]):
        Z[i,j] = np.sum((Y - (A[i,j]*X + B[i,j]))**2)/(2*len(X))

# -----------------------------
# 4. نمودار 3D و خط روی داده‌ها
# -----------------------------
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111, projection='3d')

# سطح هزینه
surf = ax.plot_surface(A, B, Z, cmap='viridis', alpha=0.6)

# نقاط داده‌ها روی محور هزینه فعلی (z=cost) برای هر پارامتر فعلی)
param_point, = ax.plot([], [], [], 'ro', markersize=6)
path_line, = ax.plot([], [], [], 'r--', alpha=0.3)

ax.set_xlabel('a (slope)')
ax.set_ylabel('b (intercept)')
ax.set_zlabel('Cost')
ax.set_title('3D Cost Surface with Gradient Descent')

# -----------------------------
# 5. init و update
# -----------------------------
def init():
    param_point.set_data([], [])
    param_point.set_3d_properties([])
    path_line.set_data([], [])
    path_line.set_3d_properties([])
    return param_point, path_line

def update(i):
    cost_now = cost_hist[i]
    param_point.set_data([a_hist[i]], [b_hist[i]])
    param_point.set_3d_properties([cost_now])

    path_line.set_data(a_hist[:i+1], b_hist[:i+1])
    path_line.set_3d_properties(cost_hist[:i+1])

    # چرخش 3D
    ax.view_init(elev=30, azim=i*12)
    return param_point, path_line

# -----------------------------
# 6. اجرای انیمیشن
# -----------------------------
anim = FuncAnimation(fig, update, frames=len(a_hist), init_func=init, blit=False, interval=400)
plt.close()
HTML(anim.to_jshtml())
