In [1]:
import pandas as pd

In [2]:
df = pd.read_csv("data.csv")

In [3]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632


In [10]:
import torch

In [115]:

Metal = ['Al', 'Ca', 'Fe', 'Mg', 'Mn', 'Ni']
M_list = torch.tensor([
    [0.00149, 0.02, 0.0803, 2.759, 0.0779],
    [0.0055, 0.1374, 0.7697, 1.417, 1.142],
    [3.303, 1.379, 2.4, 0.97, 2.238],
    [0.00155, 5.167, 1.306, 0.744, 1.438],
    [0, 0.0288, 0.0521, 0.0132, 0.0541],
    [0.328, 0, 0, 0, 0]
], dtype=torch.float32).T

abcdelist = torch.tensor(
    [0.27105, 0.567, 0.5632, 0.935, 0.6885], dtype=torch.float32)

In [141]:
M_list

tensor([[1.4900e-03, 5.5000e-03, 3.3030e+00, 1.5500e-03, 0.0000e+00, 3.2800e-01],
        [2.0000e-02, 1.3740e-01, 1.3790e+00, 5.1670e+00, 2.8800e-02, 0.0000e+00],
        [8.0300e-02, 7.6970e-01, 2.4000e+00, 1.3060e+00, 5.2100e-02, 0.0000e+00],
        [2.7590e+00, 1.4170e+00, 9.7000e-01, 7.4400e-01, 1.3200e-02, 0.0000e+00],
        [7.7900e-02, 1.1420e+00, 2.2380e+00, 1.4380e+00, 5.4100e-02, 0.0000e+00]])

In [206]:
def reverse_transform(metal_ratios: torch.Tensor) -> torch.Tensor:
    """
    Given metal_ratios (for Fe, Mn, Ni, Ca, Mg, Al) as a tensor of shape [batch, 6],
    this function recovers the original meteorite ratios [A, B, C, D, E] using
    the method from the original GUI code.

    The steps are:
      1. Define the abundance matrix and compute the transfer_matrix as its transpose.
      2. Solve the system:
             transfer_matrix[0:5] * [A, B, C, D, E]^T = metal_ratios[:, 0:5]^T
         for each batch element.
      3. Normalize the solution so that the sum over A, B, C, D, E equals 1.
      4. Multiply elementwise by (10 * abcdelist).

    Args:
        metal_ratios: A torch.Tensor of shape [batch, 6] containing the metal proportions.

    Returns:
        A torch.Tensor of shape [batch, 5] containing the computed A, B, C, D, E values.
    """
    # Ensure metal_ratios is float type:
    metal_ratios = metal_ratios.to(torch.get_default_dtype())

    # Define abcdelist (scaling factors) as in the original code.
    abcdelist = torch.tensor([0.27105, 0.567, 0.5632, 0.935, 0.6885],
                             dtype=metal_ratios.dtype,
                             device=metal_ratios.device)

    # Define the abundance matrix as given.
    # Each row corresponds to a metal and there are 5 rows;
    # In the original code abundance is 5x6.
    abundance = torch.tensor([
        [3.303,    0,     0.328, 0.0055, 0.00155, 0.00149],
        [1.379, 0.0288,       0, 0.1374,   5.167,    0.02],
        [2.4,     0.0521,     0, 0.7697,   1.306,    0.0803],
        [0.97,    0.0132,     0, 1.417,    0.744,    2.759],
        [2.238,   0.0541,     0, 1.142,    1.438,    0.0779]
    ], dtype=metal_ratios.dtype, device=metal_ratios.device)

    # Compute the transfer_matrix as the transpose of abundance.
    # This results in a 6 x 5 matrix.
    transfer_matrix = abundance.T  # shape: (6,5)

    # We use only the first five metal values to solve for [A,B,C,D,E].
    # That is, we form a 5x5 system using:
    #     transfer_matrix[0:5, :]  (shape 5x5)
    # and the corresponding metal ratios:
    #     metal_ratios[:, 0:5]       (shape: [batch,5])
    select_idx = [2, 4, 5, 1, 3]

    # Gather the selected metal proportions.
    # metal_ratios_sel will have shape [batch, 5].
    metal_ratios_sel = metal_ratios[:, select_idx]
    metal_ratios = metal_ratios_sel

    T_mat = transfer_matrix[:5, :]  # shape: (5,5)

    batch = metal_ratios.shape[0]
    # Prepare an output tensor for meteorite ratios.
    meteorite_ratios = torch.empty(
        (batch, 5), dtype=metal_ratios.dtype, device=metal_ratios.device)

    # Solve the system for each batch element.
    # Since T_mat is the same for every batch element, we iterate over the batch.
    for i in range(batch):
        # Solve T_mat x = metal_ratios[i, 0:5]
        # Here, torch.linalg.solve expects the right-hand side to have shape (5,)
        sol = torch.linalg.solve(T_mat, metal_ratios[i, :5])
        meteorite_ratios[i] = sol

    # Normalize each solution so that the sum of [A,B,C,D,E] is 1.
    meteorite_ratios = meteorite_ratios / \
        meteorite_ratios.sum(dim=1, keepdim=True)

    # Multiply elementwise by 10 * abcdelist.
    # This step scales the solution.
    # we dont do it, because the We operate in the normalised cordinate system
    # meteorite_ratios = meteorite_ratios * (10 * abcdelist)

    return meteorite_ratios

