In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
from torch.profiler import profile, record_function, ProfilerActivity

# Numba imports
from numba import cuda
# from numba import njit  # if you need CPU jitted helpers, otherwise optional

In [2]:
################################################################################
# 1) GLOBALS & SETUP
################################################################################
tol = 1e-15
max_terms = 4096
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
################################################################################
# 2) Numba/CUDA Kernels for Hypergeometric & Reward
################################################################################

@cuda.jit(device=True)
def _2F1_device(a, b, c, z_r, z_i):
    if c == 0.0:
        return 1.0, 0.0
    real_accum = 1.0
    imag_accum = 0.0
    term_r = 1.0
    term_i = 0.0
    for n in range(1, max_terms):
        denom = n*(c+n-1.0)
        poch = ((a+n-1.0)*(b+n-1.0))/denom
        zr = term_r*z_r - term_i*z_i
        zi = term_r*z_i + term_i*z_r
        term_r = poch*zr
        term_i = poch*zi
        if (abs(term_r) < tol) and (abs(term_i) < tol):
            break
        real_accum += term_r
        imag_accum += term_i
    return real_accum, imag_accum

@cuda.jit(device=True)
def compute_g_device(d_val, s_val, x, y):
    h = 0.5*(d_val + s_val)
    hb = 0.5*(d_val - s_val)

    fhz_r, fhz_i = _2F1_device(h,h,2*h,x,y)
    fhbz_b_r, fhbz_b_i = _2F1_device(hb,hb,2*hb,x,-y)
    fhz_b_r, fhz_b_i = _2F1_device(h,h,2*h,x,-y)
    fhb_z_r, fhb_z_i = _2F1_device(hb,hb,2*hb,x,y)

    r = math.sqrt(x*x + y*y)
    theta = math.atan2(y,x)
    d_ = h + hb
    s_ = h - hb
    r_pow_d = r**d_

    cos_s_th = math.cos(s_*theta)
    sin_s_th = math.sin(s_*theta)

    # F_h(z)*F_hb(z̅)
    step1_r = fhz_r*fhbz_b_r - fhz_i*fhbz_b_i
    step1_i = fhz_r*fhbz_b_i + fhz_i*fhbz_b_r
    tmp1_r = step1_r*cos_s_th - step1_i*sin_s_th
    tmp1_i = step1_r*sin_s_th + step1_i*cos_s_th
    T1_r = r_pow_d*tmp1_r
    T1_i = r_pow_d*tmp1_i

    # F_h(z̅)*F_hb(z)
    step2_r = fhz_b_r*fhb_z_r - fhz_b_i*fhb_z_i
    step2_i = fhz_b_r*fhb_z_i + fhz_b_i*fhb_z_r
    tmp2_r = step2_r*cos_s_th + step2_i*sin_s_th
    tmp2_i = -step2_r*sin_s_th + step2_i*cos_s_th
    T2_r = r_pow_d*tmp2_r
    T2_i = r_pow_d*tmp2_i

    g_r = T1_r + T2_r
    g_i = T1_i + T2_i
    return g_r, g_i

@cuda.jit
def compute_G_element(d_val, s_val, x, y, dphi):
    xp = 1.0 - x
    yp = -y
    r1 = math.sqrt((x-1.0)*(x-1.0) + y*y)
    r2 = math.sqrt(x*x + y*y)
    r1_pow = r1**(2*dphi)
    r2_pow = r2**(2*dphi)
    g1_r, _ = compute_g_device(d_val, s_val, x, y)
    g2_r, _ = compute_g_device(d_val, s_val, xp, yp)
    val_r = r1_pow*g1_r - r2_pow*g2_r
    if s_val==0:
        return val_r*0.5
    return val_r

@cuda.jit
def compute_g_delta_kernel(d_arr, s_arr, x_arr, y_arr, dphi, g_delta_matrix):
    i_g,i_z,i_state = cuda.grid(3)
    N_g, N_state = d_arr.shape
    N_z = x_arr.size
    stride_g,stride_z,stride_state = cuda.gridsize(3)
    for k in range(i_z,N_z,stride_z):
        for i in range(i_g,N_g,stride_g):
            for j in range(i_state,N_state,stride_state):
                x = x_arr[k]
                y = y_arr[k]
                d_val = d_arr[i,j]
                s_val = s_arr[j]
                g_delta_matrix[i,k,j] = compute_G_element(d_val, s_val, x, y, dphi)

@cuda.jit
def compute_W_v(d_max, x_arr, y_arr, dphi, W, v):
    idx = cuda.grid(1)
    N_z = x_arr.size
    if idx<N_z:
        x = x_arr[idx]
        y = y_arr[idx]
        r1 = math.sqrt((x-1.0)*(x-1.0) + y*y)
        r2 = math.sqrt(x*x + y*y)
        r1_pow = r1**(2*dphi)
        r2_pow = r2**(2*dphi)
        G0 = compute_G_element(d_max, 0.0, x, y, dphi)
        if abs(G0)<1e-12:
            W[idx] = 1e12
        else:
            W[idx] = 1.0/(G0*G0)
        v[idx] = r1_pow-r2_pow

################################################################################
# 3) Bridging Function (Option D)
################################################################################

