<a href="https://colab.research.google.com/github/riasat-sheikh/QM-codes/blob/main/TISE_shoot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1D Time-Independent Schrödinger Equation

*  **Course:** 現代物理学序論 (鈴木 博 先生ー Prof. H. Suzuki)
*  **Notebook Author:** Riasat Sheikh, TA
*  **Affiliation:** Theory of Elementary Particle Physics Lab, Kyushu University, JP  
*  **Contact:** riasat.sheikh@phys.kyushu-u.ac.jp

> **Note:** You can also access this notebook from the GitHub repository: https://github.com/riasat-sheikh/QM-codes


We aim to solve the time-independent Schrödinger equation (TISE) for a particle in a 1D potential $V(x)$

$$-\frac{\hbar^2}{2m} \frac{d^2\psi}{dx^2} + V(x)\psi(x) = E\psi(x).$$

To simplify the numerical integration, we use natural units where $\hbar = m = 1$. The above equation becomes

$$\frac{d^2\psi}{dx^2} = 2\big(V(x) - E\big)\psi(x).$$

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import VBox, BoundedFloatText, Dropdown, Output, Layout, FloatSlider, HBox
from IPython.display import display, Markdown
from scipy.integrate import solve_ivp

# Plotting styles
plt.rcParams.update({
    "font.family": "sans-serif",
    "mathtext.fontset": "cm",
    "figure.figsize": (9, 9),
    "axes.grid": True,
    "grid.alpha": 0.4,
    "axes.labelsize": 14,
    "axes.titlesize": 16,
    "legend.fontsize": 12,
    "font.size": 12
})

We must rewrite the second-order TISE as a system of two first-order coupled equations for the solver `solve_ivp`. Let us define $y_0^{} = \psi$ and $y_1^{} = \psi'$. Therefore, the system becomes

$$\frac{\mathrm dy_0}{\mathrm dx} = y_1^{},$$
$$\frac{\mathrm dy_1}{\mathrm dx} = 2\big(V(x) - E\big)y_0^{}.$$

The function below implements this system and applies the initial conditions based on your choice of **Even** or **Odd** parity.

> There is also an auto mode, which can guess the parity for you!

In [None]:
# --- 1. Potentials ---
potentials = {
    "Finite square well": lambda x, V0=10.0, a=1/2: np.where(np.abs(x) <= a, 0.0, V0),
    "Harmonic oscillator": lambda x: 0.5 * np.asarray(x) ** 2,
}

# --- 2. The ODE System ---
def schrodinger_rhs(x, y, energy, potential):
    """
    Returns [dpsi/dx, d2psi/dx2]
    """
    psi, dpsi = y
    return [dpsi, 2.0 * (potential(x) - energy) * psi]

# --- energy state counter ---
def count_nodes(psi, Vx, E):
    """
    Counts nodes only where E > V(x) to avoid ghost nodes in diverging tails.
    """
    # 1. Identify the classically allowed region
    allowed_mask = Vx < E

    # 2. If energy is too low (no allowed region), return 0
    if not np.any(allowed_mask):
        return 0

    # 3. Filter wavefunction to this region
    psi_safe = psi[allowed_mask]

    # 4. Remove exact zeros to prevent double-counting crossings
    signs = np.sign(psi_safe)
    signs = signs[signs != 0]

    # 5. Count sign changes
    if len(signs) < 2:
        return 0
    return np.sum(np.abs(np.diff(signs)) > 0)

## The Shooting Method
For symmetric potentials, i.e., $V(x) = V(-x)$, the eigenstates have definite **parity**. Instead of shooting from $-\infty$ to $+\infty$, we exploit symmetry and shoot from the center ($x=0$) outwards.

We define initial conditions based on the parity we are looking for

<!-- |Parity|$\psi(0)$|$\psi'(0)$|
|:---|:---:|:---:|
|Even ($\psi(x) = \psi(-x)$)|1|0|
|Odd ($\psi(x) = -\psi(-x)$)|0|1| -->

