In [55]:
import gym
from gym import spaces
import numpy as np
from scipy.special import hyp2f1
import cmath
from tqdm import tqdm
import os

class StandardDeviationEnv(gym.Env):
    """
    Custom Environment for solving the crossing equation involving hypergeometric functions,
    supporting up to 4 states with individual d_i bounds and dynamic reward based on the number
    of states included in the reward.
    """
    metadata = {'render.modes': ['human']}

    def __init__(
        self,
        n_states: int = 4,  # Number of states (up to 4)
        n_states_rew: int = 4,  # Number of states to include in the reward
        min_d: list = None,  # Minimum values for each d_i
        max_d: list = None,  # Maximum values for each d_i
        d_lr: float = 0.1,  # Learning rate for d updates
        max_episode_steps: int = 50,
        num_samples: int = 50,  # Number of samples (each sample includes n_states z's)
        impr_coeff: float = 10.0,  # Coefficient for improvement term
        spins:list=[0.0, 2.0, 4.0, 0.0],
        delta_sigma=-0.4,
        centers:list=[.05 + .45j,.15 + .3j,.5,.5 + .4j],
        gauss_stds:list=[.1,.1,.1,.1]
    ):
        super(StandardDeviationEnv, self).__init__()

        assert 1 <= n_states <= 4, "n_states must be between 1 and 4"
        assert 1 <= n_states_rew <= n_states, "n_states_rew must be between 1 and n_states"

        self.n_states = n_states
        self.n_states_rew = n_states_rew
        self.d_lr = d_lr  # Learning rate for d1 to dn updates
        self.max_episode_steps = max_episode_steps
        self.num_samples = num_samples  # Number of samples, each with n_states z's
        self.impr_coeff = impr_coeff  # Improvement coefficient
        #TODO change to general n states
        self.centers=centers
        self.gauss_stds=gauss_stds

        
        # Initialize min_d and max_d as arrays
        if min_d is None:
            self.min_d = [1.0] * self.n_states
        else:
            assert len(min_d) == self.n_states, "min_d must be a list of length n_states"
            self.min_d = min_d

        if max_d is None:
            self.max_d = [4.0] * self.n_states
        else:
            assert len(max_d) == self.n_states, "max_d must be a list of length n_states"
            self.max_d = max_d

        # Initialize variables
        self.reward = 0.0
        self.prev_reward = None  # Initialize previous reward
        self.std_over_mean_sum = None

        # Initialize d1 to dn
        self.d = np.random.uniform(self.min_d, self.max_d, size=self.n_states).tolist()

        # Action and observation spaces
        # Actions: d1_step, d2_step, ..., dn_step
        action_low = np.array([-1.0] * self.n_states, dtype=np.float32)
        action_high = np.array([1.0] * self.n_states, dtype=np.float32)

        # Observations: d1, d2, ..., dn, mean_C1, mean_C2, ..., mean_C_rew, std_over_mean_sum
        obs_low = np.array(
            self.min_d + [-np.inf] * self.n_states_rew + [0.0],
            dtype=np.float32
        )
        obs_high = np.array(
            self.max_d + [np.inf] * self.n_states_rew + [np.inf],
            dtype=np.float32
)

        self.action_space = spaces.Box(low=action_low, high=action_high, dtype=np.float32)
        self.observation_space = spaces.Box(low=obs_low, high=obs_high, dtype=np.float32)

        # Count steps
        self.num_steps = 0

        # Fixed parameters (set these to your actual values)
        self.Delta_sigma = delta_sigma   # Δσ
        self.s = spins  # Spins for operators 1 to 4
       
        # Initialize state variables
        self.reset()

    # Sampling functions
    def zcomplex(self, r, theta):
        return 0.5 + r * cmath.exp(1j * (cmath.pi - theta))
    
    def random_point_in_z_plane(self,center, width):
        """
        Generate a random point in the complex z-plane with Gaussian distribution.

        Parameters:
        - center (complex): The center of the distribution in the complex plane.
        - width (float): The standard deviation (width) of the Gaussian distribution.

        Returns:
        - complex: A randomly generated point in the z-plane.
        """
        real_part = np.random.normal(loc=center.real, scale=width)
        imag_part = np.random.normal(loc=center.imag, scale=width)
        return complex(real_part, imag_part)

    def sample_z(self, state_idx):
        """
        Sample a single z complex number based on the specified distribution.
        Allows for different sampling strategies per state if desired.

        Args:
            state_idx (int): Index of the current state (0-based).

        Returns:
            complex: Sampled complex number z.
        """

 
        
        return self.random_point_in_z_plane(self.centers[state_idx % 4], self.gauss_stds[state_idx % 4])

    def kronecker_delta(self, h, hb, tol=1e-8):
        return 1.0 if abs(h - hb) < tol else 0.0

    def g(self, h, hb, z):
        delta = self.kronecker_delta(h, hb)
        denominator = delta + 1.0

        # Compute hypergeometric functions using scipy
        z_conj = np.conj(z)

        try:
            hg_h_z = hyp2f1(h, h, 2 * h, z)
            hg_hb_z_conj = hyp2f1(hb, hb, 2 * hb, z_conj)
            hg_h_z_conj = hyp2f1(h, h, 2 * h, z_conj)
            hg_hb_z = hyp2f1(hb, hb, 2 * hb, z)
        except Exception as e:
            print(f"Error computing hypergeometric functions for z={z}: {e}")
            return 0.0

        # Compute terms
        term1 = (z ** h) * (z_conj ** hb) * hg_h_z * hg_hb_z_conj
        term2 = (z_conj ** h) * (z ** hb) * hg_h_z_conj * hg_hb_z

        # Combine terms
        result = (term1 + term2) / denominator
        return result.real  # Assuming the result should be real

    def compute_equation(self, z_list, d_values):
        """
        Compute the crossing equations for a given list of z's and d's.

        Args:
            z_list (list of complex): List of z complex numbers.
            d_values (list of float): List of d_i values for each state.

        Returns:
            tuple: (A_matrix, D_vector) where A_matrix is [num_equations x n_states] and D_vector is [num_equations]
        """
        num_equations = len(z_list)  # Typically, num_equations = n_states
        n = self.n_states
        A = np.zeros((num_equations, n), dtype=np.float64)
        D = np.zeros(num_equations, dtype=np.float64)

        for eq_idx, z in enumerate(z_list):
            # Define h[i] and hb[i] for all states
            for i in range(n):
                Delta_i = d_values[i]
                s_i = self.s[i]
                h_i = (Delta_i + s_i) / 2
                hb_i = (Delta_i - s_i) / 2

                # Compute g functions for z and 1 - z
                g_z = self.g(h_i, hb_i, z)
                g_1_minus_z = self.g(h_i, hb_i, 1 - z)

                # Compute Abs[z - 1]^(2 Delta_sigma) and Abs[z]^(2 Delta_sigma)
                abs_z_minus_1 = abs(z - 1) ** (2 * self.Delta_sigma)
                abs_z = abs(z) ** (2 * self.Delta_sigma)

                # Compute coefficient for C[i] in this equation
                coefficient = abs_z_minus_1 * g_z - abs_z * g_1_minus_z

                A[eq_idx, i] = coefficient

            # Compute D for this equation
            D[eq_idx] = - (abs(z - 1) ** (2 * self.Delta_sigma) - abs(z) ** (2 * self.Delta_sigma))

        return A, D


    def compute_reward(self, d_values):
        """
        Calculate the reward based on the current d_values.

        Args:
            d_values (list of float): Current d_i values.

        Returns:
            float: Calculated reward.
        """
        cs_list = []  # To collect C estimates from each sample

        for _ in range(self.num_samples):
            # Sample z's for each equation
            z_samples = [self.sample_z(state_idx) for state_idx in range(self.n_states)]

            # Compute A and D for the current sample
            A, D = self.compute_equation(z_samples, d_values)

            # Solve the linear system A * C = D
            try:
                C = np.linalg.solve(A, D)
                cs_list.append(C)
            except np.linalg.LinAlgError as e:
                # Singular matrix or other numerical issues; skip this sample
                continue

        if len(cs_list) == 0:
            # No valid solutions found; assign a low reward
            self.cs = np.zeros(self.n_states_rew)
            self.std_over_mean_sum = 1e6
            return -1e6
        else:
            cs_array = np.array(cs_list)
            C_means = np.mean(cs_array, axis=0)
            C_stds = np.std(cs_array, axis=0, ddof=1)

            # Select the first n_states_rew states
            selected_means = C_means[:self.n_states_rew]
            selected_stds = C_stds[:self.n_states_rew]

            # Compute the standard deviation over mean sum
            epsilon = 1e-8
            selected_means_safe = np.where(np.abs(selected_means) < epsilon, epsilon, np.abs(selected_means))
            std_over_mean = selected_stds / selected_means_safe
            self.std_over_mean_sum = np.sum(std_over_mean)

            # Compute the reward
            # Reward = - sum(log(std_over_mean)) for the selected states
            # Ensure std_over_mean > 0 to avoid log(0)
            std_over_mean = np.clip(std_over_mean, 1e-8, None)
            reward = -np.sum(np.log(std_over_mean))

            self.cs = C_means[:self.n_states_rew]

            return reward


    def _get_obs(self):
        """
        Construct the observation array.

        Returns:
            np.array: Observation containing d1 to dn, mean_C1 to Cn, std_over_mean_sum
        """
        d_array = np.array(self.d, dtype=np.float32)
        C_mean_array = self.cs if hasattr(self, 'cs') else np.zeros(self.n_states_rew, dtype=np.float32)
        observation = np.concatenate([d_array, C_mean_array, [self.std_over_mean_sum]], axis=0).astype(np.float32)
        return observation

    def reset(self):
        self.num_steps = 0
        self.prev_reward = None  # Reset previous reward

        # Initialize d1 to dn randomly within bounds
        self.d = np.random.uniform(self.min_d, self.max_d, size=self.n_states).tolist()

        # Compute initial reward and observation
        self.reward = self.compute_reward(self.d)

        # Set prev_reward to the initial immediate reward
        self.prev_reward = self.reward

        observation = self._get_obs()
        return observation

    def step(self, action):
        self.num_steps += 1

        # Clip the action to ensure it's within the action space
        action = np.clip(action, self.action_space.low, self.action_space.high)
        d_new = self.d + self.d_lr * action  # Update rule

        # Check boundary conditions
        is_ok = self.d_bc(d_new)

        done = False

        if not is_ok:
            # Penalize the agent for invalid values
            self.reward = -10.0
            done = True
            observation = self._get_obs()
            return observation, self.reward, done, {}

        # Update variables
        self.d = d_new.tolist()

        # Compute the immediate reward
        current_reward = self.compute_reward(self.d)

        # Calculate the change in immediate reward
        if self.prev_reward is not None:
            delta_reward = current_reward - self.prev_reward
        else:
            delta_reward = 0.0

        # Compute the total reward
        self.reward = current_reward + self.impr_coeff * delta_reward

        # Update the previous reward
        self.prev_reward = current_reward

        # Check if episode should be truncated due to max steps
        if self.num_steps >= self.max_episode_steps:
            done = True

        # Check for NaNs in observation
        observation = self._get_obs()
        if np.isnan(observation).any():
            self.reward = -10.0
            done = True
            return observation, self.reward, done, {}

        return observation, self.reward, done, {}

    def d_bc(self, d_values):
        """
        Check if all d_i values are within the specified bounds.

        Args:
            d_values (list of float): Current d_i values.

        Returns:
            bool: True if all d_i are within bounds, False otherwise.
        """
        return all(self.min_d[i] <= d_values[i] <= self.max_d[i] for i in range(self.n_states))

    def render(self, mode='human'):
        """
        Render the environment.
        """
        print(f"Step: {self.num_steps}")
        for i in range(self.n_states):
            print(f"d{i+1}: {self.d[i]:.4f}")
        for i in range(self.n_states_rew):
            print(f"C{i+1}_mean: {self.cs[i]:.4f}")
        print(f"Std over Mean Sum: {self.std_over_mean_sum:.4f}")
        print(f"Reward: {self.reward:.4f}")
        print("-" * 30)

    def close(self):
        pass