In [207]:
metal_sampled = reverse_transform(torch.tensor(df[['Al', 'Ca', 'Fe', 'Mg', 'Mn', 'Ni']
                                                  ].to_numpy(), dtype=torch.float32))

In [210]:
metal_sampled  # order is ABCDE

tensor([[0.5000, 0.1000, 0.1000, 0.2000, 0.1000],
        [0.8000, 0.0500, 0.0500, 0.0500, 0.0500],
        [0.2000, 0.1000, 0.3000, 0.1000, 0.3000],
        ...,
        [0.1000, 0.1000, 0.1000, 0.3000, 0.4000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.6000],
        [0.1000, 0.1000, 0.1000, 0.2000, 0.5000]])

In [209]:
samples = [str(list(row.tolist())) for row in metal_sampled[:, 0:-1]]

In [200]:
# Create a DataFrame
df_out = pd.DataFrame({
    'samples': samples,
    'energies': df["overpotential"].to_list()
})

In [205]:
metal_sampled

tensor([[0.5000, 0.1000, 0.1000, 0.2000, 0.1000],
        [0.8000, 0.0500, 0.0500, 0.0500, 0.0500],
        [0.2000, 0.1000, 0.3000, 0.1000, 0.3000],
        ...,
        [0.1000, 0.1000, 0.1000, 0.3000, 0.4000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.6000],
        [0.1000, 0.1000, 0.1000, 0.2000, 0.5000]])

In [156]:
df_out

Unnamed: 0,samples,energies
0,"[0.4999999403953552, 0.10000000894069672, 0.10...",1.7122
1,"[0.7999998927116394, 0.05000000074505806, 0.05...",1.7164
2,"[0.19999998807907104, 0.09999999403953552, 0.3...",1.7228
3,"[0.30000001192092896, 0.30000001192092896, 0.1...",1.7250
4,"[0.19999994337558746, 0.29999998211860657, 0.3...",1.7264
...,...,...
238,"[0.10000009089708328, 0.050000034272670746, 0....",2.2541
239,"[0.1000000387430191, 0.10000003129243851, 0.09...",2.2543
240,"[0.10000003129243851, 0.10000002384185791, 0.0...",2.2612
241,"[0.09999994188547134, 0.09999997168779373, 0.1...",2.2632


In [158]:
df_out.to_csv("states_train_trimmed.csv")

In [325]:
import torch

# Define the conversion matrix M_list.
# Each metal's coefficients are provided as in your original code.
# We construct M_list with shape (5, 6) (rows correspond to A, B, C, D, E).
metal_names = ['Al', 'Ca', 'Fe', 'Mg', 'Mn', 'Ni']
Al_list = [0.00149, 0.02, 0.0803, 2.759, 0.0779]
Ca_list = [0.0055, 0.1374, 0.7697, 1.417, 1.142]
Fe_list = [3.303, 1.379, 2.4, 0.97, 2.238]
Mg_list = [0.00155, 5.167, 1.306, 0.744, 1.438]
Mn_list = [0, 0.0288, 0.0521, 0.0132, 0.0541]
Ni_list = [0.328, 0, 0, 0, 0]

M_list = torch.tensor([
    Al_list,
    Ca_list,
    Fe_list,
    Mg_list,
    Mn_list,
    Ni_list,
], dtype=torch.float).T  # shape (5, 6)

# Define abcdelist as a tensor of scaling factors for the five components.
# Replace these with your actual values.
# abcdelist = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.float)
abcdelist = torch.tensor(
    [0.27105, 0.567, 0.5632, 0.935, 0.6885], dtype=torch.float32)