def calculate_c_rew(d_values, s_values, x_values, y_values, dphi, d_max=9.0, N_lsq=20):
    """
    d_values, s_values, x_values, y_values are PyTorch Tensors on GPU.
    We do CPU bridging to feed them into Numba kernels.
    """
    # 1) Convert from PyTorch -> CPU NumPy
    d_np = d_values.detach().cpu().numpy()     # shape [N_deltas, N_state]
    s_np = s_values.detach().cpu().numpy()     # shape [N_state]
    x_np = x_values.detach().cpu().numpy()     # shape [N_z]
    y_np = y_values.detach().cpu().numpy()     # shape [N_z]

    N_deltas = d_np.shape[0]
    N_state  = d_np.shape[1]
    N_z = x_np.shape[0]
    assert N_z%N_lsq==0, "N_z should be integer multiple of N_lsq"
    N_stat = N_z//N_lsq

    # 2) Move arrays to Numba device
    d_device = cuda.to_device(d_np)
    s_device = cuda.to_device(s_np)
    x_device = cuda.to_device(x_np)
    y_device = cuda.to_device(y_np)

    # 3) g_delta_matrix => shape [N_deltas, N_z, N_state]
    g_delta_device = cuda.device_array((N_deltas, N_z, N_state), dtype=np.float64)

    threads_per_block = (4,8,4)
    blocks_per_grid_x = math.ceil(N_deltas / threads_per_block[0])
    blocks_per_grid_y = math.ceil(N_z / threads_per_block[1])
    blocks_per_grid_z = math.ceil(N_state / threads_per_block[2])

    compute_g_delta_kernel[(blocks_per_grid_x, blocks_per_grid_y, blocks_per_grid_z),
                           threads_per_block](
        d_device, s_device, x_device, y_device, dphi, g_delta_device
    )

    W_device = cuda.device_array(N_z, dtype=np.float64)
    v_device = cuda.device_array(N_z, dtype=np.float64)

    threadsperblock = 256
    blockspergrid = (N_z+threadsperblock-1)//threadsperblock
    compute_W_v[blockspergrid,threadsperblock](
        d_max, x_device, y_device, dphi, W_device, v_device
    )

    # 4) Copy results back to CPU
    G_host = g_delta_device.copy_to_host()  # shape [N_deltas, N_z, N_state]
    W_host = W_device.copy_to_host()        # shape [N_z]
    v_host = v_device.copy_to_host()        # shape [N_z]

    # 5) Convert to PyTorch Tensors on GPU
    G = torch.tensor(G_host, device=device, dtype=torch.float64)
    W_diag = torch.tensor(W_host, device=device, dtype=torch.float64)
    v0 = torch.tensor(v_host, device=device, dtype=torch.float64)

    # 6) Reshape & LSQ in PyTorch
    G = G.view(N_deltas, N_stat, N_lsq, N_state).permute(1,0,2,3)
    W_diag = W_diag.view(N_stat, N_lsq)
    v = v0.unsqueeze(0).expand(N_deltas, -1)  # [N_deltas, N_z]
    v = v.view(N_deltas, N_stat, N_lsq).permute(1,0,2)

    W_mat = torch.diag_embed(W_diag) # shape [N_stat, N_lsq, N_lsq]
    WG = torch.einsum('szz,sgzn->sgzn', W_mat, G)
    GT_WG = torch.einsum('sgzn,sgzm->sgnm', G, WG)
    WG_v = torch.einsum('sgzn,sgz->sgn', WG, v)

    c = -1.* torch.linalg.solve(GT_WG, WG_v)  # shape [N_stat, N_deltas, N_state]
    Gc = torch.einsum('sgzn,sgn->sgz', G, c)
    residual_vector = Gc + v
    W_residual = torch.einsum('sy,sgy->sgy', W_diag, residual_vector)
    residual = torch.einsum('sgz,sgz->sg', W_residual, residual_vector)

    return c, residual

################################################################################
# 4) Reward Wrappers
################################################################################

def least_sq_std_rew(d_values, zs, s_values, dSigma, N_lsq=20, n_states_rew=2):
    x_values = zs.real
    y_values = zs.imag
    cs,rews = calculate_c_rew(d_values, s_values, x_values, y_values, dSigma, N_lsq=N_lsq)
    c_mean = torch.mean(cs, dim=0)
    c_std  = torch.std(cs, dim=0)
    r_stat = c_std/(c_mean + 1e-12)
    r = -torch.sum(torch.log(torch.clamp(torch.abs(r_stat[:, :n_states_rew]), 1e-12)), dim=1)
    return r
    

In [4]:
################################################################################
# 5) z points
################################################################################

def rho(z):
    """
    Compute Rho[z] = z / (1 + sqrt(1 - z))^2
    """
    return z / (1 + np.sqrt(1 - z))**2

def lambda_z(z):
    """
    Compute Lambda[z] = Abs[Rho[z]] + Abs[Rho[1 - z]]
    """
    rho_z = rho(z)
    rho_1_minus_z = rho(1 - z)
    return np.abs(rho_z) + np.abs(rho_1_minus_z)

