# Testing Parallel Batch Learning

Версия вторая: параллельные игры с уменьшающимся списком
- вместо одной среды теперь будет список сред (игр) длины `play_batch_size`, с каждой из них собрать список состояний и подать на вход `agent.act` как батч; обновить список после завершения всех игр
- плюс: баланс опыта и обучения
- плюс: быстро (из-за `agent.act`)
- минус: `agent.act` получает батчи разного размера если завершилась только часть игр
- плюс: не нужно переписывать `Trainer.train`

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
! cp -a -n /content/drive/MyDrive/wordle-rl/. /content/

In [3]:
! mkdir /content/drive/MyDrive/wordle-rl/tests

In [4]:
! pip install cpprb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting cpprb
  Downloading cpprb-10.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: cpprb
Successfully installed cpprb-10.7.1


In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial
from collections import defaultdict
import pickle

from wordle.wordlenp import Wordle
from environment.environment import Environment, StateYesNo, StateVocabulary
from environment.action import ActionVocabulary, ActionLetters, ActionCombLetters
from dqn.agent import Agent
from dqn.train import Trainer
from replay_buffer.cpprb import PrioritizedReplayBuffer, ReplayBuffer

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')
import torch
import numpy as np
np.random.seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Датасеты

In [2]:
word_list = Wordle._load_vocabulary('wordle/guesses.txt', astype=np.array)

def make_data(n_answers, n_guesses):
    guesses = np.random.choice(word_list, size=n_guesses, replace=False)
    answers = np.random.choice(guesses, size=n_answers, replace=False)
    return answers, guesses

### 10 answers, 100 guesses

In [3]:
answers_10_100, guesses_10_100 = make_data(10, 100)
print(answers_10_100)

['abmho' 'cites' 'aware' 'blays' 'acker' 'rawin' 'anile' 'eorls' 'feers'
 'sadza']


### 100 answers, 100 guesses

In [4]:
answers_100_100, guesses_100_100 = make_data(100, 100)
print(answers_100_100)

['lilac' 'orles' 'twirl' 'unled' 'sings' 'grind' 'sheaf' 'benny' 'slews'
 'karst' 'rimus' 'lossy' 'joker' 'leash' 'scopa' 'viols' 'giron' 'raiks'
 'lummy' 'renig' 'tinds' 'infos' 'logon' 'drill' 'gudes' 'ammon' 'bhoot'
 'hurry' 'noils' 'coven' 'beryl' 'margs' 'sorbo' 'momes' 'scald' 'potch'
 'flows' 'torus' 'prill' 'scuts' 'brith' 'tamin' 'sewar' 'joram' 'aldol'
 'hazel' 'texes' 'sibbs' 'truth' 'spoil' 'hames' 'actin' 'maces' 'rayas'
 'thuya' 'sugan' 'felly' 'newsy' 'bolos' 'mimeo' 'chems' 'dicty' 'liefs'
 'scuff' 'burps' 'abyes' 'zones' 'cuspy' 'kerve' 'haith' 'amino' 'zygal'
 'kokum' 'zambo' 'icier' 'piers' 'sambo' 'laden' 'barge' 'solei' 'mauts'
 'groat' 'pearl' 'curse' 'jujus' 'troop' 'bilge' 'sibyl' 'gassy' 'elain'
 'daube' 'feyly' 'duals' 'hoper' 'hains' 'beige' 'poove' 'miffy' 'lesbo'
 'dawds']


### 50 answers, 200 guesses

In [5]:
answers_50_200, guesses_50_200 = make_data(50, 200)
print(answers_50_200)

