In [8]:
# Math Research Assistant - MVP Template using Gemini API

import sympy as sp
from google import genai
from google.genai import types
from sympy.parsing.sympy_parser import parse_expr

# Configure your Gemini API key
client = genai.Client(api_key="AIzaSyA4QRgymGf_wS2-j7uZ53KkIr2rUHiHAjU")

def query_llm(prompt):
    response = client.models.generate_content(
        model="gemini-2.0-flash",
        config=types.GenerateContentConfig(
        system_instruction="You are a math assistant helping researchers and students to solve problems"),
        contents=prompt
    )
    return response.text

def analyze_function(expr_str):
    try:
        x = sp.symbols('x')
        expr = parse_expr(expr_str)
        derivative = sp.diff(expr, x)
        second_derivative = sp.diff(derivative, x)
        is_convex = sp.simplify(second_derivative >= 0)
        return {
            "expression": expr,
            "first_derivative": derivative,
            "second_derivative": second_derivative,
            "is_convex": is_convex
        }
    except Exception as e:
        return {"error": str(e)}

if __name__ == "__main__":
    user_input = "Let f(x) = x**2. Show that it is convex."

    # Step 1: Ask LLM to break down and explain
    explanation = query_llm(f"Explain and attempt to prove the following: {user_input}")
    print("\nLLM Explanation:\n", explanation)

    # Step 2: Symbolic validation
    math_result = analyze_function("x**2")
    print("\nSymbolic Math Validation:")
    for key, value in math_result.items():
        print(f"{key}: {value}")



LLM Explanation:
 Okay, let's break down the concept of convexity and then prove that f(x) = x² is indeed a convex function.

**Understanding Convexity**

Intuitively, a convex function is one that "curves upwards." More formally, a function is convex if the line segment connecting any two points on its graph lies *above* or on the graph itself.

**Definition (Convex Function):**

A function  `f(x)` is convex on an interval `I` if for any `x₁, x₂ ∈ I` and any `t ∈ [0, 1]`, the following inequality holds:

`f(t * x₁ + (1 - t) * x₂) ≤ t * f(x₁) + (1 - t) * f(x₂)`

Let's break down what this inequality means:

*   `x₁` and `x₂` are two arbitrary points in the interval `I` (the domain where we're checking for convexity).
*   `t` is a number between 0 and 1.  `t * x₁ + (1 - t) * x₂` represents a point on the line segment connecting `x₁` and `x₂`.  Think of `t` as a weighting factor. When `t = 0`, the point is `x₂`; when `t = 1`, the point is `x₁`. When `t = 0.5`, it's the midpoint.
*   `f(

In [11]:
analyze_function("x**2")

{'expression': x**2,
 'first_derivative': 2*x,
 'second_derivative': 2,
 'is_convex': True}

In [16]:
expr = "Let f(x) = x**2. Show that it is convex."

In [13]:
expr

'x**2. Show that it is convex.'

In [48]:
def extract_expression(query):
    prompt = (
        "Convert the main mathematical expression from this problem into valid Python sympy syntax "
        "(using ** for powers, symbols like x, y, etc.):\n\n"
        f"{query}\n\n"
        "strictly return the main function expression in sympy and no addition of '''python ''' or such strings, no explanation."
    )
    response = query_llm(prompt)
    print(response)
    return response.strip().splitlines()[0].strip()


In [51]:
expr =extract_expression("Is the function f(x) = 1/x^3 convex on all real numbers?")

1/x**3



In [40]:
def analyze_function(expr_str):
    try:
        symbols = sp.symbols('x y z')  # expand if needed
        expr = parse_expr(expr_str)
        x = symbols[0] if expr.free_symbols else sp.symbols('x')
        derivative = sp.diff(expr, x)
        second_derivative = sp.diff(derivative, x)
        is_convex = sp.simplify(second_derivative >= 0)
        return {
            "expression": sp.latex(expr),
            "first_derivative": sp.latex(derivative),
            "second_derivative": sp.latex(second_derivative),
            "is_convex": sp.latex(is_convex)
        }
    except Exception as e:
        return {"error": str(e)}

In [41]:
res = analyze_function(expr)

In [43]:
expr

'log(x**2 + 1)'

In [30]:
expr

'import sympy'

In [52]:
analyze_function(expr)

{'expression': '\\frac{1}{x^{3}}',
 'first_derivative': '- \\frac{3}{x^{4}}',
 'second_derivative': '\\frac{12}{x^{5}}',
 'is_convex': '\\frac{12}{x^{5}} \\geq 0'}

In [55]:
x = sp.symbols('x')

In [57]:
sp.reduce_inequalities([-1/x**2>=0], x)

False

In [58]:
from sympy import sympify, symbols

# Define the string
string_equation = "x**2 + 3*x - 1/2"

# Convert the string to a SymPy expression
equation = sympify(string_equation)

# Print the SymPy expression
print(equation)

# Substitute x = 2
x = symbols('x')
print(equation.subs(x, 2))

x**2 + 3*x - 1/2
19/2


In [60]:
str(equation)

'x**2 + 3*x - 1/2'

In [68]:
sp.Eq(equation,0)

Eq(x**2 + 3*x - 1/2, 0)