In [34]:
using ReinforcementLearning, Intervals, Flux

const RLBase = ReinforcementLearningBase

println("Done RL Import")

Base.@kwdef mutable struct MyGrid <: AbstractEnv
    N::Int
    obs::Vector{Float32}
    max_steps::Int    
    
    rewards::Float32
    done::Bool 
    t::Int
end

function MyGrid(; n)
    img = zeros(Float32, n,n)
    vect = ones(Float32, 3)
    obs = vcat(vect,vec(img))
    println(obs)
    
	MyGrid(n, obs, 10, 0, false, 0)
end

RLBase.action_space(env::MyGrid) = Base.OneTo(25)

RLBase.legal_action_space(env::MyGrid) = legal_action_space_mask(env)

function RLBase.legal_action_space_mask(env::MyGrid)
	mask = Vector(undef, 25)
	fill!(mask, true)

    mask_index = Base.rand((1:25))
	mask[mask_index] = false
	return mask
end



function (env::MyGrid)(action) #This function takes the action, makes the environment step, and gives reward
    
    env.obs[action[1]] = 6 #action[1] avoids some strange error I don't understand
    
    env.rewards+=1
    
    env.t+=1
    #Max steps check
    if env.t >= env.max_steps
        env.done = true
    end
end

RLBase.state(env::MyGrid, ::Observation{Vector{Float32}}) = env.obs
RLBase.state_space(env::MyGrid, ::Observation{Vector{Float32}}) = Space(fill( Interval{Int64}(0, (10*env.N) ) , (env.N^2+3)))

RLBase.is_terminated(env::MyGrid) = env.done

function RLBase.reset!(env::MyGrid)
    #Reset Counter and rewards (and termination variable)
    env.t = 0
    env.rewards = 0
    env.done = false    
    #Reset State
    img = zeros(env.N,env.N)
    vect = ones(3)
    env.obs = vcat(vect,vec(img))
end;

RLBase.reward(env::MyGrid) = env.rewards

RLBase.NumAgentStyle(::MyGrid) = SINGLE_AGENT
RLBase.DynamicStyle(::MyGrid) = SEQUENTIAL
RLBase.ActionStyle(::MyGrid) = FULL_ACTION_SET
RLBase.InformationStyle(::MyGrid) = PERFECT_INFORMATION

RLBase.StateStyle(::MyGrid) = Observation{Vector{Float32}}()

RLBase.RewardStyle(::MyGrid) = STEP_REWARD
RLBase.UtilityStyle(::MyGrid) = GENERAL_SUM
RLBase.ChanceStyle(::MyGrid) = DETERMINISTIC

println("Preparing Environment")
N=5
env = MyGrid(; n=N)