In [98]:
import gym
from gym import spaces
import numpy as np
from scipy.special import hyp2f1
import cmath
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import itertools
from concurrent.futures import ProcessPoolExecutor

def compute_reward_for_delta(std_env, deltas):
    return std_env.compute_reward(deltas)

class ZSamplingEnv(gym.Env):
    """
    Custom Gym environment for optimizing z-sampling in the standard deviation method.
    The agent selects shifts for 80 z-points to minimize the standard deviation of C's
    at target deltas {1, 2, 4, 6}.
    """
    metadata = {'render.modes': ['human']}

    def __init__(
        self,
        n_states_rew: int = 4,  # Number of states to include in the reward
        num_z_per_state: int = 20,  # Number of z's per state
        target_deltas: list = None,  # Fixed target deltas for sharpness
        delta_shift: float = 0.1,  # Shift applied to target deltas for baseline
        Delta_sigma: float = 1/8,  # Fixed Delta_sigma
        max_episode_steps: int = 100,  # Maximum steps per episode
        centers_init=[.25 + .45j,.15 + .3j,.45,.5 + .3j],
        gauss_stds_init=[.2,.2,.2,.2]
    ):
        super(ZSamplingEnv, self).__init__()

        self.std_env=env = StandardDeviationEnv(
            n_states=4,
            n_states_rew=n_states_rew,
            min_d=[-1.0, 1.5, 3.1, 5.5],  # Adjusted to include d1=-0.4 and d4=7.6
            max_d=[1.0, 2.5, 5.0, 7.0],    # Ensure that d1=-0.4 and d4=7.6 are within min_d and max_d
            spins=[0, 2, 0, 4],
            delta_sigma=1/8,
            centers=centers_init,
            gauss_stds=gauss_stds_init
        )
        self.centers_init=centers_init
        self.gauss_stds_init=gauss_stds_init
        self.state=np.concatenate([centers_init,gauss_stds_init])
        
        self.n_states_rew = n_states_rew
        self.num_z_per_state = num_z_per_state
       
        self.target_deltas = np.array(target_deltas if target_deltas is not None else [1.0, 2.0, 4.0, 6.0])
       
        self.delta_shift = delta_shift
        self.Delta_sigma = Delta_sigma
        self.max_steps = max_episode_steps
        self.current_step = 0
        self.shifted_deltas=self.generate_shifted_lists(self.target_deltas,self.delta_shift)
        # Action space: Shifts for each z-point (real and imaginary parts)
        # Total actions: 2 * total_z
        self.action_space = spaces.Box(
            low=-0.1, high=0.1,  # Assuming shifts are small
            shape=(2 * 4,),
            dtype=np.float32
        )

        # Observation space: Current z-points (real and imaginary parts)
        # Shape: (2 * total_z,)
        self.observation_space = spaces.Box(
            low=0.0, high=0.5,  # Since z_real and z_imag are in (0, 0.5)
            shape=(2 * 8,),
            dtype=np.float32
        )


    def reset(self):
        """
        Reset the environment to an initial state.
        """
        self.current_step = 0

        self.std_env=env = StandardDeviationEnv(
            n_states=4,
            n_states_rew=self.n_states_rew,
            min_d=[-1.0, 1.5, 3.1, 5.5],  # Adjusted to include d1=-0.4 and d4=7.6
            max_d=[1.0, 2.5, 5.0, 7.0],    # Ensure that d1=-0.4 and d4=7.6 are within min_d and max_d
            spins=[0, 2, 0, 4],
            delta_sigma=1/8,
            centers=self.centers_init,
            gauss_stds=self.gauss_stds_init
        )
        self.centers=self.centers_init,
        self.gauss_stds=self.gauss_stds_init

        self.state=np.concatenate([self.centers_init,self.gauss_stds_init])
    
        return self.state

    def clip_obs(self,array):
        """
        Clips the elements of an 8-component array based on the given rules:
        - First 4 elements (complex): Clip real and imaginary parts to [0, 1/2].
        - Last 4 elements (real): Clip to [0, 0.2].

        Parameters:
        - array (ndarray): An array with 8 components (first 4 complex, last 4 real).

        Returns:
        - ndarray: The clipped array.
        """
        if len(array) != 8:
            raise ValueError("Input array must have exactly 8 components.")
        
        # Clip first 4 elements (complex)
        clipped_first_4 = [complex(np.clip(c.real, 0, 0.5), np.clip(c.imag, 0, 0.5)) for c in array[:4]]
        
        # Clip last 4 elements (real)
        clipped_last_4 = np.clip(array[4:], 0, 0.2)
        
        # Combine the results
        return np.array(clipped_first_4 + list(clipped_last_4), dtype=array.dtype)

    def generate_shifted_lists(self,original_list,shift):
        """
        Generate all lists by shifting each element of the original list by -1, 0, or 1,
        excluding the original list (0,0,0,0 shift).
        
        Parameters:
        - original_list: List of 4 elements to be shifted.

        Returns:
        - List of shifted lists.
        """
        if len(original_list) != 4:
            raise ValueError("The original list must have exactly 4 elements.")
        
        # Generate all combinations of shifts (-1, 0, 1) for 4 elements
        shifts = list(itertools.product([-shift, 0, shift], repeat=4))
        
        # Exclude the original list shift (0, 0, 0, 0)
        shifts = [shift for shift in shifts if shift != (0, 0, 0, 0)]
        
        # Apply shifts to the original list
        shifted_lists = []
        for shift in shifts:
            shifted_list = [original + delta for original, delta in zip(original_list, shift)]
            shifted_lists.append(shifted_list)
        
        return shifted_lists

    def step(self, action):
        """
        Apply the agent's action to shift z-points and compute the reward.
        """
        self.current_step += 1

        # Apply action: shifts to z-points
        action = np.clip(action, self.action_space.low, self.action_space.high)

        self.state+=action
        self.state=self.clip_obs(self.state)
        
        self.std_env.centers=self.state[:4]
        self.std_env.gauss_stds=self.state[4:]

        # Compute reward for target deltas
        rew_target = self.std_env.compute_reward(self.target_deltas)

        # Compute reward for shifted deltas
        rew_shifted = 0
        
       # for deltas in self.shifted_deltas:
       #     rew_shifted+=self.std_env.compute_reward(deltas)
            
        with ProcessPoolExecutor() as executor:
            shifted_rewards = list(executor.map(
                lambda deltas: compute_reward_for_delta(self.std_env, deltas),
                self.shifted_deltas
            ))

        rew_shifted = sum(shifted_rewards)
        # Compute the reward as the ratio
        reward = rew_target / rew_shifted if rew_shifted != 0 else 0.0

        # Determine if the episode is done
        done = self.current_step >= self.max_steps

        # Optionally, add info for debugging
        info = {
            'rew_target': rew_target,
            'rew_shifted': rew_shifted
        }


        return self.state, reward, done, info

    def render(self, mode='human'):
        """
        Render the environment.
        """
        if mode == 'human':
            print(f"Step: {self.current_step}")

          
            print("-" * 50)

    def close(self):
        """
        Clean up resources.
        """
        pass