def convert_batch(input_batch: torch.Tensor):
    """
    Convert a batch of ABCDE values to metal compositions.

    Each input vector should be of the form [A, B, C, D, E].
    If the last component (E) is less than 0.1, it is set to 0.1 and the row is renormalized.
    The function also computes an outline which is 0.1 - original_E when E was low.

    Args:
        input_batch (torch.Tensor): Tensor of shape (batch_size, 5) containing the [A, B, C, D, E] values.
        abcdelist (torch.Tensor): 1D tensor of shape (5,) used for scaling.
        M_list (torch.Tensor): Conversion matrix of shape (5, 6).

    Returns:
        converted (torch.Tensor): Tensor of shape (batch_size, 6) with normalized metal compositions.
        outline (torch.Tensor): Tensor of shape (batch_size,) with the outline adjustments.
    """
    # Ensure a copy of the input to avoid modifying the original tensor.
    x = input_batch  # shape: (batch_size, 5)
    x = x

    # Compute outline for entries where E < 0.1.
    # Save the original E values.
    original_E = x[:, 4]

    # Create outline tensor (for each sample: outline = max(0, 0.1 - original_E))
    outline = 0.1 - original_E

    # For rows where E < 0.1, set E to 0.1.
    mask = x[:, 4] < 0.1
    if mask.any():
        x[mask, 4] = 0.1
        # Renormalize each row so that the sum becomes 1.
        row_sums = x[mask].sum(dim=1, keepdim=True)
        x[mask] = x[mask] / row_sums

    # Now, perform the conversion:
    # 1. Divide each vector elementwise by abcdelist.
    #    Make sure abcdelist has shape (1,5) to allow broadcasting.
    # skipping this step because we aer working with normalised ABCDE
    x_scaled = x
    # / abcdelist.view(1, -1)

    # 2. Multiply by the conversion matrix.
    #    The multiplication: (batch_size, 5) @ (5, 6) results in (batch_size, 6).
    converted = torch.matmul(x_scaled, M_list)

    # 3. Normalize so that each row sums to 1.
    converted = converted / converted.sum(dim=1, keepdim=True)

    return converted, outline.view(-1)


# Example usage:
if __name__ == '__main__':
    # Create an example batch with three samples.
    # Each sample is a vector [A, B, C, D, E].
    batch_data = torch.tensor([
        [0.2, 0.3, 0.1, 0.15, 0.25],   # E = 0.25 (>= 0.1, so no outline)
        [0.1, 0.2, 0.1, 0.3, 0.25],     # E = 0.25 (>= 0.1)
        [0.3, 0.2, 0.1, 0.05, 0.25]     # E = 0.25 (>= 0.1)
    ], dtype=torch.float)

    # Example with one sample having E < 0.1:
    # Here, the sum of A+B+C+D might be high so that E = 1 - (A+B+C+D) < 0.1,
    # but since we receive E directly we assume the user has computed it already.
    # For demonstration, let's modify one sample:
    # Set E below threshold so that outline adjustment happens.
    batch_data[0, 4] = 0.05

    metals, outline = convert_batch(batch_data)
    print("Converted metal values:\n", metals)
    print("Outline adjustments:\n", outline)

Converted metal values:
 tensor([[0.0950, 0.0972, 0.3669, 0.4220, 0.0046, 0.0143],
        [0.1659, 0.1575, 0.3276, 0.3373, 0.0055, 0.0063],
        [0.0383, 0.1043, 0.4772, 0.3523, 0.0057, 0.0222]])
Outline adjustments:
 tensor([ 0.0500, -0.1500, -0.1500])


In [None]:
torch.unsqueeze

In [321]:
metals, outline = convert_batch(metal_sampled)

In [324]:
outline