Done RL Import
Preparing Environment
Float32[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


# MyGrid

## Traits

| Trait Type        |                          Value |
|:----------------- | ------------------------------:|
| NumAgentStyle     |                  SingleAgent() |
| DynamicStyle      |                   Sequential() |
| InformationStyle  |           PerfectInformation() |
| ChanceStyle       |                Deterministic() |
| RewardStyle       |                   StepReward() |
| UtilityStyle      |                   GeneralSum() |
| ActionStyle       |                FullActionSet() |
| StateStyle        | Observation{Vector{Float32}}() |
| DefaultStateStyle | Observation{Vector{Float32}}() |

## Is Environment Terminated?

No

## State Space

`Space{Vector{Interval{Int64, Closed, Closed}}}(Interval{Int64, Closed, Closed}[Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50)])`

## Action Space

`Base.OneTo(25)`

## Current State

```
Float32[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
```


In [37]:
UPDATE_FREQ = 32
N_ENV = 1
input_size = 28

agent = Agent(
        policy = QBasedPolicy(
            learner = A2CLearner(
                approximator = ActorCritic(
                    actor = Chain(
                        Dense(input_size, 4*input_size, relu),
                        Dense(4*input_size, 8*input_size, relu),
                        Dense(8*input_size, 6*input_size, relu),
                        Dense(6*input_size, 4*input_size, relu),
                        Dense(4*input_size, 2*input_size, relu),
                        Dense(2*input_size, 25),
                    ),

                    critic = Chain(
                        Dense(input_size, 4*input_size, relu),
                        Dense(4*input_size, 8*input_size, relu),
                        Dense(8*input_size, 6*input_size, relu),
                        Dense(6*input_size, 4*input_size, relu),
                        Dense(4*input_size, 2*input_size, relu),
                        Dense(2*input_size, 1),
                    ),

                    optimizer = ADAM(1e-3),
                ),
                γ = 0.99f0,
                actor_loss_weight = 1.0f0,
                critic_loss_weight = 0.5f0,
                entropy_loss_weight = 0.001f0,
                update_freq = UPDATE_FREQ,
            ),

            #explorer = BatchExplorer(GumbelSoftmaxExplorer()),
            explorer=EpsilonGreedyExplorer(
                kind=:exp,
                ϵ_stable=0.01,
                decay_steps=500,
            ),
        ),


	    trajectory = CircularArraySARTTrajectory(;
            capacity = UPDATE_FREQ,
            state = Matrix{Float32} => (input_size, N_ENV),
            action = Vector{Int} => (N_ENV,),
            reward = Vector{Float32} => (N_ENV,),
            terminal = Vector{Bool} => (N_ENV,),
        ),
)



hook = TotalBatchRewardPerEpisode(N_ENV)


TotalBatchRewardPerEpisode([Float64[]], [0.0], true)

In [38]:
run(agent, env, StopAfterEpisode(8), hook)

LoadError: MethodError: objects of type Vector{Float32} are not callable
Use square brackets [] for indexing an Array.

Old Code for complex state version of environment (Ignore!)

In [35]:
UPDATE_FREQ = 32
N_ENV = 1

agent = Agent(
        policy = PPOPolicy(
            approximator = ActorCritic(
                    actor = Chain(
                                x -> ( reshape(x[4:28], (5,5,1,1)), x[1:3]), #Parallel expects a tuple input - very difficult to pass a tuple as state
                                Parallel(vcat, 
                                        Chain(Conv((3,3), 1 => 1, relu; stride = 1, pad = 0),
                                        Flux.flatten),

                                        Chain(Dense(3, 6, relu), Dense(6, 3, relu), Dense(3, 1, relu)) ),
                                Dense(10, 40, relu),
                                Dense(40,80,relu),
                                Dense(80,50,relu),
                                Dense(50,25)),

                    critic = Chain(
                                x -> ( reshape(x[4:28], (5,5,1,1)), x[1:3]),
                                Parallel(vcat, 
                                        Chain(Conv((3,3), 1 => 1, relu; stride = 1, pad = 0),
                                        Flux.flatten),

                                        Chain(Dense(3, 6, relu), Dense(6, 3, relu), Dense(3, 1, relu)) ),
                                Dense(10, 40, relu),
                                Dense(40,80,relu),
                                Dense(80,50,relu),
                                Dense(50,25,relu),
                                Dense(25,5,relu),
                                Dense(5,1)),
            
                optimizer = ADAM(1e-3),
            ),
            γ = 0.99f0,
            λ = 0.95f0,
            clip_range = 0.1f0,
            max_grad_norm = 0.5f0,
            n_epochs = 4,
            n_microbatches = 4,
            actor_loss_weight = 1.0f0,
            critic_loss_weight = 0.5f0,
            entropy_loss_weight = 0.001f0,
            update_freq = UPDATE_FREQ,
        ),
        trajectory = PPOTrajectory(;
            capacity = UPDATE_FREQ, 
            state = Matrix{Float32} => (28, N_ENV),
            action = Vector{Int} => (N_ENV,),
            action_log_prob = Vector{Float32} => (N_ENV,),
            reward = Vector{Float32} => (N_ENV,),  
            terminal = Vector{Bool} => (N_ENV,),
        ),
    )


#stop_condition = StopAfterStep(30000, is_show_progress=true)
hook = TotalBatchRewardPerEpisode(N_ENV)

TotalBatchRewardPerEpisode([Float64[]], [0.0], true)

In [36]:
run(agent, env, StopAfterEpisode(8), hook)

[32mProgress:  25%|██████████▎                              |  ETA: 0:00:03[39m

LoadError: BoundsError: attempt to access 1-element view(::Matrix{Float32}, 1, :) with eltype Float32 at index [33]