In [None]:
def test_fn():
    for _ in range(10000000):
        hyp2f1(2, 2, 2 * 2, .4)

In [136]:
start_time = time.time()
test_fn()
test_fn()
test_fn()
test_fn()
print("--- %s seconds ---" % (time.time() - start_time))

--- 8.150000810623169 seconds ---


In [137]:
with ProcessPoolExecutor() as executor:
            shifted_rewards = list(executor.map(
                lambda : test_fn(),
                range(4)
            ))

0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.


KeyboardInterrupt: 

In [130]:
from concurrent.futures import ThreadPoolExecutor

In [135]:
start_time = time.time()
with ThreadPoolExecutor(max_workers=4) as executor:
    future = executor.submit(test_fn)
print("--- %s seconds ---" % (time.time() - start_time))    
    

--- 1.9942290782928467 seconds ---


In [None]:
with ProcessPoolExecutor() as executor:
            shifted_rewards = list(executor.map(
                lambda deltas: compute_reward_for_delta(self.std_env, deltas),
                self.shifted_deltas
            ))

In [101]:
env=ZSamplingEnv()

In [102]:
env.step([0.1,0.1,0,0,0,0,0,0])

  real_part = np.random.normal(loc=center.real, scale=width)
  imag_part = np.random.normal(loc=center.imag, scale=width)
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to dis