def discretize_region(lambda_0=.42, x_range=[.51,1.5], y_range=[0,1.16], resolution=50):
    """
    Discretize the complex z-plane into a grid and return all points where Lambda[z] < lambda_0.
    
    Parameters:
    - lambda_0: Threshold for Lambda[z]
    - x_range: Tuple (x_min, x_max) for the real part
    - y_range: Tuple (y_min, y_max) for the imaginary part
    - resolution: Number of points per axis in the grid
    
    Returns:
    - valid_points: List of complex points within the valid region
    - grid_points: 2D numpy array of all complex points in the grid
    - mask: Boolean mask indicating points where Lambda[z] < lambda_0
    """
    # Generate grid of complex points
    x = np.linspace(x_range[0], x_range[1], resolution)
    y = np.linspace(y_range[0], y_range[1], resolution)
    X, Y = np.meshgrid(x, y)
    Z = X + 1j * Y  # Create the complex grid points
    
    # Compute Lambda[z] for each point in the grid
    Lambda = lambda_z(Z)
    
    # Mask points where Lambda[z] < lambda_0
    mask = Lambda < lambda_0
    valid_points = Z[mask]  # Extract valid points from the grid
    
    return valid_points


def generate_random_points(lambda_0=.42, x_range=[.51, 1.5], y_range=[0, 1.16], num_points=200):
    """
    Generate exactly num_points random points in the region where Lambda[z] < lambda_0.
    
    Parameters:
    - lambda_0: Threshold for Lambda[z]
    - x_range: List [x_min, x_max] for the real part
    - y_range: List [y_min, y_max] for the imaginary part
    - num_points: Total number of random points to generate
    
    Returns:
    - valid_points: Array of exactly num_points complex points within the valid region
    """
    valid_points = []
    
    while len(valid_points) < num_points:
        # Generate a batch of random points
        real_part = np.random.uniform(x_range[0], x_range[1], num_points)
        imag_part = np.random.uniform(y_range[0], y_range[1], num_points)
        z_batch = real_part + 1j * imag_part
        
        # Check which points satisfy Lambda[z] < lambda_0
        mask = lambda_z(z_batch) < lambda_0
        valid_points.extend(z_batch[mask])  # Add valid points to the list
    
    # Ensure exactly num_points points by slicing
    return np.array(valid_points[:num_points])


In [5]:
################################################################################
# 6) SAC Implementation (All PyTorch GPU)
################################################################################

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        # We store CPU or GPU data in a consistent format.
        # For minimal overhead, store on CPU as float32 np arrays:
        self.buffer.append((
            state.detach().cpu().numpy(),
            action.detach().cpu().numpy(),
            np.array([reward.detach().cpu().item()], dtype=np.float32),
            next_state.detach().cpu().numpy(),
            np.array([done], dtype=np.float32)
        ))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return (
            torch.tensor(state, device=device, dtype=torch.float32),
            torch.tensor(action, device=device, dtype=torch.float32),
            torch.tensor(reward, device=device, dtype=torch.float32),
            torch.tensor(next_state, device=device, dtype=torch.float32),
            torch.tensor(done, device=device, dtype=torch.float32)
        )
    
    def __len__(self):
        return len(self.buffer)

class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=256, action_low=-1, action_high=1):
        super().__init__()
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.mean = nn.Linear(hidden_size, num_actions)
        self.log_std = nn.Linear(hidden_size, num_actions)
        self.action_scale = (action_high - action_low)/2.0
        self.action_bias = (action_high + action_low)/2.0
    
    def forward(self, state):
        x = torch.relu(self.linear1(state))
        x = torch.relu(self.linear2(x))
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        return mean, log_std
    
    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        dist = torch.distributions.Normal(mean, std)
        x_t = dist.rsample()              # reparameterization trick
        y_t = torch.tanh(x_t)
        action = self.action_scale*y_t + self.action_bias
        log_prob = dist.log_prob(x_t)
        # Enforcing action bound
        log_prob -= torch.log(self.action_scale*(1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(dim=1, keepdim=True)
        return action, log_prob

class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size=256):
        super().__init__()
        # Q1
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        # Q2
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear5 = nn.Linear(hidden_size, hidden_size)
        self.linear6 = nn.Linear(hidden_size, 1)
    
    def forward(self, state, action):
        xu = torch.cat([state, action], dim=1)  # [batch, state_dim+action_dim]
        x1 = torch.relu(self.linear1(xu))
        x1 = torch.relu(self.linear2(x1))
        x1 = self.linear3(x1)
        x2 = torch.relu(self.linear4(xu))
        x2 = torch.relu(self.linear5(x2))
        x2 = self.linear6(x2)
        return x1, x2

class SACAgent:
    def __init__(self, num_state, num_action, args):
        self.gamma = args['gamma']
        self.tau = args['tau']
        self.alpha = args['alpha']
        self.automatic_entropy_tuning = args['automatic_entropy_tuning']

        self.critic = QNetwork(num_state, num_action).to(device)
        self.critic_target = QNetwork(num_state, num_action).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=args['lr'])

        self.policy = GaussianPolicy(num_state, num_action).to(device)
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=args['lr'])

        if self.automatic_entropy_tuning:
            self.target_entropy = -float(num_action)
            self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
            self.alpha_optimizer = optim.Adam([self.log_alpha], lr=args['lr'])

        self.replay_buffer = ReplayBuffer(args['replay_size'])
        self.batch_size = args['batch_size']
    
    def select_action(self, state, evaluate=False):
        # state => shape (state_dim,)
        with torch.no_grad():
            state = state.unsqueeze(0).to(device, dtype=torch.float32)
            if evaluate:
                mean, _ = self.policy.forward(state)
                y_t = torch.tanh(mean)
                action = self.policy.action_scale*y_t + self.policy.action_bias
            else:
                action, _ = self.policy.sample(state)
        return action[0]  # shape (num_action,)

    def update_parameters(self):
        if len(self.replay_buffer) < self.batch_size:
            return None, None, None
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        done = done.view(-1,1)
        reward = reward.view(-1,1)

        # 1) Critic
        with torch.no_grad():
            next_a, next_log_prob = self.policy.sample(next_state)
            q1_next, q2_next = self.critic_target(next_state, next_a)
            min_q_next = torch.min(q1_next, q2_next) - self.alpha*next_log_prob
            target_q = reward + (1-done)*self.gamma*min_q_next

        q1, q2 = self.critic(state, action)
        qf1_loss = nn.MSELoss()(q1, target_q)
        qf2_loss = nn.MSELoss()(q2, target_q)
        qf_loss = qf1_loss + qf2_loss
        self.critic_optimizer.zero_grad()
        qf_loss.backward()
        self.critic_optimizer.step()

        # 2) Actor
        pi, log_pi = self.policy.sample(state)
        qf1_pi, qf2_pi = self.critic(state, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)
        policy_loss = (self.alpha*log_pi - min_qf_pi).mean()
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # 3) Alpha
        alpha_loss = torch.tensor(0., device=device)
        if self.automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            self.alpha = self.log_alpha.exp()

        # 4) Soft update
        for tp, p in zip(self.critic_target.parameters(), self.critic.parameters()):
            tp.data.copy_(self.tau*p.data + (1-self.tau)*tp.data)
        
        return qf_loss.item(), policy_loss.item(), alpha_loss.item()

    def save_model(self, name="test"):
        torch.save(self.policy.state_dict(), f"sac_actor_{name}.pth")
        torch.save(self.critic.state_dict(), f"sac_critic_{name}.pth")
    
    def load_model(self, name):
        self.policy.load_state_dict(torch.load(f"sac_actor_{name}.pth"))
        self.critic.load_state_dict(torch.load(f"sac_critic_{name}.pth"))