tensor([ 0.0000e+00,  0.0000e+00, -2.0000e-01, -3.2783e-07,  0.0000e+00,
        -2.0000e-01,  0.0000e+00, -2.0000e-01, -7.4506e-09, -3.0000e-01,
        -1.0000e-01, -1.2708e-02, -1.0000e-01,  0.0000e+00, -1.0000e-01,
        -2.0000e-01, -2.0000e-01, -1.0000e-01,  0.0000e+00, -5.2154e-08,
        -4.7520e-01, -1.3976e-01, -1.0000e-01,  1.4901e-08,  0.0000e+00,
        -1.0341e-01, -1.8648e-01, -2.0000e-01, -9.9999e-02, -5.7369e-07,
        -4.0000e-01, -1.0000e-01, -2.7672e-02, -1.0000e-01, -9.9999e-02,
        -5.2154e-08, -5.5134e-07, -3.0000e-01, -1.0000e-01, -1.2676e-01,
        -1.0000e-01, -1.3269e-01, -7.4506e-09, -1.4156e-07, -1.0000e-01,
        -1.0000e-01, -1.0219e-02, -1.0000e-01, -7.4506e-07, -1.6459e-02,
        -9.9999e-02, -3.1903e-01, -3.0000e-01,  0.0000e+00, -4.4703e-08,
        -1.0000e-01, -1.0000e-01, -2.8577e-01, -1.0000e-01, -4.4703e-08,
         0.0000e+00, -3.0000e-01,  0.0000e+00, -6.0891e-02, -2.0000e-01,
        -1.0000e-01, -1.0000e-01, -7.4506e-09, -3.0

In [323]:
sum(outline)/243

tensor(-0.1089)

In [314]:
metals

tensor([[0.1232, 0.1061, 0.5286, 0.2032, 0.0035, 0.0354],
        [0.0355, 0.0561, 0.7253, 0.1196, 0.0024, 0.0611],
        [0.0670, 0.1503, 0.4707, 0.2911, 0.0074, 0.0135],
        ...,
        [0.1656, 0.1854, 0.3609, 0.2755, 0.0064, 0.0062],
        [0.0658, 0.1815, 0.4246, 0.3133, 0.0083, 0.0065],
        [0.1166, 0.1835, 0.3922, 0.2940, 0.0073, 0.0064]])

In [318]:
sum(metals[0])

tensor(1.)

In [317]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632


In [293]:
metals, outline = convert_batch(metal_sampled)

In [294]:
metals

tensor([[0.2627, 0.1901, 0.3025, 0.2336, 0.0046, 0.0064],
        [0.1263, 0.1576, 0.4535, 0.2378, 0.0057, 0.0190],
        [0.1351, 0.1888, 0.3756, 0.2904, 0.0076, 0.0025],
        ...,
        [0.2578, 0.2145, 0.2863, 0.2348, 0.0057, 0.0008],
        [0.1152, 0.2064, 0.3744, 0.2946, 0.0083, 0.0010],
        [0.1943, 0.2109, 0.3256, 0.2615, 0.0068, 0.0009]])

In [161]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632


In [127]:
outline

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [None]:
class MVP_Masked(ContinuousCube):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    """    def __getattr__(self, name):
            if name == "fit_kde":
                raise AttributeError(
                    f"'{self.__class__.__name__}' object has no attribute 'fit_kde'")
            return super().__getattr__(name)"""

    def states2policy(
        self, states: Union[List, TensorType["batch", "state_dim"]]
    ) -> TensorType["batch", "state_dim"]:
        """
        Prepares a batch of states in "environment format" for the policy model: clips
        the states into [0, 1] and maps them to [-1.0, 1.0]

        Args
        ----
        states : list or tensor
            A batch of states in environment format, either as a list of states or as a
            single tensor.

        Returns
        -------
        A tensor containing all the states in the batch.
        """
        states = tfloat(states, device=self.device, float_type=self.float)
        return 2.0 * torch.clip(states, min=0.0, max=1.0)

    def states2proxy(
        self, states: Union[List, TensorType["batch", "state_dim"]]
    ) -> TensorType["batch", "state_dim"]:
        """
        Prepares a batch of states in "environment format" for a proxy: clips the
        states into [0, 1] and maps them to [CELL_MIN, CELL_MAX]

        Args
        ----
        states : list or tensor
            A batch of states in environment format, either as a list of states or as a
            single tensor.

        Returns
        -------
        A tensor containing all the states in the batch.
        """
        # Compute the sum along the state dimension for each batch element
        states = tfloat(states, device=self.device, float_type=self.float)
        states = torch.clip(states, min=0.0, max=1.0)
        states = 0.5 * states + 0.1
        state_sum = torch.sum(
            states, dim=1, keepdim=True)  # Shape: [batch, 1]
        outline = 1-state_sum
        state_e = 1-state_sum
        outline = torch.clip(outline, min=-0.6*self.n_dim, max=0.0)
        state_e = torch.clip(state_e, min=0.1, max=0.6)

        # Concatenate the states with their sums along the last dimension
        # Shape: [batch, state_dim + 1]
        states = torch.cat([states, state_e], dim=1)
        state_norm = states / torch.sum(
            states, dim=1, keepdim=True)

        metal_prop = self._compute_metal_proportions(state_norm)

        states = torch.cat(
            [metal_prop, outline], dim=1)

        return states

    def _compute_metal_proportions(self, input_vector_ABCDE, return_original=False):
        # Combine inputs into a tensor

        # Normalize the input vector with abcdelist
        normalized_vector = input_vector_ABCDE / abcdelist

        # Transform using M_list to get metal proportions
        metal_proportions = (normalized_vector @ M_list)

        # Normalize to ensure proportions sum to 1
        metal_proportions /= metal_proportions.sum()

        if return_original == False:

            return metal_proportions

    # Reverse function: Compute A, B, C, D, E from metal proportions

    def _compute_abcde(self, metal_proportions: torch.Tensor) -> torch.Tensor:
        """
        Computes A, B, C, D, E from metal proportions for a batch of inputs.

        Args
        ----
        metal_proportions : torch.Tensor
            A tensor of shape [batch, 6] containing the metal proportions.

        Returns
        -------
        torch.Tensor
            A tensor of shape [batch, 5] containing the computed A, B, C, D, E values.
        """
        # Ensure input is a tensor and properly shaped
        metal_proportions = metal_proportions.clone().detach()  # [batch, 6]

        # Solve the linear system for each batch element
        # torch.linalg.solve only supports matrices, so we need batched matrix operations
        M_list_T = M_list.T.unsqueeze(0)  # [1, 6, 5] to support batch
        metal_proportions = metal_proportions.unsqueeze(-1)  # [batch, 6, 1]

        # Batched lstsq to solve for normalized_vector
        # Solve M_list_T @ normalized_vector = metal_proportions for normalized_vector
        lstsq_result = torch.linalg.lstsq(M_list_T.expand(metal_proportions.size(0), -1, -1),
                                          metal_proportions)

        # Extract solutions and remove unnecessary dimensions
        normalized_vector = lstsq_result.solution.squeeze(-1)  # [batch, 5]

        # Denormalize to get A, B, C, D, E
        abcde = normalized_vector * abcdelist  # [batch, 5]

        return abcde

    def state2readable(self, state: List) -> str:
        """
        Converts a state (a list of positions) into a human-readable string
        representing a state.
        """

        return str(state).replace("(", "[").replace(")", "]").replace(",", ""), str(self.states2proxy([state]))

    def get_mask_invalid_actions_forward(
        self,
        state: Optional[List] = None,
        done: Optional[bool] = None,
    ) -> List:
        state = self._get_state(state)
        done = self._get_done(done)
        # If done, then all actions are “invalid”
        if done:
            return [True] * self.mask_dim

        mask = [False] * self.mask_dim_base + self.ignored_dims

        # If the state is the source state, EOS is invalid.
        if self._get_effective_dims(state) == self._get_effective_dims(self.source):
            mask[2] = True
        else:
            # Here is our additional rule:
            # Allow EOS only if the sum of the state coordinates is (close to) 1.
            # We use a tolerance since floating point arithmetic is not exact.
            tol = 1e-6
            s = sum(self._get_effective_dims(state))
            if abs(s - 1.0) > tol:
                mask[2] = True  # mark EOS as invalid if sum is not (almost) 1
            # Otherwise, EOS remains valid.

        # If any dimension is above 1 - min_incr, then continuous actions are invalid.
        if any([s > 1 - self.min_incr for s in self._get_effective_dims(state)]):
            mask[0] = True

        return mask

    def get_mask_invalid_actions_backward(self, state=None, done=None, parents_a=None):
        state = self._get_state(state)
        done = self._get_done(done)
        mask = [True] * self.mask_dim_base + self.ignored_dims

        # If the state is the source state, then the entire mask is True.
        if self._get_effective_dims(state) == self._get_effective_dims(self.source):
            return mask

        # If done, only valid action is EOS.
        if done:
            # Additionally, check the EOS condition here if needed
            tol = 1e-6
            s = sum(self._get_effective_dims(state))
            if abs(s - 1.0) <= tol:
                mask[2] = False  # EOS valid if sum is nearly 1.
            else:
                mask[2] = True
            return mask

        if any([s < self.min_incr for s in self._get_effective_dims(state)]):
            mask[1] = False
            return mask

        mask[0] = False
        return mask

In [126]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632


In [122]:
convert_torch(metal_sampled)

TypeError: convert_torch() missing 3 required positional arguments: 'B', 'C', and 'D'

In [None]:
Metal = ['Al', 'Ca', 'Fe', 'Mg', 'Mn', 'Ni']
Al_list = [0.00149, 0.02, 0.0803, 2.759, 0.0779]
Ca_list = [0.0055, 0.1374, 0.7697, 1.417, 1.142]
Fe_list = [3.303, 1.379, 2.4, 0.97, 2.238]
Mg_list = [0.00155, 5.167, 1.306, 0.744, 1.438]
Mn_list = [0, 0.0288, 0.0521, 0.0132, 0.0541]
Ni_list = [0.328, 0, 0, 0, 0]

M_list = pd.DataFrame([Al_list, Ca_list, Fe_list, Mg_list, Mn_list, Ni_list],
                      columns=['A', 'B', 'C', 'D', 'E'],
                      index=Metal).T

In [None]:
def forward_transform(abcde: torch.Tensor) -> torch.Tensor:
    """
    Forward transformation from meteorite ratios [A, B, C, D, E] to metal ratios.
    The transformation follows these steps for each batch element:
      1. Compute E = 1 - (A+B+C+D).
      2. If E < 0.1, replace E with 0.1 and re-normalize [A, B, C, D, 0.1].
      3. Scale the vector by dividing elementwise by `abcdelist`.
      4. Multiply by the transfer matrix (abundance.T) to obtain 6 metal values.
      5. Normalize the metal ratios so that they sum to 1.

    Args:
        abcde: Tensor of shape [batch, 5] representing A, B, C, D, E.

    Returns:
        metal_ratios: Tensor of shape [batch, 6] with normalized metal proportions.
    """
    # Ensure input is a floating-point tensor.
    abcde = abcde.to(torch.get_default_dtype())

    # Split into A, B, C, D, and provided E.
    # However, note that in the forward transform the original code computes E = 1 - (A+B+C+D)
    # (ignoring the passed E) and then optionally replaces it.
    A = abcde[:, 0:1]  # shape: [batch, 1]
    B = abcde[:, 1:2]
    C = abcde[:, 2:3]
    D = abcde[:, 3:4]

    # Compute E from A,B,C,D.
    computed_E = 1 - (A + B + C + D)

    # Create the candidate vector x = [A, B, C, D, E*] where E* may be replaced.
    # We will decide on each batch element whether to clip E or not.
    # Initialize with computed_E.
    E_final = computed_E.clone()

    # We'll gather A,B,C,D together.
    base = torch.cat([A, B, C, D], dim=1)  # shape: [batch, 4]

    # Create a flag where computed_E < 0.1.
    clip_mask = computed_E < 0.1  # shape: [batch, 1], boolean

    # When clipping, the vector becomes [A, B, C, D, 0.1].
    # When not clipping, it is [A, B, C, D, computed_E].
    # We'll form the complete x.
    # First, set the candidate E value: where clip_mask is True, use 0.1.
    E_candidate = torch.where(clip_mask, torch.tensor(
        0.1, dtype=abcde.dtype, device=abcde.device), computed_E)

    # Stack to form x.
    x = torch.cat([base, E_candidate], dim=1)  # shape: [batch, 5]

    # If clipping occurred, re-normalize the vector so that its sum is 1.
    # For each element, if clip_mask is True then divide that row by its sum.
    row_sum = x.sum(dim=1, keepdim=True)
    # For rows that were clipped, we re-normalize.
    # (For not-clipped rows the sum should already be 1; but normalizing again is harmless.)
    x = x / row_sum

    # Now, scale x by dividing elementwise by abcdelist.
    # Make abcdelist a 1x5 vector to broadcast correctly.
    abcdelist_ = abcdelist.unsqueeze(0)  # shape: [1, 5]
    x_scaled = x / abcdelist_

    # Multiply by the transfer matrix.
    # transfer_matrix is shape [6,5], so we compute:
    #   metal_raw = x_scaled @ transfer_matrix^T   (or equivalently, (x_scaled / abcdelist) @ transfer_matrix^T)
    # Because the original forward used:
    #   x = (x/abcdelist) @ M_list.values
    # where M_list.values is transfer_matrix.
    metal_raw = torch.matmul(x_scaled, transfer_matrix.T)  # shape: [batch, 6]

    # Finally, normalize the metal ratios so that each row sums to 1.
    metal_ratios = metal_raw / metal_raw.sum(dim=1, keepdim=True)

    return metal_ratios

In [110]:
metal_sampled

tensor([[1.3552, 0.5670, 0.5632, 1.8700, 0.6885],
        [2.1684, 0.2835, 0.2816, 0.4675, 0.3442],
        [0.5421, 0.5670, 1.6896, 0.9350, 2.0655],
        ...,
        [0.2711, 0.5670, 0.5632, 2.8050, 2.7540],
        [0.2710, 0.5670, 0.5632, 0.9350, 4.1310],
        [0.2711, 0.5670, 0.5632, 1.8700, 3.4425]])

In [None]:
metal_sampled_norm = metal_sampled / \
    torch.sum(metal_sampled, dim=1, keepdim=True)

In [None]:
_compute_metal_proportions

In [74]:
sum(metal_sampled[100])

tensor(1.)

In [105]:
metal_sampled

tensor([[1.3552, 0.5670, 0.5632, 1.8700, 0.6885],
        [2.1684, 0.2835, 0.2816, 0.4675, 0.3442],
        [0.5421, 0.5670, 1.6896, 0.9350, 2.0655],
        ...,
        [0.2711, 0.5670, 0.5632, 2.8050, 2.7540],
        [0.2710, 0.5670, 0.5632, 0.9350, 4.1310],
        [0.2711, 0.5670, 0.5632, 1.8700, 3.4425]])

In [95]:
metal_sampled_norm

tensor([[0.2687, 0.1124, 0.1117, 0.3707, 0.1365],
        [0.6116, 0.0800, 0.0794, 0.1319, 0.0971],
        [0.0935, 0.0978, 0.2914, 0.1612, 0.3562],
        ...,
        [0.0389, 0.0815, 0.0809, 0.4030, 0.3957],
        [0.0419, 0.0877, 0.0871, 0.1446, 0.6388],
        [0.0404, 0.0845, 0.0839, 0.2785, 0.5127]])

In [96]:
metal_sampled_norm

tensor([[0.2687, 0.1124, 0.1117, 0.3707, 0.1365],
        [0.6116, 0.0800, 0.0794, 0.1319, 0.0971],
        [0.0935, 0.0978, 0.2914, 0.1612, 0.3562],
        ...,
        [0.0389, 0.0815, 0.0809, 0.4030, 0.3957],
        [0.0419, 0.0877, 0.0871, 0.1446, 0.6388],
        [0.0404, 0.0845, 0.0839, 0.2785, 0.5127]])

In [97]:
ww = _compute_metal_proportions(metal_sampled_norm)

In [76]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632


In [98]:
ww

tensor([[5.3805e-04, 4.6324e-04, 2.3085e-03, 8.8738e-04, 1.5226e-05, 1.5470e-04],
        [1.9870e-04, 2.3849e-04, 4.0153e-03, 5.8246e-04, 9.9447e-06, 3.5218e-04],
        [2.6720e-04, 5.9899e-04, 1.8764e-03, 1.1607e-03, 2.9586e-05, 5.3823e-05],
        ...,
        [5.9408e-04, 6.6526e-04, 1.2951e-03, 9.8840e-04, 2.3031e-05, 2.2427e-05],
        [2.4487e-04, 6.7552e-04, 1.5803e-03, 1.1659e-03, 3.0805e-05, 2.4128e-05],
        [4.2588e-04, 6.7020e-04, 1.4325e-03, 1.0739e-03, 2.6775e-05, 2.3250e-05]])

In [100]:
sum(ww[0])

tensor(0.0044)

In [None]:
E = 1 - (A + B + C + D)
outline = 0
if E < 0.1:
    x = np.array([A, B, C, D, 0.1])
    x /= x.sum()
    outline = 0.1 - E
else:
    x = np.array([A, B, C, D, E])
x = (x / abcdelist) @ M_list.values
x = x / x.sum()

In [77]:
sum(ww[0])

tensor(0.0047)

In [None]:
Metal = ['Al', 'Ca', 'Fe', 'Mg', 'Mn', 'Ni']
Al_list = [0.00149, 0.02, 0.0803, 2.759, 0.0779]
Ca_list = [0.0055, 0.1374, 0.7697, 1.417, 1.142]
Fe_list = [3.303, 1.379, 2.4, 0.97, 2.238]
Mg_list = [0.00155, 5.167, 1.306, 0.744, 1.438]
Mn_list = [0, 0.0288, 0.0521, 0.0132, 0.0541]
Ni_list = [0.328, 0, 0, 0, 0]

M_list = pd.DataFrame([Al_list, Ca_list, Fe_list, Mg_list, Mn_list, Ni_list],
                      columns=['A', 'B', 'C', 'D', 'E'],
                      index=Metal).T

In [113]:
# Define the abcdelist (scaling factors).
abcdelist = torch.tensor(
    [0.27105, 0.567, 0.5632, 0.935, 0.6885], dtype=torch.get_default_dtype())

# Define the abundance matrix as provided in the original code.
# (Shape is 5x6; rows correspond to different meteorite components, columns to metals.)
abundance = torch.tensor([
    [3.303,    0.0,     0.328,   0.0055,  0.00155, 0.00149],
    [1.379,  0.0288,       0.0,  0.1374,    5.167,   0.02],
    [2.4,     0.0521,      0.0,  0.7697,    1.306,   0.0803],
    [0.97,    0.0132,      0.0,  1.417,     0.744,   2.759],
    [2.238,   0.0541,      0.0,  1.142,     1.438,   0.0779]
], dtype=torch.get_default_dtype())

# The transfer matrix is the transpose of abundance.
transfer_matrix = abundance.T  # shape: (6,5)

In [104]:
def forward_transform(abcde: torch.Tensor) -> torch.Tensor:
    """
    Forward transformation from meteorite ratios [A, B, C, D, E] to metal ratios.
    The transformation follows these steps for each batch element:
      1. Compute E = 1 - (A+B+C+D).
      2. If E < 0.1, replace E with 0.1 and re-normalize [A, B, C, D, 0.1].
      3. Scale the vector by dividing elementwise by `abcdelist`.
      4. Multiply by the transfer matrix (abundance.T) to obtain 6 metal values.
      5. Normalize the metal ratios so that they sum to 1.

    Args:
        abcde: Tensor of shape [batch, 5] representing A, B, C, D, E.

    Returns:
        metal_ratios: Tensor of shape [batch, 6] with normalized metal proportions.
    """
    # Ensure input is a floating-point tensor.
    abcde = abcde.to(torch.get_default_dtype())

    # Split into A, B, C, D, and provided E.
    # However, note that in the forward transform the original code computes E = 1 - (A+B+C+D)
    # (ignoring the passed E) and then optionally replaces it.
    A = abcde[:, 0:1]  # shape: [batch, 1]
    B = abcde[:, 1:2]
    C = abcde[:, 2:3]
    D = abcde[:, 3:4]

    # Compute E from A,B,C,D.
    computed_E = 1 - (A + B + C + D)

    # Create the candidate vector x = [A, B, C, D, E*] where E* may be replaced.
    # We will decide on each batch element whether to clip E or not.
    # Initialize with computed_E.
    E_final = computed_E.clone()

    # We'll gather A,B,C,D together.
    base = torch.cat([A, B, C, D], dim=1)  # shape: [batch, 4]

    # Create a flag where computed_E < 0.1.
    clip_mask = computed_E < 0.1  # shape: [batch, 1], boolean

    # When clipping, the vector becomes [A, B, C, D, 0.1].
    # When not clipping, it is [A, B, C, D, computed_E].
    # We'll form the complete x.
    # First, set the candidate E value: where clip_mask is True, use 0.1.
    E_candidate = torch.where(clip_mask, torch.tensor(
        0.1, dtype=abcde.dtype, device=abcde.device), computed_E)

    # Stack to form x.
    x = torch.cat([base, E_candidate], dim=1)  # shape: [batch, 5]

    # If clipping occurred, re-normalize the vector so that its sum is 1.
    # For each element, if clip_mask is True then divide that row by its sum.
    row_sum = x.sum(dim=1, keepdim=True)
    # For rows that were clipped, we re-normalize.
    # (For not-clipped rows the sum should already be 1; but normalizing again is harmless.)
    x = x / row_sum

    # Now, scale x by dividing elementwise by abcdelist.
    # Make abcdelist a 1x5 vector to broadcast correctly.
    abcdelist_ = abcdelist.unsqueeze(0)  # shape: [1, 5]
    x_scaled = x / abcdelist_

    # Multiply by the transfer matrix.
    # transfer_matrix is shape [6,5], so we compute:
    #   metal_raw = x_scaled @ transfer_matrix^T   (or equivalently, (x_scaled / abcdelist) @ transfer_matrix^T)
    # Because the original forward used:
    #   x = (x/abcdelist) @ M_list.values
    # where M_list.values is transfer_matrix.
    metal_raw = torch.matmul(x_scaled, transfer_matrix.T)  # shape: [batch, 6]

    # Finally, normalize the metal ratios so that each row sums to 1.
    metal_ratios = metal_raw / metal_raw.sum(dim=1, keepdim=True)

    return metal_ratios

In [111]:
forward_transform(metal_sampled)

tensor([[0.5363, 0.0027, 0.0390, 0.0935, 0.1944, 0.1340],
        [0.7573, 0.0014, 0.0682, 0.0357, 0.0996, 0.0378],
        [0.4783, 0.0060, 0.0190, 0.1173, 0.2914, 0.0881],
        ...,
        [0.3088, 0.0038, 0.0098, 0.1595, 0.2669, 0.2511],
        [0.3878, 0.0047, 0.0152, 0.1155, 0.3438, 0.1330],
        [0.3398, 0.0042, 0.0119, 0.1422, 0.2971, 0.2047]])

In [112]:
df

Unnamed: 0,Al,Ca,Fe,Mg,Mn,Ni,overpotential
0,0.123204,0.106073,0.528617,0.203194,0.003486,0.035425,1.7122
1,0.036817,0.044191,0.743975,0.107923,0.001843,0.065252,1.7164
2,0.067023,0.150251,0.470664,0.291139,0.007421,0.013501,1.7228
3,0.059822,0.088446,0.430872,0.396597,0.005039,0.019225,1.7250
4,0.060222,0.101457,0.405555,0.414243,0.005944,0.012578,1.7264
...,...,...,...,...,...,...,...
238,0.350899,0.204542,0.258098,0.177185,0.003356,0.005919,2.2541
239,0.212779,0.187240,0.330747,0.257557,0.005539,0.006137,2.2543
240,0.165565,0.185401,0.360911,0.275456,0.006418,0.006249,2.2612
241,0.065797,0.181515,0.424648,0.313278,0.008277,0.006484,2.2632
