In [52]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)  # State and time as input
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, state_time):
        stacked_tensor = torch.stack((state_time[0], state_time[1])).unsqueeze(0)
        x = torch.relu(self.fc1(state_time))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.softmax(x, dim=0)

def get_action(state_time, policy_net):
    with torch.no_grad():
        action_probs = policy_net(torch.tensor(np.array(state_time),dtype=torch.float32)).detach().numpy()
        action = np.random.choice(action_dim, p=action_probs.ravel())
    return action

def get_state(curr_state, drift, days, action):
    """
    Calculate the next state based on the current state and drift.
    """
    if action == 1:
        return 2 * drift * days + 1
    delta = np.random.randint(-drift, drift + 1, dtype=int)
    return max(0, min(curr_state + delta, 2 * days * drift))

def get_reward(curr_state, action, days, drift, start_price, strike_price):
    """
    Calculate the reward based on the current state and action.
    """
    if action == 0:
        return 0
    else:
        return curr_state - days * drift + start_price - strike_price

class TRPOAgent:
    def __init__(self, state_dim, action_dim, kl_constraint=0.01, max_kl_step=0.01, gamma=0.99, delta=0.01):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.kl_constraint = kl_constraint
        self.max_kl_step = max_kl_step
        self.gamma = gamma
        self.delta = delta
        
        self.policy_net = PolicyNetwork(state_dim, action_dim)
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)

    def compute_advantage(self, rewards):
        advantages = []
        advantage = 0.0
        for reward in reversed(rewards):
            advantage = self.gamma * advantage + reward
            advantages.insert(0, advantage)
        return advantages

    def surrogate_loss(self, old_probs, state_times, actions, advantages):
        new_probs = self.policy_net(state_times)
        old_probs = old_probs.gather(1, actions)
        new_probs = new_probs.gather(1, actions)

        ratio = new_probs / (old_probs + 1e-10)
        surr_loss = torch.mean(ratio * advantages)
        return surr_loss

    def update_policy(self, samples):
        state_times, actions, rewards, next_state_times = zip(*[(st, a, r, nst) for st, a, r, nst in samples])
        rewards = np.array(rewards)
        advantages = self.compute_advantage(rewards)
        old_probs = torch.stack([self.policy_net(torch.tensor(np.array([state, time]),dtype=torch.float32)) for state, time in state_times])
        print(old_probs)
        # for _ in range(10):
        #     policy_loss = -self.surrogate_loss(old_probs, 
        #                                         state_times, 
        #                                         torch.tensor(actions, dtype=torch.int64), 
        #                                         torch.tensor(advantages, dtype=torch.float32))
        #     self.optimizer.zero_grad()
        #     policy_loss.backward()
        #     self.optimizer.step()

        #     new_probs = torch.stack([self.policy_net(state, time) for state, time in state_times])



        #     kl_div = torch.mean(new_probs * torch.log((new_probs + 1e-10) / (old_probs + 1e-10)))
        #     if kl_div < self.kl_constraint:
        #         break

        # return kl_div.item()

# Parameters for the environment
days = 5
drift = 5
start_price = 500
strike_price = 510

# Parameters for TRPO
state_dim = 2
action_dim = 2
num_episodes = 1000
max_steps_per_episode = days

agent = TRPOAgent(state_dim, action_dim)

start_states = []
for _ in range(1):  # Number of episodes
    state = np.random.randint(0, 2 * days * drift + 1)
    done = False
    time = np.random.randint(0, days)
    samples = []
    start_states.append((state, time))
    while not done:
        action = get_action([state, time],agent.policy_net)
        next_state = get_state(state, drift, days, action)
        reward = get_reward(state, action, days, drift, start_price, strike_price)
        next_time = time + 1

        samples.append(((state, time), action, reward, (next_state, next_time)))

        if time == days or action == 1:
            done = True
        state = next_state
        time = next_time
    agent.update_policy(samples)
    # print(f"Episode: {_}, KL Divergence: {kl_div}")


tensor([[0.6714, 0.3286],
        [0.7154, 0.2846]], grad_fn=<StackBackward0>)
