In [428]:
import torch
from joblib import Parallel, delayed

In [585]:
class Random_walk:
    def __init__(self, bounds : list[callable],
                n_workers : int = 4):
        self.bounds, self.nb = bounds, len(bounds)
        self.n_workers = n_workers
    
    def _sample_steps(self, n : int = 100) -> torch.Tensor:
        """Sample n steps in 2D RW with step size 10 (cm)."""
        opt = torch.tensor([[10, 0], [0, 10], [-10, 0], [0, -10]])
        steps = opt[torch.randint(4, (1, n))].squeeze()
        return steps.cumsum(dim=0)
    
    def _check_bounds(self, path : torch.Tensor,
                            found : torch.Tensor) -> torch.Tensor:
        """Check if path is finding any food."""
        breaches = torch.ones(self.nb, dtype=int)*-1
        for i in torch.argwhere(~found).ravel():
            limit = self.bounds[i](path[:, 0])
            breach = torch.argwhere((path[:, 1] <= limit[:, 0]) |\
                                    (path[:, 1] >= limit[:, 1]))
            if breach.numel() > 0:
                breaches[i] = breach.min()
        return breaches

    def _walk(self) -> torch.Tensor:
        """Count steps needed to first reach the food."""
        # --- initialize search ---
        path = torch.zeros(1, 2)
        steps, carry = torch.zeros(self.nb, dtype=int), torch.zeros(self.nb, dtype=int)
        found = torch.zeros(self.nb, dtype=bool)
        # --- search for food ---
        while not all(found):
          carry += path.shape[0] # add at least one step
          path = self._sample_steps(100) + path[-1]
          # --- --- check for boundary breaches (food found) --- ---
          check_srch = self._check_bounds(path, found)
          newfound = ~found & (check_srch > -1)
          steps[newfound] = check_srch[newfound] + carry[newfound]
          found[newfound] = True
        
        return steps
    
    def simulate(self, iterations = 1_000) -> torch.Tensor:
        """Return average number of steps to reach the boundary."""
        steps = torch.vstack(Parallel(n_jobs=self.n_workers)(
            delayed(self._walk)() for _ in range(iterations)))
        return steps.float().mean(dim=0)

In [603]:
"""Define food distributions as a list of functions formatted as
expressions of the upper and lower (y) limits as a function
of the x-coordinate. E.g. x -> yUpperLimit(x), x -> yLowerLimit(x). """

food_distributions = [
    # --- square food distribution, a) ---
    lambda x : torch.tensor([-20, 20]).repeat(x.size(0), 1)\
                *torch.logical_and(x > -20, x < 20).repeat(2, 1).T,
    # --- linear diagonal food distribution, b) ---
    lambda x : torch.vstack([-torch.tensor(torch.inf).repeat(x.size(0)),
                            -x+10]).T,
    # --- circular food distribution, c) ---
    lambda x : torch.vstack([-40 * torch.sqrt((-((x-2.5)/30)**2+1)) + 2.5,
            40 * torch.sqrt((-((x-2.5)/30)**2+1)) + 2.5]).T.nan_to_num(0)
]

random_walk = Random_walk(food_distributions, n_workers=8)
n, bootstrap_n = 100, 100
averages = torch.vstack([random_walk.simulate(n) for _ in range(bootstrap_n)])

In [604]:
print(f"Average time to reach the food:")
for i in range(averages.size(1)):
    print(f"  {chr(97+i)}) {averages[:, i].mean():.3f} ", end="")
    print(f"+- {averages[:, i].std():.3f} seconds ", end="")
    print(f"(median of bootstrap averages: {averages[:, i].median():.3f} seconds)")

Average time to reach the food:
  a) 4.520 +- 0.280 seconds (median of bootstrap averages: 4.490 seconds)
  b) 9337.585 +- 72063.906 seconds (median of bootstrap averages: 201.560 seconds)
  c) 13.930 +- 0.978 seconds (median of bootstrap averages: 13.920 seconds)
