In [1]:
using ReinforcementLearning, Intervals, Flux

const RLBase = ReinforcementLearningBase

println("Done RL Import")

Base.@kwdef mutable struct MyGrid <: AbstractEnv
    N::Int
    obs::Vector{Float32}
    max_steps::Int    
    
    rewards::Float32
    done::Bool 
    t::Int
end

function MyGrid(; n)
    img = zeros(Float32, n,n)
    vect = ones(Float32, 3)
    obs = vcat(vect,vec(img))
    println(obs)
    
	MyGrid(n, obs, 10, 0, false, 0)
end

RLBase.action_space(env::MyGrid) = Base.OneTo(25)

RLBase.legal_action_space(env::MyGrid) = legal_action_space_mask(env)

function RLBase.legal_action_space_mask(env::MyGrid)
	mask = Vector(undef, 25)
	fill!(mask, true)

    mask_index = Base.rand((1:25))
	mask[mask_index] = false
	return mask
end



function (env::MyGrid)(action) #This function takes the action, makes the environment step, and gives reward
    
    env.obs[action[1]] = 6 #action[1] avoids some strange error I don't understand
    
    env.rewards+=1
    
    env.t+=1
    #Max steps check
    if env.t >= env.max_steps
        env.done = true
    end
end

RLBase.state(env::MyGrid, ::Observation{Vector{Float32}}) = env.obs
RLBase.state_space(env::MyGrid, ::Observation{Vector{Float32}}) = Space(fill( Interval{Int64}(0, (10*env.N) ) , (env.N^2+3)))

RLBase.is_terminated(env::MyGrid) = env.done

function RLBase.reset!(env::MyGrid)
    #Reset Counter and rewards (and termination variable)
    env.t = 0
    env.rewards = 0
    env.done = false    
    #Reset State
    img = zeros(env.N,env.N)
    vect = ones(3)
    env.obs = vcat(vect,vec(img))
end;

RLBase.reward(env::MyGrid) = env.rewards

RLBase.NumAgentStyle(::MyGrid) = SINGLE_AGENT
RLBase.DynamicStyle(::MyGrid) = SEQUENTIAL
RLBase.ActionStyle(::MyGrid) = FULL_ACTION_SET
RLBase.InformationStyle(::MyGrid) = PERFECT_INFORMATION

RLBase.StateStyle(::MyGrid) = Observation{Vector{Float32}}()

RLBase.RewardStyle(::MyGrid) = STEP_REWARD
RLBase.UtilityStyle(::MyGrid) = GENERAL_SUM
RLBase.ChanceStyle(::MyGrid) = DETERMINISTIC

println("Preparing Environment")
N=5
env = MyGrid(; n=N)

