# Roadmap of DreamerV3 main code loop

Goal: want to know how data flows from beginning to end


- load config file
    - likely a yaml with attridict so I can easily call from it as if they are class attributes
- set seeds
    - if using many different libraries with different seeds, ensure to call seeds from each of them
    - e.g., numpy -> np.random.seed, torch -> torch.manual_seed, etc...
- set up folders for evaluations, checkpoints, plots, metrics, etc.
- initialize the agent environment, likely using gymnasium, and get the env properties
    - check formal definition of the space to ensure proper
    - e.g., check if wanting continuous or discrete (gym.spaces.Box vs gym.spaces.Discrete)
    - NOTE: will need extensive review when deciding what data to work with to ensure it is handled properly and that the encoder/decoder works with it
    - the original DreamerV3 (danijar) has extensive logic for handling all kinds of inputs (his encoder/decoder)

In [58]:
import gymnasium as gym
env = gym.make("CarRacing-v3")
print(env.action_space)
# Box (-2.0, 2.0, (1,), float32)
# low = -2.0
# high = 2.0
# shape = (1,)
print(env.observation_space)
# Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)
# [-1. -1. -8.]
# [1. 1. 8.]
# these spaces are the formal definitions (ranges) of allowed values for every element in that space
# e.g.: 
# observation_space = obs
# obs[0] ∈ [−1, +1]
# obs[1] ∈ [−1, +1]
# obs[2] ∈ [−8, +8]
# dtype(obs) = float32
# size cannot be greater than 

Box([-1.  0.  0.], 1.0, (3,), float32)
Box(0, 255, (96, 96, 3), uint8)


- initialize the core dreamer agent using the env properties and config (decision on architecture)
    - NOTE: core architecture not gone through here
- initialize dreamer interaction with the environment for data collection
    - start the loop for how many episodes (how many times we get 'done') of data wanted, and keep track of the scores outside of it
        - init the states (action, latent, and hidden) with zeros and a batch of 1
        - reset the environment to produce a new observation (.reset() will trigger .observation() automatically in gymnasium)
        - encode the new observation to get the initial starting point
        - start the loop with the break condition when 'done' appears
            - call GRU to get hidden state
            - using hidden state + obs, get the stochastic latent from posterior network
            - get the action from the actor using the hidden and latent states
            - using the action, take the next step (.step() in the env)
            - append to the buffer to collect the data
                - want the current observation, the action taken, the reward, the next observation, and whether it is 'done' or not
            - set the observation to the new one gotten from the action that the actor decided to take (observation = next_observation)
            - keep track of metrics, e.g., the score based on rewards, step count taken, total episodes run through, total env steps taken
        - returns the avg score based on the number of episodes run through

### Now that we have the first burst of data and model initialized, training begins

- determine how many steps wanted and how many replays wanted
    - i.e., how many total gradient steps (gradient updates) and how many runs through the current data before adding more to the buffer
- initialize outer loop: the number of total iterations (gradient steps // replays); 
    - initialize inner loop: the amount of replays (how many times sampled from the current buffer before adding more)
        - sample from the buffer of batch size and sequence length (how many actions taken per sample)
        - get the states from the world model training
            - encodes observations from the observation data
            - loop over the length of sequence length (time interval)
                - get current recurrent state using prev action sequence (t-1), previous latent state, and previous recurrent state
                - get the logits from the prior network (no observation needed as input for this network)
                - get the logits and the latent stochastic from the posterior network (requires current recurrent state and the current encoded observation)
                - append all this acquired data to their respected lists
                - set the new calculated states to their previous variables (e.g., prevRecurrent = recurrent)
            - calculate the loss for decoding
                - using the hidden states and latent states, decode to make observations
                - model these decoded observations as a normal distribution 
                - take the log probs of this distribution and see where the true observation are
                    - if true obs close to the predicted distribution mean, then loss is low
                - compute the mean of all loss calculations, gotten from each sample in the batch, to get one scalar value representing overall loss
            - calculate the loss for the reward (a scalar value on how good or bad the action was)
                - feed the hidden + stochastic states through the reward network
                - get out a normal distribution, specifiaclly, the mean and std
                - compute the mean of all loss calculations
            - calculate the loss for the prior and posterior networks
                - require computing KL divergence between both networks
                    - ensure that gradients flow properly through the prior and posterior (compute two kl for posterior and prior)
                    - clamp the bottom end to 'freeNats' to make sure that KL doesn't go down to 0 (posterior collapse)
                    - make sure that the prior network follows the posterior 
                - get the total kl loss (mean of post + prior net loss)
            - using all losses (decoder, reward and kl), get world model loss (sum of all losses)
                - if continuation is on (i.e., if we want the continuepredictor to try to figure out when the action sequence terminates--'done')
                    - compute the continuationLoss and add this to the world model loss
            - zero grad, backward pass, parameter clipping, optimizer step
            - report metrics (worldmodel loss, decoder loss, reward loss, kl loss)
            - return the metrics and the hidden + stochastic state
        - get the behavior metrics from behavior training
            - 

        - save checkpoints, metrics and/or evaluate if wanted
    - add more data to the buffer before running the next inner loop