1.  **Even Parity ($\psi(x) = \psi(-x)$):**
    * Condition: $\psi(0) = 1, \quad \psi'(0) = 0$
    * (Slope must be flat at the center)
    
2.  **Odd Parity ($\psi(x) = -\psi(-x)$):**
    * Condition: $\psi(0) = 0, \quad \psi'(0) = 1$
    * (Value must be zero at the center)

The code below integrates from $x=0$ to $x=L$ and then mirrors the result to show the full wavefunction.

In [None]:
# --- 3. The Shooting Routine ---
def shoot_symmetric(energy, potential_key, parity, L=6.0):
    pot = potentials[potential_key]

    # Apply Boundary Conditions at x=0
    if parity == "even":
        y0 = [1.0, 0.0]  # Even: Flat slope at center
    else:
        y0 = [0.0, 1.0]  # Odd: Zero value at center

    # Integrate from 0 to L
    sol = solve_ivp(
        schrodinger_rhs,
        (0.0, L),
        y0=y0,
        args=(energy, pot),
        dense_output=True,
        max_step=0.05,
    )

    # Generate arrays for plotting
    x_pos = np.linspace(0.0, L, 500)
    psi_pos = sol.sol(x_pos)[0]
    Vx_pos = pot(x_pos)

    # Mirror the results to show the full domain (-L to +L)
    x = np.concatenate((-x_pos[::-1], x_pos))
    Vx = np.concatenate((Vx_pos[::-1], Vx_pos))

    if parity == "even":
        psi = np.concatenate((psi_pos[::-1], psi_pos))
    else:
        psi = np.concatenate((-psi_pos[::-1], psi_pos)) # Anti-symmetric flip

    # The residual is the value of psi at the far boundary L.
    # For a bound state, this should be effectively 0.
    residual = psi[-1]

    return x, psi, Vx, residual

## Widgets and Plot Setup

Here we define the controls and the plotting logic


In [None]:
# --- Widgets ---
E_input = BoundedFloatText(
    value=0.5, min=0.0, max=1.0E6, step=0.05,
    description="Energy (E):", layout=Layout(width="50%")
)

potential_dropdown = Dropdown(
    options=list(potentials.keys()),
    value="Harmonic oscillator",
    description="Potential:", layout=Layout(width="50%")
)

parity_dropdown = Dropdown(
    options=[
        ("Auto (Find best match)", "auto"),
        ("Even: ψ(0)=1", "even"),
        ("Odd:  ψ(0)=0", "odd")
    ],
    value="even",
    description="Parity:", layout=Layout(width="50%")
)

vis_scaler = FloatSlider(
    value=1.0,
    min=1.0,
    max=50.0,
    step=0.1,
    description='Scaling',
    disabled=False,
    continuous_update=False,
    orientation='vertical',
    readout=True,
    readout_format='.1f',
)

out = Output()