Done RL Import
Preparing Environment
Float32[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


┌ Info: Precompiling ReinforcementLearning [158674fc-8238-5cab-b5ba-03dfc80d1318]
└ @ Base loading.jl:1423


# MyGrid

## Traits

| Trait Type        |                          Value |
|:----------------- | ------------------------------:|
| NumAgentStyle     |                  SingleAgent() |
| DynamicStyle      |                   Sequential() |
| InformationStyle  |           PerfectInformation() |
| ChanceStyle       |                Deterministic() |
| RewardStyle       |                   StepReward() |
| UtilityStyle      |                   GeneralSum() |
| ActionStyle       |                FullActionSet() |
| StateStyle        | Observation{Vector{Float32}}() |
| DefaultStateStyle | Observation{Vector{Float32}}() |

## Is Environment Terminated?

No

## State Space

`Space{Vector{Interval{Int64, Closed, Closed}}}(Interval{Int64, Closed, Closed}[Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50), Interval{Int64, Closed, Closed}(0, 50)])`

## Action Space

`Base.OneTo(25)`

## Current State

```
Float32[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
```


In [2]:
UPDATE_FREQ = 32
N_ENV = 1
input_size = 28

agent = Agent(
        policy = QBasedPolicy(
            learner = A2CLearner(
                approximator = ActorCritic(
                    actor = Chain(
                        Dense(input_size, 4*input_size, relu),
                        Dense(4*input_size, 8*input_size, relu),
                        Dense(8*input_size, 6*input_size, relu),
                        Dense(6*input_size, 4*input_size, relu),
                        Dense(4*input_size, 2*input_size, relu),
                        Dense(2*input_size, 25),
                    ),

                    critic = Chain(
                        Dense(input_size, 4*input_size, relu),
                        Dense(4*input_size, 8*input_size, relu),
                        Dense(8*input_size, 6*input_size, relu),
                        Dense(6*input_size, 4*input_size, relu),
                        Dense(4*input_size, 2*input_size, relu),
                        Dense(2*input_size, 1),
                    ),

                    optimizer = ADAM(1e-3),
                ),
                γ = 0.99f0,
                actor_loss_weight = 1.0f0,
                critic_loss_weight = 0.5f0,
                entropy_loss_weight = 0.001f0,
                update_freq = UPDATE_FREQ,
            ),

            #explorer = BatchExplorer(GumbelSoftmaxExplorer()),
            explorer=EpsilonGreedyExplorer(
                kind=:exp,
                ϵ_stable=0.01,
                decay_steps=500,
            ),
        ),


	    trajectory = CircularArraySARTTrajectory(;
            capacity = UPDATE_FREQ,
            state = Matrix{Float32} => (input_size, N_ENV),
            action = Vector{Int} => (N_ENV,),
            reward = Vector{Float32} => (N_ENV,),
            terminal = Vector{Bool} => (N_ENV,),
        ),
)



hook = TotalBatchRewardPerEpisode(N_ENV)


TotalBatchRewardPerEpisode([Float64[]], [0.0], true)

In [3]:
run(agent, env, StopAfterEpisode(8), hook)

[32mProgress:  25%|██████████▎                              |  ETA: 0:00:26[39m

values: Float32[-0.0038848114, -0.032028243, 0.017632825, -0.10160909, 0.036231592, -0.00055729167, 0.036761574, -0.0018544598, -0.020733155, 0.031485848, 0.0010244212, -0.049003795, 0.05508644, -0.07202485, -0.020319697, 0.022446275, -0.020134708, 0.04001557, -0.021725155, 0.11125428, -0.022910202, -0.004275955, -0.013713145, -0.009959723, 0.031443264] and mask: Any[true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true] 
values: Float32[0.14228946, -0.11814363, -0.0008185971, -0.2064806, 0.1050148, -0.044275798, 0.12893558, -0.090099104, -0.11713499, -0.067914784, -0.057469904, -0.017017415, 0.10286204, -0.08789596, 0.028185278, -0.094691284, 0.11457257, 0.20240201, -0.08254326, 0.23744449, 0.03208153, -0.035807293, -0.047383066, -0.0770152, -0.03638997] and mask: Any[true, true, true, true, true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, tru

[32mProgress:  50%|████████████████████▌                    |  ETA: 0:01:02[39m

, -0.009959723, 0.031443264] and mask: Any[true, true, true, true, true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true] 
values: Float32[-0.038415164, -0.044313733, -0.0020383564, -0.46849522, 0.15079787, -0.084743224, 0.35165018, 0.046136394, 0.031342324, 0.036070082, -0.0605635, -0.17892462, 0.19221704, -0.1866836, -0.11397059, -0.08862956, 0.17589338, 0.29528287, -0.13907135, 0.3665968, -0.013234453, 0.10649144, 0.0070157903, 0.1757529, -0.04392179] and mask: Any[true, true, true, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true] 
values: Float32[0.13072672, -0.16501029, 0.024486931, -0.56495017, 0.16458873, -0.15651307, 0.31735715, -0.06400016, 0.053914215, 0.10384582, 0.07348574, -0.20582813, 0.32547018, -0.20267855, 0.10385047, -0.20534365, 0.13624178, 0.19607197, -0.14817892, 0.48519084, 0.005649657, -0.050482407, -0.0218576

LoadError: MethodError: objects of type Vector{Float32} are not callable
Use square brackets [] for indexing an Array.



In [21]:
values = [-0.06736562, -0.14842498, -0.03438756, -0.2348835, 0.19270504, -0.06881056, 0.06740674, -0.13019112, 0.09400712, 0.10236878, -0.114360444, -0.24241893, 0.17329161, -0.14359772, -0.08873647, -0.14410512, 0.04276317, 0.0986805, -0.17130606, 0.33509204, -0.094900094, -0.05302921, -0.07780004, 0.046892297, 0.07059874]
mask = [true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, true, true, true, true]

vals_len = length(values)

mask_len = length(mask)


println("values length: ", vals_len, " and mask length: ", mask_len)

findmax(values, mask)[2]


values length: 25 and mask length: 25


5

In [None]:
#This is a similar bug test for my "production" version of the environment that I use in my research. Action space is much larger (64)
#Seems like it should be the same error!

values = [-0.018166097470102248, -0.05312329404601712, -0.11821760730890844, 0.01130529832199928, 0.25495336620659376, 0.03688280274304784, 0.24376921960025238, 0.09219057380660353, 0.13532862570075035, -0.006394795989184387, 0.06804521432895054, -0.10223188563086141, 0.09034899740849847, -0.014090588857831439, 0.10504844247607419, -0.1010042722516204, -0.06139492012262019, 0.04547940613530445, 0.14835828847573806, 0.17224788209978356, 0.022583950763336302, 0.09815816625439103, -0.1086762831646083, -0.04875057012926273, -0.004055656945551531, 0.0771577468229819, 0.28446659449818534, 0.05801666527334296, 0.012560477373359635, -0.020215410519104898, -0.17811680727261842, 0.0659528204220787, -0.1398748086613198, 0.0683758730845624, 0.1397017022419469, -0.09369770003150565, 0.005993454421915061, 0.0015033004316932636, -0.06735809777500537, 0.03484625372778246, -0.15758421574125123, 0.07659801068910996, -0.12459759486470504, -0.0419593543471892, -0.00940457689693222, -0.2681972368143283, 0.14750588128489334, 0.030406348586110007, -0.03335642103175048, 0.10261564895656809, -0.0047844669243396094, 0.0925932127099294, 0.03755944363783806, -0.07972708307755658, -0.09741959496801679, -0.03881579249499037, -0.07526886357099384, -0.1716122010795189, -0.01426435240576198, -0.11978894015708387, 0.009270593322248836, -0.12199733236932064, 0.11713120363075379, -0.03448852026738303]

mask = [true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, true, true, false, true, true, true, false, true, true, true, true, true, true, true, false, true, true, false, true, false, true, true, true, false, true, true, true]

vals_len = length(values)

mask_len = length(mask)


println("values length: ", vals_len, " and mask length: ", mask_len)

findmax(values, mask)[2]


values length: 64 and mask length: 64


27

Old Code for complex state version of environment (Ignore!)

In [35]:
UPDATE_FREQ = 32
N_ENV = 1

agent = Agent(
        policy = PPOPolicy(
            approximator = ActorCritic(
                    actor = Chain(
                                x -> ( reshape(x[4:28], (5,5,1,1)), x[1:3]), #Parallel expects a tuple input - very difficult to pass a tuple as state
                                Parallel(vcat, 
                                        Chain(Conv((3,3), 1 => 1, relu; stride = 1, pad = 0),
                                        Flux.flatten),

                                        Chain(Dense(3, 6, relu), Dense(6, 3, relu), Dense(3, 1, relu)) ),
                                Dense(10, 40, relu),
                                Dense(40,80,relu),
                                Dense(80,50,relu),
                                Dense(50,25)),

                    critic = Chain(
                                x -> ( reshape(x[4:28], (5,5,1,1)), x[1:3]),
                                Parallel(vcat, 
                                        Chain(Conv((3,3), 1 => 1, relu; stride = 1, pad = 0),
                                        Flux.flatten),

                                        Chain(Dense(3, 6, relu), Dense(6, 3, relu), Dense(3, 1, relu)) ),
                                Dense(10, 40, relu),
                                Dense(40,80,relu),
                                Dense(80,50,relu),
                                Dense(50,25,relu),
                                Dense(25,5,relu),
                                Dense(5,1)),
            
                optimizer = ADAM(1e-3),
            ),
            γ = 0.99f0,
            λ = 0.95f0,
            clip_range = 0.1f0,
            max_grad_norm = 0.5f0,
            n_epochs = 4,
            n_microbatches = 4,
            actor_loss_weight = 1.0f0,
            critic_loss_weight = 0.5f0,
            entropy_loss_weight = 0.001f0,
            update_freq = UPDATE_FREQ,
        ),
        trajectory = PPOTrajectory(;
            capacity = UPDATE_FREQ, 
            state = Matrix{Float32} => (28, N_ENV),
            action = Vector{Int} => (N_ENV,),
            action_log_prob = Vector{Float32} => (N_ENV,),
            reward = Vector{Float32} => (N_ENV,),  
            terminal = Vector{Bool} => (N_ENV,),
        ),
    )


#stop_condition = StopAfterStep(30000, is_show_progress=true)
hook = TotalBatchRewardPerEpisode(N_ENV)

TotalBatchRewardPerEpisode([Float64[]], [0.0], true)

In [36]:
run(agent, env, StopAfterEpisode(8), hook)

[32mProgress:  25%|██████████▎                              |  ETA: 0:00:03[39m

LoadError: BoundsError: attempt to access 1-element view(::Matrix{Float32}, 1, :) with eltype Float32 at index [33]