In [15]:
import random

import numpy as np
import pandas as pd

import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from pyrl.environments import DiceRolling, BetDiceRolling
from pyrl.agents import Agent
from pyrl.exp import Experiment

from pyrl.tools import State, Action, DiscreteState, DiscreteAction
from pyrl.tabular import Q
from pyrl.tabular.algorithms import q_learning


from typing import List

In [2]:
def random_policy(state: DiscreteState, available_actions: List[DiscreteAction]) -> DiscreteAction:
    return random.choice(available_actions)

### Dice rolling

In [9]:
environment = DiceRolling()

exp = Experiment(environment)

In [10]:
agent = Agent(random_policy)

history = exp.explore(agent, 10000)

100%|██████████| 10000/10000 [00:00<00:00, 252106.99it/s]


In [11]:
S, A = environment.spaces()

q = Q(S, A)

optimal_q, optimal_policy = q_learning(q, history)

In [12]:
agent = Agent(optimal_policy)

history = exp.explore(agent, 10000)

100%|██████████| 10000/10000 [00:00<00:00, 51062.68it/s]


In [13]:
rewards = [transition[2] for transition in history if transition[-1].terminal]
np.mean(rewards)

4.248

In [28]:
ix = list(q.values.keys())
ix = pd.MultiIndex.from_tuples(ix)

values = list(q.values.values())

pd.DataFrame(index=ix, data=values)

Unnamed: 0,Unnamed: 1,0
S0,"Action ""Бросить""",2.908324
S0,"Action ""Не бросать""",0.684311
S1,"Action ""Бросить""",2.760661
S1,"Action ""Не бросать""",1.402558
S2,"Action ""Бросить""",2.667749
S2,"Action ""Не бросать""",1.92626
S3,"Action ""Бросить""",2.493393
S3,"Action ""Не бросать""",2.28832
S4,"Action ""Бросить""",2.784586
S4,"Action ""Не бросать""",2.967899
