In [2]:
import simglucose
import gymnasium as gym
from collections import namedtuple, deque
import random

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'h_0', 'c_0', 'h_0_next', 'c_0_next'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:

from gymnasium.wrappers import FlattenObservation


def paper_reward_function(BG_last_hour):
    G = BG_last_hour[-1]
    if G >= 70 and G <= 180:
        return 0.5
    if G > 180 and G <= 200:
        return -0.9
    if G > 200 and G <= 250:
        return -1.2
    if G > 250 and G <= 350:
        return -1.5
    if G > 30 and G < 70:
        return -1.8
    else:
        return -2

gym.envs.register(
    id="simglucose-basal",
    entry_point="simglucose.envs:T1DSimEnvBolus",
    kwargs={
        "patient_name": ["adult#001"],
        "reward_fun": paper_reward_function,
        "history_length": 1,
        "enable_meal": True,
    },
)




In [None]:

def create_env():

    env = gym.make("simglucose-basal")

    env = FlattenObservation(env)

    return env



env = create_env()


print(env.action_space)
print(env.observation_space)

# random action
env.reset()
for _ in range(100):
    env.step(env.action_space.sample())  # take a random action


### Train the network?

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import namedtuple
from gymnasium.wrappers import FlattenObservation

In [3]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'h_0', 'c_0', 'h_0_next', 'c_0_next'))

class DQNNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(DQNNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [59]:
class DDQNAgent:
    def __init__(self, state_size, action_size, replay_memory_capacity=10000, batch_size=32, gamma=0.99):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Define the Q-networks
        self.policy_net = DQNNetwork(state_size, action_size).to(self.device)
        self.target_net = DQNNetwork(state_size, action_size).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        # Define the optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=1e-3)

        # Define replay memory
        self.replay_memory = []
        self.replay_memory_capacity = replay_memory_capacity
        self.batch_size = batch_size
        self.gamma = gamma

        # Counter for target network update
        self.target_update_counter = 0
        self.target_update_frequency = 100  # Update target network every N steps

    def select_action(self, state, epsilon):
        if np.random.rand() < epsilon:
            # Explore: choose a random action
            return np.random.randint(0, env.action_space.n)
        else:
            with torch.no_grad():
                # TODO: state[0] in training loop?
                #flat_state = np.concatenate([np.array(state[0]), np.array(list(state[1].values()))])
                #state_tensor = torch.FloatTensor(flat_state).unsqueeze(0).to(self.device)
                state_tensor = torch.from_numpy(state).float().to(self.device)
                q_values = self.policy_net(state_tensor)
                return q_values.argmax().item()

    def store_transition(self, transition):
        if len(self.replay_memory) < self.replay_memory_capacity:
            self.replay_memory.append(transition)
        else:
            self.replay_memory[random.randint(0, self.replay_memory_capacity - 1)] = transition

    def update_policy(self):
        if len(self.replay_memory) < self.batch_size:
            return

        transitions = random.sample(self.replay_memory, self.batch_size)
        batch = Transition(*zip(*transitions))

        state_batch = torch.FloatTensor(batch.state).to(self.device)
        action_batch = torch.LongTensor(batch.action).to(self.device)
        reward_batch = torch.FloatTensor(batch.reward).to(self.device)
        next_state_batch = torch.FloatTensor(batch.next_state).to(self.device)

        # Compute Q-values for the current state
        q_values = self.policy_net(state_batch).gather(1, action_batch.unsqueeze(1))

        # Compute Q-values for the next state using the target network
        next_q_values = self.target_net(next_state_batch).max(1)[0].detach()

        # Compute the expected Q-values
        expected_q_values = reward_batch + self.gamma * next_q_values

        # Compute Huber loss
        loss = nn.SmoothL1Loss()(q_values, expected_q_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network periodically
        self.target_update_counter += 1
        if self.target_update_counter % self.target_update_frequency == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

In [60]:
def paper_reward_function(BG_last_hour):
    G = BG_last_hour[-1]
    if G >= 70 and G <= 180:
        return 0.5
    if G > 180 and G <= 200:
        return -0.9
    if G > 200 and G <= 250:
        return -1.2
    if G > 250 and G <= 350:
        return -1.5
    if G > 30 and G < 70:
        return -1.8
    else:
        return -2

gym.envs.register(
    id="simglucose-basal",
    entry_point="simglucose.envs:T1DSimEnvBolus",
    kwargs={
        "patient_name": ["adult#001"],
        "reward_fun": paper_reward_function,
        "history_length": 1,
        "enable_meal": True,
    },
)

def create_env():

    env = gym.make("simglucose-basal")

    env = FlattenObservation(env)

    return env


env = create_env()

  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


In [61]:
# DDQ Agent setup
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
print('state_size', state_size)
print('action_size', action_size)
agent = DDQNAgent(state_size=state_size, action_size=action_size)

state_size 3
action_size 7


In [63]:
# Training loop
num_episodes = 1000
epsilon = 0.1  # Exploration rate

for episode in range(num_episodes):
    state = env.reset()  # Reset the environment and get the initial state
    # TODO: I don't know which values to use here. If i don't do this then I have problems with the data types (date and so)
    state = state[0]
    total_reward = 0

    while True:
        # Select action using epsilon-greedy strategy
        action = agent.select_action(state, epsilon)

        # Take the selected action in the environment
        next_state, reward, done, _, _= env.step(action)

        # Store the transition in the replay memory
        agent.store_transition(Transition(state, action, next_state, reward, 0, 0, 0, 0))

        # Update the policy network
        agent.update_policy()

        # Update the current state
        state = next_state
        total_reward += reward

        if done:
            break

    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")

# After training, you can use the policy network to act in the environment
state = env.reset()
while True:
    action = agent.select_action(state, epsilon=0.1)  
    next_state, _, done, _ = env.step(action)
    state = next_state
    if done:
        break


Episode: 1, Total Reward: -125.99999999999982
Episode: 2, Total Reward: 52.20000000000003
Episode: 3, Total Reward: -35.99999999999997
Episode: 4, Total Reward: -14.299999999999947
Episode: 5, Total Reward: 4.7000000000001325
Episode: 6, Total Reward: -5.199999999999901
Episode: 7, Total Reward: -60.19999999999983
Episode: 8, Total Reward: -10.399999999999949
Episode: 9, Total Reward: -42.699999999999925
Episode: 10, Total Reward: 96.89999999999964
Episode: 11, Total Reward: -82.29999999999981
Episode: 12, Total Reward: -72.69999999999989
Episode: 13, Total Reward: -64.09999999999984
Episode: 14, Total Reward: -31.799999999999933
Episode: 15, Total Reward: -16.09999999999984
Episode: 16, Total Reward: -21.599999999999937
Episode: 17, Total Reward: -40.79999999999986
Episode: 18, Total Reward: -245.80000000000067
Episode: 19, Total Reward: -63.59999999999984
Episode: 20, Total Reward: 170.19999999999982
Episode: 21, Total Reward: 63.39999999999981
Episode: 22, Total Reward: -12.09999999