KeyboardInterrupt: 

In [100]:
import time 
start_time = time.time()
for _ in range(10):
    env.step([0,0,0,0,0,0,0,0])
print("--- %s seconds ---" % (time.time() - start_time))

  real_part = np.random.normal(loc=center.real, scale=width)
  imag_part = np.random.normal(loc=center.imag, scale=width)
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to dis

AttributeError: Can't pickle local object 'ZSamplingEnv.step.<locals>.<lambda>'

In [89]:
import time 
start_time = time.time()
for _ in range(10):
    env.step([0,0,0,0,0,0,0,0])
print("--- %s seconds ---" % (time.time() - start_time))

  real_part = np.random.normal(loc=center.real, scale=width)
  imag_part = np.random.normal(loc=center.imag, scale=width)


--- 21.714500904083252 seconds ---


In [83]:
env.step([0,0,0,0,0,0,0,0])

  real_part = np.random.normal(loc=center.real, scale=width)
  imag_part = np.random.normal(loc=center.imag, scale=width)


(array([0.25+0.45j, 0.15+0.3j , 0.45+0.j  , 0.5 +0.3j , 0.1 +0.j  ,
        0.1 +0.j  , 0.1 +0.j  , 0.1 +0.j  ]),
 0.01973247040557794,
 False,
 {'rew_target': 3.1080512305189094, 'rew_shifted': 157.50948394380111})