['toran' 'burka' 'umpie' 'brock' 'civic' 'beige' 'hoiks' 'biffo' 'nagas'
 'sheol' 'malls' 'matzo' 'peeve' 'deshi' 'mooli' 'scaud' 'ameba' 'wadds'
 'bayts' 'glees' 'kaput' 'bitos' 'comae' 'dosed' 'rabis' 'neats' 'tutti'
 'stays' 'smoky' 'chase' 'resaw' 'simas' 'sowne' 'rorid' 'rebec' 'deawy'
 'hinny' 'sores' 'cerge' 'yogas' 'fouet' 'wheel' 'sowfs' 'talus' 'yabas'
 'topee' 'sabin' 'unbox' 'dyers' 'qophs']


### 300 answers, 300 guesses

In [6]:
answers_300_300, guesses_300_300 = make_data(300, 300)
print(answers_300_300)

['hoied' 'aarti' 'dyers' 'ingot' 'hasty' 'tices' 'tache' 'deoxy' 'sutta'
 'score' 'fiefs' 'sling' 'ready' 'wests' 'dobes' 'tells' 'bitts' 'roper'
 'veena' 'hewer' 'chats' 'jarta' 'ceric' 'olent' 'feare' 'bodge' 'fleer'
 'prats' 'spiny' 'tryps' 'welts' 'march' 'jelly' 'furor' 'okapi' 'penni'
 'curve' 'altar' 'aboon' 'salut' 'amido' 'razor' 'vouch' 'prill' 'ganev'
 'dukka' 'noxal' 'briar' 'wills' 'trigo' 'dusts' 'meter' 'veale' 'hafiz'
 'dosha' 'wilis' 'ohmic' 'silds' 'giust' 'blimy' 'zerda' 'mucus' 'abrin'
 'nandu' 'larva' 'cruds' 'kaies' 'hussy' 'dolce' 'birch' 'madam' 'chase'
 'onely' 'molys' 'scape' 'sauce' 'amate' 'mohur' 'vagal' 'royne' 'spics'
 'azuki' 'acres' 'shops' 'sicks' 'sunup' 'cosey' 'louse' 'aweel' 'skosh'
 'twoer' 'loves' 'wents' 'reest' 'winna' 'rosed' 'mbira' 'rangy' 'omega'
 'moira' 'typey' 'romal' 'bachs' 'floss' 'scath' 'roast' 'moola' 'moles'
 'witch' 'rabbi' 'chest' 'aulos' 'yokes' 'aspen' 'sepic' 'lirot' 'lemon'
 'musts' 'drouk' 'kudzu' 'yacka' 'sonny' 'hived' 'l

### 100 answers, 2000 guesses

In [7]:
answers_100_2000, guesses_100_2000 = make_data(100, 2000)
print(answers_100_2000)

['fumed' 'rival' 'zizit' 'wests' 'scrog' 'gryce' 'seils' 'sugar' 'glims'
 'pharm' 'pyets' 'kilts' 'bings' 'emyde' 'duads' 'shahs' 'spaza' 'spore'
 'redub' 'sidas' 'choco' 'woofs' 'sukuk' 'spumy' 'mirex' 'sluse' 'tween'
 'nodus' 'wands' 'unled' 'rates' 'toxin' 'lotes' 'wacko' 'ariot' 'baisa'
 'sways' 'roads' 'poked' 'popsy' 'gonif' 'vutty' 'bicep' 'parky' 'braid'
 'ports' 'spyal' 'match' 'spook' 'scowp' 'sdein' 'lovie' 'torii' 'souks'
 'vibey' 'genny' 'clues' 'decaf' 'diced' 'delay' 'mirly' 'flogs' 'lotos'
 'whine' 'seems' 'jerry' 'scram' 'gosse' 'roped' 'pipis' 'spank' 'seder'
 'doorn' 'evict' 'buteo' 'ponks' 'miffy' 'potin' 'rathe' 'papaw' 'local'
 'tolus' 'apode' 'jouks' 'decad' 'temes' 'wafts' 'liter' 'kagos' 'piste'
 'ogmic' 'fyles' 'brace' 'adage' 'hepar' 'bales' 'molal' 'eject' 'seles'
 'commy']


## Plotting Utility

In [8]:
def plot_results(tasks_results, figname):
    n_tasks = len(tasks_results)
    _, ax = plt.subplots(1, n_tasks, figsize=(4*n_tasks, 5))
    
    color = ['b','g','r','c','m','y','k','w']

    # over tasks
    for i, (task_name, results) in enumerate(tasks_results.items()):
        # over methods
        for c, (method_name, res) in enumerate(results.items()):
            train_timers, train_win_rates, test_timers, test_win_rates = res
            ax[i].plot(train_timers, train_win_rates, label=method_name+' (train)', c=color[c], alpha=0.2)
            ax[i].plot(test_timers, test_win_rates, label=method_name+' (test)', c=color[c])
        ax[i].set_xlabel('time, s')
        ax[i].set_ylabel('win rate, s')
        ax[i].legend()
        ax[i].set_title(task_name)
    plt.savefig(figname + '.svg', bbox_inches='tight')
    plt.show()

## Sampling

In [9]:
soft_rewards = {'B':1, 'Y':2, 'G':3, 'win':20, 'lose':-10, 'step':-4}
hard_rewards = {'B':0, 'Y':0, 'G':0, 'win':10, 'lose':-10, 'step':-2}
step_rewards = {'B':0, 'Y':1, 'G':1, 'win':10, 'lose':-10, 'step':-5}
char_rewards = {'B':0, 'Y':1, 'G':1, 'win':10, 'lose':-10, 'step':-2}

In [10]:
tasks_results = defaultdict(dict)

### Uniform

In [11]:
def experiment(answers, guesses, n_batches, n_batches_warm, play_batch_size=8):
    env_list = []
    for _ in range(play_batch_size):
        env = Environment(
            rewards=step_rewards,
            wordle=Wordle(vocabulary=guesses, answers=answers),
            state_instance=StateYesNo()
        )
        env_list.append(env)

    agent = Agent(
        state_size=env_list[0].state.size,
        action_instance=ActionCombLetters(k=1, vocabulary=guesses),
        # here is what we are experimenting with
        replay_buffer=ReplayBuffer(state_size=env_list[0].state.size),
    )

    trainer = Trainer(
        env_list, agent,
        n_batches=n_batches,
        n_batches_warm=n_batches_warm
    )
    
    res = trainer.train(eps_decay=0.9995, nickname=f'uniform-{len(answers)}-{len(guesses)}')
    return res

In [12]:
res = experiment(
    answers_10_100, guesses_10_100,
    n_batches=500, n_batches_warm=15,
)

tasks_results['10/100']['uniform'] = res


Batch   62	Time: 35 s	Agent Eps: 0.97	Train Win Rate: 6.45%	Test Win Rate: 40.00%	Test Mean Steps: 2.25

Batch  124	Time: 68 s	Agent Eps: 0.94	Train Win Rate: 27.42%	Test Win Rate: 80.00%	Test Mean Steps: 2.25

Batch  186	Time: 100 s	Agent Eps: 0.91	Train Win Rate: 35.48%	Test Win Rate: 90.00%	Test Mean Steps: 2.33

Batch  248	Time: 128 s	Agent Eps: 0.88	Train Win Rate: 48.39%	Test Win Rate: 90.00%	Test Mean Steps: 2.00

Batch  310	Time: 159 s	Agent Eps: 0.86	Train Win Rate: 54.84%	Test Win Rate: 100.00%	Test Mean Steps: 2.00

Batch  372	Time: 186 s	Agent Eps: 0.83	Train Win Rate: 53.23%	Test Win Rate: 100.00%	Test Mean Steps: 2.20

Batch  434	Time: 222 s	Agent Eps: 0.80	Train Win Rate: 58.06%	Test Win Rate: 100.00%	Test Mean Steps: 2.00

Batch  496	Time: 247 s	Agent Eps: 0.78	Train Win Rate: 82.26%	Test Win Rate: 100.00%	Test Mean Steps: 2.00

Saving checkpoint... Saved to uniform-10-100-1.pth
