# DRL-Based Cache Replacement Policy

In [None]:
from cache.Cache import Cache
from agents.DQNAgent import DQNAgent
from cache.DataLoader import DataLoaderZipf

## Cache Model Parameters

In [None]:
# Auto-generated zipf simulation data
dataloader = DataLoaderZipf(5000, 10000, 1.3, num_progs=100)

# Cache size: 5, 10, 50, ...
cache_size = 50

# Various combination of features
base_features = ('Base',)
base_UT_features = ('Base', 'UT')
base_UT_CT_features = ('Base', 'UT', 'CT')

# Our reward function
our_reward = dict(name='our'
    , alpha=0.5, psi=10
    , mu=1, beta=0.3
)

# Zhong et. al. reward function
zhong_reward = dict(name='zhong'
    , short_reward=1.0
    , long_span=100
    , beta=0.5
)

## Create Cache Environment

In [None]:
env = Cache(dataloader, cache_size
    , feature_selection=base_features
    , reward_params=our_reward
    , allow_skip=False
)

## Setup DRL Agent

In [None]:
RL = DQNAgent(env.n_actions, env.n_features,
    learning_rate=0.01,
    reward_decay=0.9,

    # Epsilon greedy
    e_greedy_min=(0.0, 0.1),
    e_greedy_max=(0.2, 0.8),
    e_greedy_init=(0.1, 0.5),
    e_greedy_increment=(0.005, 0.01),
    e_greedy_decrement=(0.005, 0.001),

    history_size=50,
    dynamic_e_greedy_iter=25,
    reward_threshold=3,
    explore_mentor = 'LRU',

    replace_target_iter=100,
    memory_size=10000,
    batch_size=128,

    output_graph=False,
    verbose=0
)    

## Learning Procedure

In [None]:
step = 0
for episode in range(100):
    # initial observation
    observation = env.reset()

    while True:
        # agent choose action based on observation
        action = RL.choose_action(observation)

        # agent take action and get next observation and reward
        observation_, reward = env.step(action)

        # break while loop when end of this episode
        if env.hasDone():
            break

        RL.store_transition(observation, action, reward, observation_)

        if (step > 20) and (step % 5 == 0):
            RL.learn()

        # swap observation
        observation = observation_

        # report every 100 step
        if step % 100 == 0:
            mr = env.miss_rate()
            print("Episode=%d, Step=%d: Accesses=%d, Misses=%d, MissRate=%f"
                % (episode, step, env.total_count, env.miss_count, mr)
            )

        step += 1