In [None]:
env.step([0,0,0,0,-0.1,-0.1,-0.1,-0.1])

  real_part = np.random.normal(loc=center.real, scale=width)
  imag_part = np.random.normal(loc=center.imag, scale=width)


(array([0.25+0.45j, 0.15+0.3j , 0.45+0.j  , 0.5 +0.3j , 0.1 +0.j  ,
        0.1 +0.j  , 0.1 +0.j  , 0.1 +0.j  ]),
 -0.0073491191149271705,
 False,
 {'rew_target': -0.5988429573121348, 'rew_shifted': 81.48499812661824})

In [56]:
import itertools

def generate_shifted_lists(original_list):
    """
    Generate all lists by shifting each element of the original list by -1, 0, or 1,
    excluding the original list (0,0,0,0 shift).
    
    Parameters:
    - original_list: List of 4 elements to be shifted.

    Returns:
    - List of shifted lists.
    """
    if len(original_list) != 4:
        raise ValueError("The original list must have exactly 4 elements.")
    
    # Generate all combinations of shifts (-1, 0, 1) for 4 elements
    shifts = list(itertools.product([-1, 0, 1], repeat=4))
    
    # Exclude the original list shift (0, 0, 0, 0)
    shifts = [shift for shift in shifts if shift != (0, 0, 0, 0)]
    
    # Apply shifts to the original list
    shifted_lists = []
    for shift in shifts:
        shifted_list = [original + delta for original, delta in zip(original_list, shift)]
        shifted_lists.append(shifted_list)
    
    return shifted_lists

