In [None]:
import math
import copy
from functools import partial
import numpy as np
import scipy.integrate as si
import matplotlib.pyplot as plt
from scipy.stats import linregress

In [None]:
class Snake:
    """
    Should add in plotting and boundary condition specification methods.
    """

    def __init__(self, B=1, c=1, l=1, sigma=0.1, rho=1, mesh_points=50):
        self.g = -9.8
        self.B = B
        self.c = c
        self.bt = self.B / self.c
        self.sigma = sigma
        self.rho = rho
        self.lg = (-self.B / (self.rho * self.g)) ** (1 / 3)
        self.l = l
        self.x = np.linspace(0, 1, mesh_points)
        self.y = np.zeros((2, self.x.size))
        self.make_initial_guesses()

    def make_initial_guesses(self):
        """
        Initial parameter guesses for usual boundary conditions.
        """
        self.y[0, 0] = np.pi / 2
        self.y[0, -1] = 0
        self.y[1, 0] = 0
        # dy_ds = np.exp(-((self.x - 0.75) ** 2) / 0.010)
        # y_guess = (
        #     np.pi / 2 * si.cumtrapz(dy_ds, -self.x, initial=0) / si.trapz(dy_ds, self.x)
        #     + np.pi / 2
        # )  # want s=0 to be at pi/2 so need to do integral backwards
        # self.y[0, :] = y_guess
        # self.y[1, :] = dy_ds

    def fun_k(self, x, y, p):
        l = p[0]
        s_til = copy.deepcopy(x)
        f_inte = lambda s: (self.c / (np.sqrt(2 * np.pi) * self.sigma)) * si.trapezoid(
            y[1] * np.exp((-((s_til - s) ** 2)) / (2 * self.sigma ** 2)), s_til
        )
        ma = np.array([f_inte(x_) for x_ in x])
        grad_ma = np.gradient(ma, x, edge_order=2)

        dy0_dx = y[1]
        dy1_dx = -(1) / (self.B) * grad_ma - x * (l / self.lg) ** 3 * np.cos(y[0])
        dy_dx = np.vstack((dy0_dx, dy1_dx))
        return dy_dx

    def bc(self, ya, yb, p):
        res1 = ya[0] - np.pi / 2
        res2 = yb[0]
        res3 = ya[1]
        residuals = np.array([res1, res2, res3])
        return residuals

    def solve(self, verbose=0, max_nodes=1000):
        self.sol = si.solve_bvp(
            fun=self.fun_k,
            bc=self.bc,
            x=self.x,
            y=self.y,
            p=[self.l],
            max_nodes=max_nodes,
            verbose=verbose,
        )
        if snake.sol.status != 0:
            print("The answer didn't converge!")
        self.l = self.sol.p[0]
        self.x_pos = si.cumtrapz(self.l * np.cos(self.sol.y[0]), self.sol.x, initial=0)
        self.x_pos -= self.x_pos[-1]
        self.y_pos = si.cumtrapz(self.l * np.sin(self.sol.y[0]), self.sol.x, initial=0)
        self.y_pos -= self.y_pos[-1]

    def cost(self, alpha=0.5):
        height_term = self.y_pos[0] / self.l
        self.height_cost = -alpha * height_term
        work_term = self.c / (2 * self.B) * si.trapz(self.sol.y[1] ** 2, self.sol.x)
        self.work_cost = (1 - alpha) * work_term
        cost = self.height_cost + self.work_cost
        return cost

In [None]:
snake = Snake(sigma=0.1)
snake.solve(verbose=2, max_nodes=5e3)

In [None]:
plt.plot(-snake.x_pos, -snake.y_pos)

In [None]:
plt.plot(snake.sol.x, snake.sol.y[0])

In [None]:
plt.plot(snake.sol.x, snake.sol.y[1])