In [1]:
from pprint import pprint
import numpy as np

### Q4), Q5), Q6)

What is the gradient of the following function with respect to x and y?
\begin{equation}
f(x, y) = x^2y + y^3 sin(x)
\end{equation}


### Answer)

Gradient is the vector of partial derivatives of a function and is supposed to be a vector that points in the direction of greatest increase of a function. The gradient of a function f(x, y) is denoted as $\nabla f$.

The partial derivative of f with respect to x is:

\begin{equation}
\frac{\partial f}{\partial x} = 2xy + y^3 cos(x)
\end{equation}

The partial derivative of f with respect to y is:

\begin{equation}
\frac{\partial f}{\partial y} = x^2 + 3y^2 sin(x)
\end{equation}

The gradient of f is:

\begin{equation}
\nabla f = \left(\frac{\partial f}{\partial x}, \frac{\partial f}{\partial y}\right) = \left(2xy + y^3 \cos(x), x^2 + 3y^2 \sin(x)\right)
\end{equation}

In [2]:
import jax

def f(x: int, y: int) -> float:
    return x ** 2 * y + y ** 3 * jax.numpy.sin(x)


class MyGradient:
    @staticmethod
    def df_dx(x: int, y: int) -> float:
        return 2 * x * y + y ** 3 * np.cos(x)
    
    @staticmethod
    def df_dy(x: int, y: int) -> float:
        return x ** 2 + 3 * y ** 2 * np.sin(x)
    
    def grad(self, x: int, y: int) -> tuple[float, float]:
        return self.df_dx(x, y), self.df_dy(x, y)

x = 2.0
y = 3.0

pprint(MyGradient().grad(x, y))
pprint(jax.grad(f, argnums=(0, 1))(x, y))

(0.7640354132271554, 28.551030524293406)
Metal device set to: Apple M1




(Array(0.7640343, dtype=float32, weak_type=True),
 Array(28.55103, dtype=float32, weak_type=True))


In [3]:
# sympy
import sympy as sp

# Define symbols
x, y = sp.symbols('x y')

# Define the function
f = x**2 * y + y**3 * sp.sin(x)

# Calculate partial derivatives
df_dx = sp.diff(f, x)
df_dy = sp.diff(f, y)

# Display the results
print("Partial derivative with respect to x:")
print(df_dx)

print("\nPartial derivative with respect to y:")
print(df_dy)

Partial derivative with respect to x:
2*x*y + y**3*cos(x)

Partial derivative with respect to y:
x**2 + 3*y**2*sin(x)