# Example usage
original_list = [5, 10, 15, 20]
shifted_lists = generate_shifted_lists(original_list)

# Output the results
print(f"Original list: {original_list}")
print(f"Generated {len(shifted_lists)} shifted lists:")
for lst in shifted_lists[:10]:  # Show first 10 for brevity
    print(lst)


Original list: [5, 10, 15, 20]
Generated 80 shifted lists:
[4, 9, 14, 19]
[4, 9, 14, 20]
[4, 9, 14, 21]
[4, 9, 15, 19]
[4, 9, 15, 20]
[4, 9, 15, 21]
[4, 9, 16, 19]
[4, 9, 16, 20]
[4, 9, 16, 21]
[4, 10, 14, 19]


In [53]:
env.z_points=np.array([(0.33991855072593896+0.009120811605038992j),
 (0.4479317435927501+0.3332602036873727j),
 (0.3313606153070749+0.08893805255512283j),
 (0.4870245345441848+0.345396985284971j),
 (0.34452530696896383+0.07593746572536138j),
 (0.4270729300211247+0.30044956363579733j),
 (0.4660427022841916+0.0037328107392406657j),
 (0.4428095378988895+0.32037957671224626j),
 (0.4028300155799515+0.00466332970672254j),
 (0.40751577782174975+0.2948520555545184j),
 (0.39517938578318423+0.024581274166913976j),
 (0.44255155055450124+0.33493599261864726j),
 (0.42721021098863105+0.006446516721407273j),
 (0.3308539759324897+0.3063645404089619j),
 (0.4095919273443319+0.015210026544854044j),
 (0.3712956380701379+0.27884266117321316j),
 (0.3156754491343662+0.05124163467558722j),
 (0.313301240232927+0.2957471833496594j),
 (0.3699094546466398+0.060109643683488236j),
 (0.4273689526197979+0.33273630665676224j),
 (0.3167865846915863+0.04841714188641773j),
 (0.3220057239680192+0.2969247277816467j),
 (0.40266114196865904+0.05364610671754945j),
 (0.3959493628092905+0.29522075998087055j),
 (0.31981736861273474+0.0810209032864796j),
 (0.43813327564065074+0.3211556718838836j),
 (0.41942550289240826+0.04050604690717442j),
 (0.4679049907735872+0.3059818701669683j),
 (0.38341030970500534+0.03234258755490473j),
 (0.3829326309906609+0.32051502784364055j),
 (0.37430300207633516+0.04737101338427306j),
 (0.35090190675311506+0.2759751381937153j),
 (0.3341002310804143+0.012733262247345865j),
 (0.4368995160163938+0.3372675244064487j),
 (0.36394039112435644+0.05537611054319957j),
 (0.39154348839342096+0.28338870511096986j),
 (0.44470706192419546+0.025716328094016222j),
 (0.4686990093078302+0.3341282209119506j),
 (0.3946135758880598+0.007222503988180198j),
 (0.3374641392509291+0.2629339081162081j),
 (0.47060690165572744+0.005185003292286693j),
 (0.45235848764493114+0.3053001735956744j),
 (0.32571607462712465+0.09241124671702446j),
 (0.35536503034788114+0.30898017399058153j),
 (0.41571001112955924+0.00851437193732401j),
 (0.44820507007227284+0.3119632129140621j),
 (0.3755198411307309+0.004843946157810831j),
 (0.4235039204547405+0.34086135515970456j),
 (0.3581797066602756+0.04164400556904397j),
 (0.42794461430862263+0.3039436046500811j),
 (0.3618014281594154+0.0377887109076519j),
 (0.4688914343479966+0.30828678943395516j),
 (0.3346793524314787+0.08856375551404054j),
 (0.3376098696599371+0.25856110226555373j),
 (0.3402969773388954+0.09678456469445448j),
 (0.41061801539395715+0.28886816572792573j),
 (0.3998994169418176+0.0230411257043737j),
 (0.4427413609292173+0.30371978467128563j),
 (0.31022328208769+0.04531363614185657j),
 (0.4087667036412457+0.29908764453037795j),
 (0.3673305049183452+0.06249101015277798j),
 (0.38566152819705124+0.320499024067667j),
 (0.3375797959533162+0.06914650942268076j),
 (0.34263744466040674+0.2583440224099935j),
 (0.3459240945454646+0.0387067964522405j),
 (0.3865025070898438+0.3240083399875981j),
 (0.40867162114782407+0.0004514291040405452j),
 (0.3498941354478917+0.27851807025531344j),
 (0.31813890431945036+0.05777444877314896j),
 (0.3469331229034194+0.2702774637679896j),
 (0.4363516745337823+0.033959656786335356j),
 (0.344448567938032+0.28295752227140886j),
 (0.3735018803887609+0.054211566792017975j),
 (0.3659251012403769+0.2958053894415227j),
 (0.31048095565784223+0.04265645579070334j),
 (0.4781798937029102+0.31083447774910566j),
 (0.3655234124003276+0.017315950208415868j),
 (0.43662947160088067+0.3064694051162609j),
 (0.313694059067655+0.002677786916775582j),
 (0.4152295082215325+0.2909471842625235j)])