In [6]:
################################################################################
# 7) TRAINING LOOP
################################################################################

def train():
    args = {
        'gamma': 0.99,
        'tau': 0.005,
        'lr': 3e-4,
        'alpha': 0.2,
        'automatic_entropy_tuning': True,
        'replay_size': 100000,
        'batch_size': 256,
    }
    model_name = "test3"
    num_state = 4
    num_action = 3
    reward_scale = 1000
    zeros = torch.zeros(num_state - num_action, device=device)

    max_steps = int(1e6)
    max_episode_len = 80
    start_steps = max_steps//50
    episode_reward = 0.
    episode = 0
    step_in_episode = 0

    max_delta=9
    
    rand_state = 2 * 0.25
    d_step_size = 2*0.04

    # Example spins
    spins = torch.tensor([0.0, 2.0, 4.0, 0.0], device=device)
    dSigma = -0.4
    # Initial "state" => shape(4,)
    init_state = torch.tensor([-0.4, 2., 3.6, 7.6], device=device)

    zs=torch.tensor(generate_random_points(lambda_0=.37), device=device)

    agent = SACAgent(num_state, num_action, args)
    state = init_state.clone()
    
    for global_step in range(max_steps):
        step_in_episode += 1

        # 1) Sample 3D action from policy
        action_3d = agent.select_action(state, evaluate=False)  # shape(3,)

        # 2) Next state logic => your "environment"
        #    E.g. next_state = state + [0, action_3d]
        full_action_4d = torch.cat([zeros, d_step_size*(action_3d-0.5)], dim=0)  # shape(4,)
        next_state = state + full_action_4d

        # 3) Calculate reward on GPU
        #    least_sq_std_rew expects [N_deltas, N_state], so we do [1,4]
        next_state_expanded = next_state.unsqueeze(0)  # shape(1,4)
        rew_tensor = reward_scale*least_sq_std_rew(next_state_expanded, zs, spins, dSigma, N_lsq=10, n_states_rew=2)

        # clamp if needed
        if torch.isnan(rew_tensor).any() or torch.isinf(rew_tensor).any():
            rew_tensor = torch.zeros_like(rew_tensor, device=device)
        reward_val = rew_tensor[0]

        done = (step_in_episode >= max_episode_len) or next_state[-1]>max_delta

        if not done:
            for i in range(num_state-1):
                if next_state[i] > next_state[i+1]:
                    done = True
                    reward_val *= reward_scale*max_episode_len
                    reward_val = -torch.abs(reward_val)
                    # print(reward_val)

        # 4) Store transition in replay:
        #    State shape(4), Action shape(3), Reward scalar, Next state(4)
        agent.replay_buffer.push(
            state.float(),     # 4D
            action_3d.float(),# 3D
            reward_val.float().unsqueeze(0),
            next_state.float(),
            float(done)
        )

        state = next_state
        episode_reward += reward_val.item()

        if done:
            print(f"Episode {episode}, Reward={episode_reward}")
            episode_reward = 0.
            episode += 1
            step_in_episode = 0
            state = init_state.clone()
            state*=1+rand_state*(torch.rand(num_state,device=device)-0.5)
            state[0]=init_state[0]

        # 5) SAC update
        if global_step > start_steps:
            with record_function("update_parameters"):
                qf_loss, policy_loss, alpha_loss = agent.update_parameters()

    agent.save_model(model_name)


