In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn


In [11]:
def generate_exponential_function(intercept: float, base: float):
    """
    Returns an exponential function of the form:
    y = intercept * (base ** x)

    Parameters:
    intercept (float): The multiplicative constant (scales the output).
    base (float): The base of the exponential.

    Returns:
    function: A function that computes y for a given x.
    """

    def exponential_function(x):
        return intercept + (base ** (x))

    return exponential_function


def compile_punishment_vals(
    num_resources=5, num_steps=10, exponentialness=0.12, intercept_increase_speed=2
):
    """Generates a set of punishment values based on an exponential function.

    Args:
        num_resources (int): Number of resources to consider.
        num_steps (int): Number of steps in the exponential growth.
        exponentialness (float): Parameter controlling the steepness of the curve.
        intercept_increase_speed (float): Parameter controlling the speed of the increase of the intercept.
    """
    vals = []

    # Generate exponential functions for each step
    # and calculate the values for each resource
    for i in range(num_steps):
        exp_func = generate_exponential_function(
            1 + (i / intercept_increase_speed), exponentialness * i + 2
        )
        vals.append([exp_func(k) for k in range(num_resources)])

    vals = np.array(vals)
    max_val = np.max(vals)
    max_val = int(max_val)

    # Normalized values
    punishment_probs = vals / max_val
    punishment_probs = punishment_probs * [1, 1.1, 1.25, 1.45, 1.95]
    punishment_probs = np.clip(punishment_probs, 0.0, 1.0)

    return punishment_probs

In [12]:
x= compile_punishment_vals()

In [16]:
x

array([[0.02105263, 0.03473684, 0.06578947, 0.13736842, 0.34894737],
       [0.02631579, 0.04191579, 0.07887368, 0.16832406, 0.44541349],
       [0.03157895, 0.04909474, 0.09233684, 0.20207542, 0.55782952],
       [0.03684211, 0.05627368, 0.10617895, 0.23878075, 0.68805122],
       [0.04210526, 0.06345263, 0.1204    , 0.2785983 , 0.83803652],
       [0.04736842, 0.07063158, 0.135     , 0.32168632, 1.        ],
       [0.05263158, 0.07781053, 0.14997895, 0.36820305, 1.        ],
       [0.05789474, 0.08498947, 0.16533684, 0.41830675, 1.        ],
       [0.06315789, 0.09216842, 0.18107368, 0.47215565, 1.        ],
       [0.06842105, 0.09934737, 0.19718947, 0.52990803, 1.        ]])

In [10]:
predefined_punishment_probs = np.array([
    [0.50, 0.00, 0.00, 0.00, 0.00],  # s = 0
    [0.55, 0.05, 0.00, 0.00, 0.00],  # s = 1
    [0.60, 0.10, 0.00, 0.00, 0.00],  # s = 2
    [0.65, 0.10, 0.05, 0.00, 0.00],  # s = 3
    [0.70, 0.10, 0.10, 0.00, 0.00],  # s = 4
    [0.75, 0.10, 0.10, 0.05, 0.00],  # s = 5
    [0.80, 0.10, 0.10, 0.10, 0.00],  # s = 6
    [0.85, 0.15, 0.10, 0.10, 0.05],  # s = 7
    [0.90, 0.15, 0.15, 0.10, 0.10],  # s = 8
    [0.95, 0.20, 0.15, 0.15, 0.10],  # s = 9
])
print(predefined_punishment_probs.shape)

(10, 5)