In [54]:
env.step(np.zeros(160))

(array([0.33991855, 0.44793174, 0.3313606 , 0.48702455, 0.3445253 ,
        0.42707294, 0.4660427 , 0.44280955, 0.40283   , 0.40751576,
        0.3951794 , 0.44255155, 0.4272102 , 0.33085397, 0.4095919 ,
        0.37129563, 0.31567544, 0.31330124, 0.36990947, 0.42736894,
        0.3167866 , 0.32200572, 0.40266114, 0.39594936, 0.31981736,
        0.43813327, 0.41942552, 0.46790498, 0.3834103 , 0.38293263,
        0.374303  , 0.3509019 , 0.33410022, 0.4368995 , 0.3639404 ,
        0.39154348, 0.44470707, 0.468699  , 0.39461356, 0.33746415,
        0.4706069 , 0.45235848, 0.32571608, 0.35536504, 0.41571   ,
        0.44820508, 0.37551984, 0.42350394, 0.35817972, 0.4279446 ,
        0.36180142, 0.46889144, 0.33467937, 0.33760986, 0.34029698,
        0.410618  , 0.39989942, 0.44274136, 0.31022328, 0.40876672,
        0.3673305 , 0.38566154, 0.3375798 , 0.34263745, 0.3459241 ,
        0.3865025 , 0.40867162, 0.34989414, 0.3181389 , 0.34693313,
        0.4363517 , 0.34444857, 0.37350187, 0.36

In [6]:
np.zeros(80)

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [6]:
np.random.rand(80) [0*4:(0+1)*4]

array([0.2343206 , 0.53453008, 0.29740833, 0.97596133])

In [None]:

from stable_baselines3 import SAC
from stable_baselines3.common.env_checker import check_env
import matplotlib.pyplot as plt




# Create the RL model
model = SAC(
    'MlpPolicy',
    env,
    verbose=1,
    tensorboard_log="./sac_standard_deviation_tensorboard/",
    # You can adjust hyperparameters here
    learning_rate=3e-4,
    buffer_size=1000000,
    learning_starts=10000,
    batch_size=256,
    tau=0.005,
    gamma=0.99,
    train_freq=1,
    gradient_steps=1,
    ent_coef='auto',
    target_update_interval=1,
)

# Train the agent
# Adjust the total_timesteps based on computational resources
model.learn(total_timesteps=1000000)

# Save the trained model
model.save("sac_z_sampling")

# To load the model later:
# model = PPO.load("ppo_standard_deviation")