if __name__ == "__main__":
    train()




Episode 0, Reward=-437821065.0089559
Episode 1, Reward=-275577200.75779366
Episode 2, Reward=8310.230030156461
Episode 3, Reward=-450612315.5968969
Episode 4, Reward=-427353506.96124643
Episode 5, Reward=-499409248.4681259
Episode 6, Reward=-652829373.5479805
Episode 7, Reward=-386993463.43673736
Episode 8, Reward=-682195049.4147943
Episode 9, Reward=-449560249.01335806
Episode 10, Reward=-387175148.2879974
Episode 11, Reward=-471006782.1656338
Episode 12, Reward=-375895379.74053574
Episode 13, Reward=-386469092.4498616
Episode 14, Reward=-405854377.32937586
Episode 15, Reward=-505243278.41771454
Episode 16, Reward=-205947818.423298
Episode 17, Reward=-846002211.8373456
Episode 18, Reward=-253072881.01744583
Episode 19, Reward=-405299039.0822533
Episode 20, Reward=13384.878141981118
Episode 21, Reward=-208279627.63762912
Episode 22, Reward=-534871666.7814691
Episode 23, Reward=-472386318.6601708
Episode 24, Reward=-671684298.0672535
Episode 25, Reward=-217909716.46822503
Episode 26, Re

In [7]:
def predict():
    args = {
        'gamma': 0.99,
        'tau': 0.005,
        'lr': 3e-4,
        'alpha': 0.2,
        'automatic_entropy_tuning': True,
        'replay_size': 100000,
        'batch_size': 128,
    }
    model_name = "test3"
    num_state = 4
    num_action = 3
    reward_scale = 1000
    zeros = torch.zeros(num_state - num_action, device=device)

    max_episode_len = 50
    episode_reward = 0.

    rand_state = 2 * 0.25

    # Example spins
    spins = torch.tensor([0.0, 2.0, 4.0, 0.0], device=device)
    dSigma = -0.4
    # Initial "state" => shape(4,)
    init_state = torch.tensor([-0.4, 2., 3.6, 7.6], device=device)

    zs=torch.tensor(generate_random_points(lambda_0=.37), device=device)

    agent = SACAgent(num_state, num_action, args)
    if len(model_name)>0:
        agent.load_model(model_name)
    state = init_state.clone()

    num_sample = 100
    max_episode_len = 100
    episode_reward = 0.

    traj=[]

    for _ in range(num_sample):
        episode_reward = 0.
        state = init_state.clone()
        state*=1+rand_state*(torch.rand(num_state,device=device)-0.5)
        state[0]=init_state[0]

        temp = [state.detach().cpu().tolist()]

        for _ in range(max_episode_len):
            # 1) Sample 3D action from policy
            action_3d = agent.select_action(state, evaluate=False)  # shape(3,)

            # 2) Next state logic => your "environment"
            #    E.g. next_state = state + [0, action_3d]
            full_action_4d = torch.cat([zeros, 0.2*(action_3d-0.5)], dim=0)  # shape(4,)
            next_state = state + full_action_4d

            # 3) Calculate reward on GPU
            #    least_sq_std_rew expects [N_deltas, N_state], so we do [1,4]
            next_state_expanded = next_state.unsqueeze(0)  # shape(1,4)
            rew_tensor = reward_scale*least_sq_std_rew(next_state_expanded, zs, spins, dSigma, N_lsq=10, n_states_rew=2)

            # clamp if needed
            if torch.isnan(rew_tensor).any() or torch.isinf(rew_tensor).any():
                rew_tensor = torch.zeros_like(rew_tensor, device=device)
            reward_val = rew_tensor[0]

            state = next_state
            episode_reward += reward_val.item()
            temp.append(state.detach().cpu().tolist())
        traj.append(temp)
    return traj

if __name__ == "__main__":
    traj=predict()
    print(traj)