In [None]:
# --- Plotting Logic ---
def update_plot(change=None):
    with out:
        out.clear_output(wait=True)
        E = E_input.value
        pot_name = potential_dropdown.value
        parity_mode = parity_dropdown.value

        # Link the slider value here!
        vis_scale = vis_scaler.value

        # --- AUTO LOGIC ---
        if parity_mode == "auto":
            x_even, psi_even, Vx_even, res_even = shoot_symmetric(E, pot_name, "even")
            x_odd,  psi_odd,  Vx_odd,  res_odd  = shoot_symmetric(E, pot_name, "odd")

            if abs(res_even) < abs(res_odd):
                x, psi, Vx, residual = x_even, psi_even, Vx_even, res_even
                chosen_parity = "even"
            else:
                x, psi, Vx, residual = x_odd, psi_odd, Vx_odd, res_odd
                chosen_parity = "odd"

            title_parity_info = f"(Parity: {chosen_parity.capitalize()})"
        else:
            x, psi, Vx, residual = shoot_symmetric(E, pot_name, parity_mode)
            chosen_parity = parity_mode
            title_parity_info = ""

        n_nodes = count_nodes(psi, Vx, E)

        if abs(residual) < 0.5:
          psi_label = rf"$\psi(x)$ ($n={n_nodes}$)"
        else:
          psi_label = rf"$\psi(x)$"

        # --- PLOTTING ---
        fig, ax = plt.subplots(figsize=(9, 5))

        ax.plot(x, Vx, 'k-', lw=2, alpha=0.4, label=r"$V(x)$")
        ax.fill_between(x, Vx, color='gray', alpha=0.1)

        ax.axhline(E, color='tab:red', linestyle='--', lw=1.5, label=r"Energy $E$")

        # Use the variable defined from the slider above
        psi_to_plot = np.clip(psi * vis_scale + E, -1.0E5, 1.0E5)

        ax.plot(x, psi_to_plot, 'b-', lw=2, label=psi_label)
        ax.axhline(E, color='b', alpha=0.1)

        if chosen_parity == 'even':
            ax.plot(0, E + vis_scale, 'bo', label=r"BC: $\psi(0)=1$")
        else:
            ax.plot(0, E, 'bo', markerfacecolor='white', label=r"BC: $\psi(0)=0$")

        V_max = np.max(Vx)
        if E < V_max:
          ax.set_ylim(-V_max - 2, V_max + 2)
        else:
          ax.set_ylim(-E - 10, E + 10)

        if pot_name == "Finite square well":
          ax.set_xlim(-2, 2)
        else:
          ax.set_xlim(-6,6)

        ax.set_xlabel("Position $x$")
        separator = " " if title_parity_info else ""
        ax.set_title(f"Residual: {residual:.2e} | {title_parity_info}{separator}")

        ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0, frameon=True)

        plt.tight_layout()
        plt.show()

        display(Markdown('<br>'))
        if abs(residual) < 0.1:
            display(Markdown(f"✅ **Eigenstate found!** The wavefunction dies off at the edges."))
        elif abs(residual) > 100:
            display(Markdown(f"❌ **Diverging.** The tail is exploding. Try changing $E$."))
        else:
            display(Markdown(f"⚠️ **Close?** The tail is non-zero. Tune $E$ slightly."))

# Link Widgets
E_input.observe(update_plot, names="value")
potential_dropdown.observe(update_plot, names="value")
parity_dropdown.observe(update_plot, names="value")
vis_scaler.observe(update_plot, names="value")
# Slider Layout
vis_scaler.layout = Layout(
    margin='0px 0px 0px 20px',  # Add 20px space to the left
    height='300px'              # Set a fixed height so it looks neat
)
# Inputs Layout
input_controls = VBox([potential_dropdown, parity_dropdown, E_input])
input_controls.layout = Layout(
    width='60%',                  # Makes it smaller
    margin='0px 0px 20px 0px',     # Adds bottm gap between inputs and plot
    overflow = 'visible'
)

## Launch Interface

**Instructions:**
1.  **Select Parity:**
    * Choose **Even** ($\psi(0)=1$) if you are looking for the ground state ($n=1$) or symmetric states.
    * Choose **Odd** ($\psi(0)=0$) if you are looking for the first excited state ($n=2$) or anti-symmetric states.
    * Use **Auto Mode** if you are stuck. It will test both parities and show you the one that fits best!
2.  **Tune Energy:** Change $E$ until the tail of the wavefunction (blue line) goes to zero, i.e., matches with the red dashed line.
    * If the tail shoots up or down to $\infty$, you are not at an eigenvalue.
    * **Goal:** Minimizing the "Residual" value shown in the title.
    * Use the scaling slider to scale the value of $\psi(x)$ if it gets difficult to see the waves!


> Tip: You can click on the `...` button on the output cell and then scelect `View output fullscreen`.

In [None]:
# Create a horizontal container for the plot and the amplitude scaler
plot_area = HBox([out, vis_scaler], layout=Layout(align_items='center'))
# Display the main controls on top, and the plot area below
main_ui = VBox(
    [input_controls, plot_area],
    layout=Layout(
        width='100%',
        align_items='center',
        padding='100px 0px'
    )
)
display(main_ui)
update_plot()