[[[-0.4000000059604645, 2.1211998462677, 3.5713624954223633, 6.847171783447266], [-0.4000000059604645, 2.2077081203460693, 3.6687498092651367, 6.936207294464111], [-0.4000000059604645, 2.254722833633423, 3.766906499862671, 7.023483753204346], [-0.4000000059604645, 2.1132915019989014, 3.8650476932525635, 7.11212158203125], [-0.4000000059604645, 2.1937029361724854, 3.9636473655700684, 7.2044758796691895], [-0.4000000059604645, 2.254909038543701, 4.0609846115112305, 7.279606342315674], [-0.4000000059604645, 2.2618043422698975, 4.157670974731445, 7.378492832183838], [-0.4000000059604645, 2.2668304443359375, 4.25408935546875, 7.4715118408203125], [-0.4000000059604645, 2.3382761478424072, 4.353107452392578, 7.524835109710693], [-0.4000000059604645, 2.4304447174072266, 4.450499534606934, 7.483363151550293], [-0.4000000059604645, 2.478384494781494, 4.548092842102051, 7.581724643707275], [-0.4000000059604645, 2.4345099925994873, 4.6435370445251465, 7.660626411437988], [-0.4000000059604645, 2.52

In [8]:
for i in range(len(traj)):
    print(traj[i][0])
    print(traj[i][-1])
    print('----------------------------')

[-0.4000000059604645, 2.1211998462677, 3.5713624954223633, 6.847171783447266]
[-0.4000000059604645, -0.27832239866256714, 11.616085052490234, 9.168302536010742]
----------------------------
[-0.4000000059604645, 2.1039748191833496, 3.0154223442077637, 8.62626838684082]
[-0.4000000059604645, -0.4218192398548126, 10.91402816772461, 8.928770065307617]
----------------------------
[-0.4000000059604645, 1.7834502458572388, 3.232863187789917, 6.5470967292785645]
[-0.4000000059604645, -0.20407135784626007, 11.660879135131836, 8.700546264648438]
----------------------------
[-0.4000000059604645, 2.472268581390381, 4.037909984588623, 9.264799118041992]
[-0.4000000059604645, -0.40520909428596497, 12.090240478515625, 8.622880935668945]
----------------------------
[-0.4000000059604645, 1.5702309608459473, 3.693993091583252, 9.335162162780762]
[-0.4000000059604645, -0.2826486825942993, 11.437080383300781, 8.963704109191895]
----------------------------
[-0.4000000059604645, 1.9030828475952148, 3.9

## Abandoned

In [9]:
# ################################################################################
# # 2) Numba/CUDA Kernels for Hypergeometric & Reward
# ################################################################################

# @cuda.jit(device=True)
# def _2F1_device(a, b, c, z_r, z_i):
#     if c == 0.0:
#         return 1.0, 0.0
#     real_accum = 1.0
#     imag_accum = 0.0
#     term_r = 1.0
#     term_i = 0.0
#     for n in range(1, max_terms):
#         denom = n*(c+n-1.0)
#         poch = ((a+n-1.0)*(b+n-1.0))/denom
#         zr = term_r*z_r - term_i*z_i
#         zi = term_r*z_i + term_i*z_r
#         term_r = poch*zr
#         term_i = poch*zi
#         if (abs(term_r)<tol) and (abs(term_i)<tol):
#             break
#         real_accum += term_r
#         imag_accum += term_i
#     return real_accum, imag_accum

# @cuda.jit(device=True)
# def compute_g_device(d_val, s_val, x, y):
#     # h = (Delta + s)/2, hb = (Delta - s)/2
#     h = 0.5*(d_val + s_val)
#     hb = 0.5*(d_val - s_val)

#     # Compute hypergeometric parts:
#     z_r, z_i = x, y
#     xp = 1.0 - x
#     yp = -y

#     # We'll directly compute necessary pieces in kernel
#     # See original snippet for reference
#     # 2F1(h,h;2h;z)
#     fhz_r, fhz_i = _2F1_device(h,h,2*h,z_r,z_i)
#     # 2F1(hb,hb;2hb;z*)
#     fhbz_b_r, fhbz_b_i = _2F1_device(hb,hb,2*hb,z_r,-z_i)
#     # 2F1(h,h;2h;z*)
#     fhz_b_r, fhz_b_i = _2F1_device(h,h,2*h,z_r,-z_i)
#     # 2F1(hb,hb;2hb;z)
#     fhb_z_r, fhb_z_i = _2F1_device(hb,hb,2*hb,z_r,z_i)

#     # Construct g similarly as done in initial code:
#     # g = (z^h z̅^{hb} * F_h(z)*F_hb(z̅) + z^{hb} z̅^{h}*F_h(z̅)*F_hb(z)) / (1 + δ_{h,hb})
#     # For simplicity, we assume (1+KroneckerDelta)=1 since h,hb from continuous spectrum
#     # Approximate powers:
#     r = math.sqrt(x*x + y*y)
#     theta = math.atan2(y,x)
#     # z^h z̅^{hb}:
#     # z^h = r^h * e^{i h theta}, z̅^{hb} = r^{hb} * e^{-i hb theta}
#     # combined: z^h z̅^{hb} = r^{h+hb} e^{i(h - hb)theta} = r^{(h+hb)}e^{i s theta}, s=(h-hb)
#     d = h+hb
#     s = h-hb
#     r_pow_d = r**d
#     cos_s_th = math.cos(s*theta)
#     sin_s_th = math.sin(s*theta)
#     # z^h z̅^{hb} = r^d(cos(sθ) + i sin(sθ))
#     # Similarly, z^{hb} z̅^h = r^d(cos(-sθ) + i sin(-sθ)) = r^d(cos(sθ) - i sin(sθ))

#     # F_h(z)*F_hb(z̅)
#     step1_r = fhz_r*fhbz_b_r - fhz_i*fhbz_b_i
#     step1_i = fhz_r*fhbz_b_i + fhz_i*fhbz_b_r

#     # multiply by z^h z̅^{hb}:
#     # (cos(sθ)+i sin(sθ))*(step1_r + i step1_i)
#     tmp1_r = step1_r*cos_s_th - step1_i*sin_s_th
#     tmp1_i = step1_r*sin_s_th + step1_i*cos_s_th
#     T1_r = r_pow_d*tmp1_r
#     T1_i = r_pow_d*tmp1_i

#     # F_h(z̅)*F_hb(z)
#     step2_r = fhz_b_r*fhb_z_r - fhz_b_i*fhb_z_i
#     step2_i = fhz_b_r*fhb_z_i + fhz_b_i*fhb_z_r

#     # z^{hb} z̅^{h} = r^d(cos(-sθ)+i sin(-sθ)) = r^d(cos(sθ)-i sin(sθ))
#     # multiply step2 by (cos(sθ)- i sin(sθ))
#     tmp2_r = step2_r*cos_s_th + step2_i*sin_s_th
#     tmp2_i = -step2_r*sin_s_th + step2_i*cos_s_th
#     T2_r = r_pow_d*tmp2_r
#     T2_i = r_pow_d*tmp2_i

#     g_r = T1_r + T2_r
#     g_i = T1_i + T2_i
#     return g_r, g_i

# @cuda.jit
# def compute_G_element(d_val, s_val, x, y, dphi):
#     xp = 1.0 - x
#     yp = -y
#     # Compute factors |z-1|^{2 dSigma} and |z|^{2 dSigma}
#     zm1_r = x - 1.0
#     zm1_i = y
#     # |z-1|^2 = (zm1_r^2+zm1_i^2)
#     r1 = math.sqrt(zm1_r*zm1_r + zm1_i*zm1_i)
#     # |z| = sqrt(x^2+y^2)
#     r2 = math.sqrt(x*x + y*y)

#     # (|z-1|^(2 dSigma)*g1 - |z|^(2 dSigma)*g2)
#     r1_pow = r1**(2*dphi)
#     r2_pow = r2**(2*dphi)    
#     # g(h,hb,z)
#     g1_r, g1_i = compute_g_device(d_val, s_val, x, y)
    
#     # g(h,hb,1-z)
#     g2_r, g2_i = compute_g_device(d_val, s_val, xp, yp)
    
#     val_r = r1_pow*g1_r - r2_pow*g2_r
#     # Imag parts negligible for final eq (expected real), ignoring imaginary since original eq focuses on real
#     if s_val==0:
#         return val_r/2
    
#     return val_r

# @cuda.jit
# def compute_g_delta_kernel(d_arr, s_arr, x_arr, y_arr, dphi, g_delta_matrix):
#     i_g,i_z , i_state= cuda.grid(3)
#     N_z = x_arr.size
#     N_g, N_state = d_arr.shape  # Dimensions of d_arr (M = outer size, n_states = inner size)
#     #total_elements = N_z * N_g * N_state
#     stride_g,stride_z, stride_state = cuda.gridsize(3)  # Grid dimensions (stride for each axis)
#     for k in range(i_z,N_z,stride_z):
#         for i in range(i_g,N_g,stride_g):
#             for j in range(i_state,N_state,stride_state):
#                     x = x_arr[k]
#                     y = y_arr[k]
                        
#                     d_val = d_arr[i,j]
#                     s_val = s_arr[j]
                             
#                     g_delta_matrix[i,k,j] = compute_G_element(d_val, s_val, x, y, dphi)

# @cuda.jit
# def compute_W_v(d_max, x_arr, y_arr, dphi,W,v):
#     N_z = x_arr.size
#     k=cuda.grid(1)
#     if k<N_z:
#         x = x_arr[k]
#         y = y_arr[k]
#         xp = 1.0 - x
#         yp = -y
#         # Compute factors |z-1|^{2 dSigma} and |z|^{2 dSigma}
#         zm1_r = x - 1.0
#         zm1_i = y
#         # |z-1|^2 = (zm1_r^2+zm1_i^2)
#         r1 = math.sqrt(zm1_r*zm1_r + zm1_i*zm1_i)
#         # |z| = sqrt(x^2+y^2)
#         r2 = math.sqrt(x*x + y*y)
#         r1_pow = r1**(2*dphi)
#         r2_pow = r2**(2*dphi)
#         G0=compute_G_element(d_max, 0, x, y, dphi) 
#         W[k]=1.0/(G0*G0)
#         v[k]= r1_pow - r2_pow

# ################################################################################
# # 3) GPU-Based Calculation of c, residual, then PyTorch post-processing
# ################################################################################

# def calculate_c_rew(d_values, s_values, x_values, y_values, dphi,d_max=9.0,N_lsq=20):
#     N_z = len(x_values)
#     N_deltas = len(d_values)
#     assert N_z%N_lsq==0, "N_z should be integer times N_lsq"
#     N_stat=N_z//N_lsq #how many times to calculate lsq for std stats

#     N_state=len(d_values[0])

#     host_array = np.zeros(( N_deltas,N_z, N_state), dtype=float)
#     g_delta_device = cuda.to_device(host_array)

#     d_device = d_values
#     s_device = s_values
#     x_device = x_values
#     y_device = y_values

#     threads_per_block = (4,8,4)
#     blocks_per_grid_x = math.ceil(N_deltas / threads_per_block[0])
#     blocks_per_grid_y = math.ceil(N_z / threads_per_block[1])
#     blocks_per_grid_z = math.ceil(N_state / threads_per_block[2])

#     compute_g_delta_kernel[(blocks_per_grid_x,blocks_per_grid_y,blocks_per_grid_z), threads_per_block](
#         d_device, s_device, x_device, y_device, dphi, g_delta_device
#     )
    
#     W_device = cuda.device_array(N_z, dtype=float)
#     v_device = cuda.device_array(N_z, dtype=float)
#     threadsperblock = 32
#     blockspergrid = (N_z+ (threadsperblock - 1)) // threadsperblock
#     compute_W_v[blockspergrid,threadsperblock](d_max, x_device, y_device, dphi,W_device,v_device)

#     G=torch.tensor(g_delta_device, device='cuda') # (N_deltas,N_z, N_state)
#     G=G.view((N_deltas,N_stat,N_lsq,N_state)).permute(1,0,2,3)

#     W_diag=torch.tensor(W_device, device='cuda')
#     #W N_z=N_batch*N_lsq_sq
#     W_diag=W_diag.view((N_stat,N_lsq))

#     v0=torch.tensor(v_device, device='cuda')#N_z
#     v= v0.unsqueeze(0).repeat(N_deltas,1)  # N_deltas,N_z
#     v=v.view((N_deltas,N_stat,N_lsq)).permute(1,0,2) #N_stat, N_deltas, N_lsq

#     # Step 1: Apply W to G
#     W=torch.diag_embed(W_diag) #N_stat,N_lsq,N_lsq

#     WG = torch.einsum('szz,sgzn->sgzn', W, G)

#     # WG: (N_deltas, N_z, N_state)
#     # G: (N_deltas, N_z, N_state)
#     # We sum over z: 'gzn,gzm->gnm'
#     GT_WG = torch.einsum('sgzn,sgzm->sgnm', G, WG)  
#     WG_v = torch.einsum('sgzn,sgz->sgn', WG, v)

#     c =-1.* torch.linalg.solve(GT_WG, WG_v)  # Solve for each batch
#     Gc = torch.einsum('sgzn,sgn->sgz', G, c)
#     # Step 2: Compute residual vector (G.c - v)
#     residual_vector = Gc + v  # Residual vector: [N_stat,N_deltas, N_lsq]

#     # Step 3: Compute weighted residuals 
#     W_residual =  torch.einsum('sy,sgy->sgy', W_diag,residual_vector )    # Weighted residual vector: [N_g, N_z]

#     # Step 4: Compute the final residual for each batch
#     residual = torch.einsum('sgz,sgz->sg', W_residual, residual_vector)  # Result: [N_stat,N_deltas]

#     return c,residual #[N_stat,N_deltas, N_state] and [N_stat,N_deltas]



# def least_sq_std_rew(d_values,zs, s_values, dSigma,N_lsq=20,n_states_rew=2):
#     #assert d_values.dim()==1,"for multiple deltas use least_sq_std_rew_z(..) instead"
#     x_values = zs.real
#     y_values = zs.imag

#     # 2) Construct G matrix using GPU
#     # d_values = deltas, s_values = spins
#     # According to definition: h=(Δ+s)/2, hb=(Δ-s)/2
#     # The kernel expects d_val=Δ, s_val=s
#     cs,rews = calculate_c_rew(d_values, s_values, x_values, y_values, dSigma,N_lsq=N_lsq) #[N_stat,N_deltas, N_state] and [N_stat,N_deltas]
#     c_mean=torch.mean(cs,0) #[N_deltas, N_state] 
#     c_std=torch.std(cs,0) #[N_deltas, N_state]
#     r_stat=c_std/c_mean #[N_deltas, N_state]
#     r=-torch.sum(torch.log(torch.abs(r_stat[:,0:n_states_rew])),dim=1) #[N_deltas, ]

#     return r

# def least_sq_std_rew_z(d_values,zs, s_values, dSigma,N_lsq=20,n_physical=1,n_states_rew=2):
#     x_values = zs.real
#     y_values = zs.imag

#     # 2) Construct G matrix using GPU
#     # d_values = deltas, s_values = spins
#     # According to definition: h=(Δ+s)/2, hb=(Δ-s)/2
#     # The kernel expects d_val=Δ, s_val=s
    
#     cs,rews = calculate_c_rew(d_values, s_values, x_values, y_values, dSigma,N_lsq=N_lsq) #[N_stat,N_deltas, N_state] and [N_stat,N_deltas]
#     c_mean=torch.mean(cs,0) #[N_deltas, N_state] 
#     c_std=torch.std(cs,0) #[N_deltas, N_state]
#     r_stat=c_std/c_mean #[N_deltas, N_state]
#     r=-torch.sum(torch.log(r_stat[:,0:n_states_rew]),dim=1) #[N_deltas, ]
        
#     return sum(r[0:n_physical])/sum(r[n_physical:])        

# def least_sq_rew_z(d_values,zs, s_values, dSigma,N_lsq=20,n_physical=1):
#     x_values = zs.real
#     y_values = zs.imag

#     # 2) Construct G matrix using GPU
#     # d_values = deltas, s_values = spins
#     # According to definition: h=(Δ+s)/2, hb=(Δ-s)/2
#     # The kernel expects d_val=Δ, s_val=s
    
#     cs,rews = calculate_c_rew(d_values, s_values, x_values, y_values, dSigma,N_lsq=N_lsq)
    
#     r = -torch.log(rews) 
        
#     return sum(r[0:n_physical])/sum(r[n_physical:])
