# Rock Paper Scissors - LSTM

Work In Progress - Looking for feedback on how to improve this design to actually be competative

---

This implements an LSTM for Rock Paper Scissors in Pytorch.

Input is encoded multiple ways:
```
x = torch.cat([
    noise.flatten(),    # tensor (3,)  random noise for probablistic moves
    step.flatten(),     # tensor (7,)  round(log(step)) one-hot encoded - to predict warmup periods    
    stats.flatten(),    # tensor (2,3) with move frequency percentages   
    history.flatten(),  # tensor (2,3,window=10) one-hot encoded timeseries history
])
```

Then fed through the following network:
```
RpsLSTM(
  (lstm): LSTM(76, 128, num_layers=3, batch_first=True, dropout=0.25)
  (dense_1): Linear(in_features=204, out_features=128, bias=True)
  (dense_2): Linear(in_features=128, out_features=128, bias=True)
  (out_probs): Linear(in_features=128, out_features=3, bias=True)
  (out_hash): Linear(in_features=128, out_features=128, bias=True)
  (activation): Softsign()
  (softmax): Softmax(dim=2)
)
```

There are two loss functions:
0. loss_probs() - softmax probability vs EV score of opponent move prediction
1. loss_hash() - Categorical Cross Entropy loss for agent identity prediction using Locality-sensitive hashing


This model can successfully defeat the simplest of agents such as:
- Rock 
- Paper 
- Scissors 
- Sequential

It has problems however with more complex agents, where struggles to get beyond a draw
- anti_rotn
- multi_stage_decision_tree
- iocaine_powder 
- greenberg

# NNBase Class

This is a generic baseclass to handle saving/loading the model from file and other utility functions

In [None]:
# %%writefile NNBase.py
# Source: https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-lstm/
# Source: https://github.com/JamesMcGuigan/ai-games/blob/master/games/rock-paper-scissors/neural_network/NNBase.py

from __future__ import annotations

import os
import re
from abc import ABCMeta
from typing import TypeVar

import humanize
import torch
import torch.nn as nn



# noinspection PyTypeChecker
T = TypeVar('T', bound='GameOfLifeBase')
class NNBase(nn.Module, metaclass=ABCMeta):
    """
    Base class for GameOfLife based NNs
    Handles: save/autoload, freeze/unfreeze, casting between data formats, and training loop functions
    """
    def __init__(self):
        super().__init__()
        self.loaded  = False  # can't call sell.load() in constructor, as weights/layers have not been defined yet
        self.device  = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


    def __call__(self, *args, **kwargs) -> torch.Tensor:
        if not self.loaded: self.load()  # autoload on first function call
        return super().__call__(*args, **kwargs)



    ### Initialization

    def weights_init(self, layer):
        ### Default initialization seems to work best, at least for Z shaped ReLU1 - see GameOfLifeHardcodedReLU1_21.py
        if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
            ### kaiming_normal_ corrects for mean and std of the relu function
            ### xavier_normal_ works better for ReLU6 and Z shaped activations
            if isinstance(self.activation, (nn.ReLU, nn.LeakyReLU, nn.PReLU)):
                nn.init.kaiming_normal_(layer.weight)
                # nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    # small positive bias so that all nodes are initialized
                    nn.init.constant_(layer.bias, 0.1)
        else:
            # Use default initialization
            pass

    ### Freeze / Unfreeze

    def freeze(self: T) -> T:
        if not self.loaded: self.load()
        for name, parameter in self.named_parameters():
            parameter.requires_grad = False
        return self

    def unfreeze(self: T) -> T:
        if not self.loaded: self.load()
        for name, parameter in self.named_parameters():
            parameter.requires_grad = True
        return self



    ### Load / Save Functionality

    @property
    def filename(self) -> str:
        if os.environ.get('KAGGLE_KERNEL_RUN_TYPE'):
            return f'./{self.__class__.__name__}.pth'
        else:
            filename = os.path.join( os.path.dirname(__file__), 'models', f'{self.__class__.__name__}.pth' )
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            return filename

    # DOCS: https://pytorch.org/tutorials/beginner/saving_loading_models.html
    def save(self: T, verbose=True) -> T:
        os.makedirs(os.path.dirname(self.filename), exist_ok=True)
        torch.save(self.state_dict(), self.filename)
        if verbose: print(f'{self.__class__.__name__}.savefile(): {self.filename} = {humanize.naturalsize(os.path.getsize(self.filename))}')
        return self


    def load(self: T, load_weights=True, verbose=True) -> T:
        if load_weights and os.path.exists(self.filename):
            try:
                self.load_state_dict(torch.load(self.filename))
                if verbose: print(f'{self.__class__.__name__}.load(): {self.filename} = {humanize.naturalsize(os.path.getsize(self.filename))}')
            except Exception as exception:
                # Ignore errors caused by model size mismatch
                if verbose: print(f'{self.__class__.__name__}.load(): model has changed dimensions, reinitializing weights\n')
                self.apply(self.weights_init)
        else:
            if verbose:
                if load_weights: print(f'{self.__class__.__name__}.load(): model file not found, reinitializing weights\n')
                # else:          print(f'{self.__class__.__name__}.load(): reinitializing weights\n')
            self.apply(self.weights_init)

        self.loaded = True    # prevent any infinite if self.loaded loops
        self.to(self.device)  # ensure all weights, either loaded or untrained are moved to GPU
        # self.eval()           # default to production mode - disable dropout
        # self.freeze()         # default to production mode - disable training
        return self



    ### Debugging

    def print_params(self):
        print(self.__class__.__name__)
        print(self)
        for name, parameter in sorted(self.named_parameters(), key=lambda pair: pair[0].split('.')[0] ):
            print(name)
            print(re.sub(r'\n( *\n)+', '\n', str(parameter.data.cpu().numpy())))  # remove extranious newlines
            print()

In [None]:
# %run NNBase.py

# LSTM Agent

This is the main neural network model class implemented in pytorch.

Input is encoded multiple ways:
```
x = torch.cat([
    noise.flatten(),    # tensor (3,)  random noise for probablistic moves
    step.flatten(),     # tensor (7,)  round(log(step)) one-hot encoded - to predict warmup periods    
    stats.flatten(),    # tensor (2,3) with move frequency percentages   
    history.flatten(),  # tensor (2,3,window=10) one-hot encoded timeseries history
])
```

Then fed through the following network:
```
RpsLSTM(
  (lstm): LSTM(76, 128, num_layers=3, batch_first=True, dropout=0.25)
  (dense_1): Linear(in_features=204, out_features=128, bias=True)
  (dense_2): Linear(in_features=128, out_features=128, bias=True)
  (out_probs): Linear(in_features=128, out_features=3, bias=True)
  (out_hash): Linear(in_features=128, out_features=128, bias=True)
  (activation): Softsign()
  (softmax): Softmax(dim=2)
)
```

It defines two custom loss functions: 
- `loss_probs()` is for correctly predicting the next opponent move based on 1/0.5/0 EV scores
- `loss_hash()` is for predicting who our current opponent is

[Locality-sensitive hashing](https://en.wikipedia.org/wiki/Locality-sensitive_hashing) is used within `loss_hash()`. 
This is done by creating a xxhash(seed=0) of the agents str label, 
then using modulo to place it into a fixed-size one-hot-encoded bucket.
Whist this information is not directly used for making moves,
the idea is to train the model to have an internal representation
of who our opponent is, such that it can better select an appropriate strategy

[Softsign](https://sefiks.com/2017/11/10/softsign-as-a-neural-networks-activation-function/)
`f(x) = x / (1 + |x|)` is used as the activation function. It has a similar shape to `tanh()`
but has two extra "bumps" in the curve which a neural network can make use of.
Its lesser known but has been cited in [recent papers](https://paperswithcode.com/method/softsign-activation).

I have not done a through exploration as to the best activation function to use here,
however it did fix a weird bug caused by [PReLU](https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html)
that resulted in the LSTM weights converging to NaN after a thousand epochs.

I also concatinate the original input to the LSTM output before passing it the dense layers.
I am unsure if I should be doing this, or rely purely on this information being saved into the LSTM embedding.

In [None]:
!pip3 install -q xxhash 

In [None]:
# %%writefile RpsLSTM.py
# Source: https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-lstm/
# Source: https://github.com/JamesMcGuigan/ai-games/blob/master/games/rock-paper-scissors/neural_network/RpsLSTM.py

import math

import torch
import torch.nn as nn
import xxhash

# from neural_network.NNBase import NNBase
# from neural_network.NNBase import NNBase


# DOCS: https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
# DOCS: https://blog.floydhub.com/long-short-term-memory-from-zero-to-hero-with-pytorch/
# DOCS: https://github.com/MagaliDrumare/How-to-learn-PyTorch-NN-CNN-RNN-LSTM/blob/master/10-LSTM.ipynb
class RpsLSTM(NNBase):

    def __init__(self, hidden_size=128, hash_size=128, num_layers=3, dropout=0.25, window=10):
        """
        :param hidden_size: size of LSTM embedding
        :param hash_size:   size of hash for guessing opponent
        :param num_layers:  number of LSTM layers
        :param dropout:     dropout parameter for LSTM
        :param window:      maximum history length passed in as input data
        """
        super().__init__()
        self.device      = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        self.hidden_size = hidden_size
        self.num_layers  = num_layers
        self.batch_size  = 1
        self.dropout     = dropout
        self.window      = window
        self.hash_size   = hash_size
        self.input_size  = self.cast_inputs(0,0).shape[-1]

        self.lstm  = nn.LSTM(
            input_size  = self.input_size,
            hidden_size = self.hidden_size,
            num_layers  = self.num_layers,
            dropout     = self.dropout,
            batch_first = True,
        )
        self.dense_1   = nn.Linear(self.input_size + hidden_size, hidden_size)
        self.dense_2   = nn.Linear(hidden_size, hidden_size)
        self.out_probs = nn.Linear(hidden_size, 3)
        self.out_hash  = nn.Linear(hidden_size, self.hash_size)

        self.activation = nn.Softsign()  # BUGFIX: PReLU was potentially causing NaNs in model weights
        self.softmax    = nn.Softmax(dim=2)
        self.reset()  # call before self.cast_inputs()
        self.to(self.device)



    ### Lifecycle

    def reset(self):
        self.history = [ [], [] ]  # action, opponent
        self.stats   = torch.zeros((2,3),    dtype=torch.float).to(self.device)
        self.hidden  = (
            torch.zeros(self.num_layers, self.batch_size, self.hidden_size),
            torch.zeros(self.num_layers, self.batch_size, self.hidden_size)
        )



    ### Training

    @staticmethod
    def reward(action: int, opponent: int) -> float:
        if (action - 1) % 3 == opponent % 3: return  1.0  # win
        if (action - 0) % 3 == opponent % 3: return  0.5  # draw
        if (action + 1) % 3 == opponent % 3: return  0.0  # loss
        return 0.0


    def loss_probs(self, probs: torch.Tensor, opponent: int) -> torch.Tensor:
        """
        Loss based on softmax probability vs EV score of opponent move prediction
        """
        ev = torch.zeros((3,), dtype=torch.float).to(self.device)
        ev[(opponent + 0) % 3] = 1.0   # expect rock, play paper + opponent rock     = win
        ev[(opponent + 1) % 3] = 0.5   # expect rock, play paper + opponent paper    = draw
        ev[(opponent + 2) % 3] = 0.0   # expect rock, play paper + opponent scissors = loss
        losses = probs * (1-ev)
        loss   = torch.sum( losses )
        # loss   = -torch.sum(torch.log(1-losses))  # cross entropy loss
        return loss


    def loss_hash(self, hash_id: torch.Tensor, agent_name: str) -> torch.Tensor:
        """
        Categorical Cross Entropy loss for agent identity prediction using Locality-sensitive hashing
        """
        hash_id    = hash_id.flatten()
        hash_pred  = torch.argmax(hash_id)
        agent_hash = xxhash.xxh32(agent_name, seed=0).intdigest() % self.hash_size
        agent_hot  = self.one_hot_encode(agent_hash, size=self.hash_size).flatten()
        loss       = -torch.sum( agent_hot * torch.log(hash_id) - (1-agent_hot) * torch.log(1-hash_id) )
        loss[ torch.isnan(loss) ] = 0.0  # BUGFIX: prevent log(0) = NaN
        return loss



    ### Casting

    def one_hot_encode(self, number: int = None, size: int = 3) -> torch.Tensor:
        """ One hot encoding of action and opponent action """
        x = torch.zeros((size,), dtype=torch.float).to(self.device)
        if number is not None:
            x[ int(number) % size ] = 1.0
        return x


    def encode_actions(self, action: int, opponent: int) -> torch.Tensor:
        return torch.stack([
            self.one_hot_encode(action),
            self.one_hot_encode(opponent),
        ])


    def encode_history(self) -> torch.Tensor:
        """
        self.history as a one hot encoded tensor
        history is created via .insert() thus latest move will always be in position 0
        self.window can be used to restrict the maximum size of history data passed into the model
        """
        x = torch.zeros((2, self.window, 3), dtype=torch.float).to(self.device)
        for player in [0,1]:
            window = min( self.window, len(self.history[player]) )
            for step in range(window):
                action = self.history[player][step]
                if action is None: continue
                x[ player, step, action % 3 ] = 1.0
        return x


    def encode_stats(self):
        """ Normalized percentage frequency from self.stats """
        step  = torch.sum(self.stats[0])
        stats = ( self.stats / torch.sum(self.stats[0])
                  if step.item() > 0
                  else self.stats )
        return stats


    def encode_step(self):
        """
        Encode the step (current turn number) as one hot encoded version of the logarithm
        This is mostly for detecting random warmup periods
        { round(math.log(n)): n for n in range(1,1001)  } = { 0: 1, 1: 4, 2: 12, 3: 33, 4: 90, 5: 244, 6: 665, 7: 1000}
        """
        step     = torch.sum(self.stats[0]).reshape(1)
        log_step = torch.log(step+1).round().int().item()  # log(0) == NaN
        hot_step = self.one_hot_encode(log_step, size=round(math.log(1000)))
        return hot_step


    @staticmethod
    def cast_action(probs: torch.Tensor) -> int:
        expected = torch.argmax(probs, dim=2).detach().item()
        action   = int(expected + 1) % 3
        return action


    def cast_inputs(self, action: int, opponent: int) -> torch.Tensor:
        """
        Generate the input tensors for the LSTM
        Assumes that self.update_state(action, opponent) has been called beforehand
        action + opponent are now encoded as part of the history
        """
        if not hasattr(self, 'stats'): self.reset()

        noise   = torch.rand((3,)).to(self.device)
        step    = self.encode_step()
        stats   = self.encode_stats()
        history = self.encode_history()

        x = torch.cat([
            noise.flatten(),    # tensor (3,)            random noise for probablistic moves
            step.flatten(),     # tensor (7,)            round(log(step)) one-hot encoded - to predict warmup periods
            stats.flatten(),    # tensor (2,3)           with move frequency percentages
            history.flatten(),  # tensor (2,3,window=10) one-hot encoded timeseries history
        ])
        x = torch.reshape(x, (1,1,-1))    # (seq_len, batch, input_size)
        return x



    ### Play

    def update_state(self, action: int, opponent: int):
        """
        self.stats records total count for each action
            which will later be normalized as percentage frequency
        self.history records move history
            [0] index always being the most recent move
        """
        if action   is not None: self.stats[0][action]   += 1.0
        if opponent is not None: self.stats[1][opponent] += 1.0
        if action   is not None: self.history[0].insert(0, action)
        if opponent is not None: self.history[1].insert(0, opponent)


    def forward(self, action: int, opponent: int):
        self.update_state(action, opponent)
        inputs    = self.cast_inputs(action, opponent)
        x, hidden = self.lstm(inputs)
        x         = torch.cat([ x, inputs ], dim=2)
        x         = self.activation( self.dense_1(x)   )
        x         = self.activation( self.dense_2(x)   )
        probs     = self.softmax(    self.out_probs(x) )
        hash_id   = self.softmax(    self.out_hash(x)  )
        action    = self.cast_action(probs)

        # BUGFIX: occasionally LSTM would return NaNs after 850+ epochs
        #         this possibly caused by PReLU, now replaced with Softsign
        if any([ torch.any(torch.isnan(layer)) for layer in [ x, *hidden] ]):
            raise ValueError(f'{self.__class__.__name__}.forward() - LSTM returned nan')

        self.hidden = hidden
        return action, probs, hash_id


# Opponents

We start with a series of really simple agents: 
- Rock
- Paper
- Scissors
- Sequential

Plus the whitebelt agents from the [RPS Dojo](https://www.kaggle.com/chankhavu/rps-dojo)
- Reactionary
- Counter Reactionary
- Mirror
- Mirror Shift
- Statistical

These act as a baseline showing that the LSTM neural network is at least capable of learning simple logic.

Next up are a variety of more complex agents, implementing a range of different strategies
- [Anti-Rotn](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-anti-rotn)
- [Multi Stage Decision Tree](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-multi-stage-decision-tree)
- [Iocaine Powder](https://www.kaggle.com/jamesmcguigan/rps-roshambo-comp-iocaine-powder)
- [Greenberg](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-greenberg)
- [Statistical Prediction](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-statistical-prediction)


In [None]:
!cat ../input/rock-paper-scissors-anti-rotn/anti_rotn.py                  | perl -p -e 's/warmup=\d+/warmup=1/' | tee anti_rotn.py > /dev/null
!cat ../input/rock-paper-scissors-multi-stage-decision-tree/submission.py | perl -p -e 's/warmup_period=\d+/warmup_period=1/; s/random_freq=[\d.]+/random_freq=0/' | tee decision_tree.py > /dev/null
!cat ../input/rps-roshambo-comp-iocaine-powder/submission.py              | tee iocaine.py                > /dev/null
!cat ../input/rock-paper-scissors-greenberg/greenberg.py                  | tee greenberg.py              > /dev/null
!cat ../input/rock-paper-scissors-statistical-prediction/submission.py    | tee statistical_prediction.py > /dev/null

In [None]:
import random
from kaggle_environments.envs.rps.utils import get_score

def rock_agent(observation, configuration):
    return 0

def paper_agent(observation, configuration):
    return 1

def scissors_agent(observation, configuration):
    return 2

def sequential_agent(observation, configuration):
    return observation.step % configuration.signs


# White Belt Agents from https://www.kaggle.com/chankhavu/rps-dojo
last_react_action = None
def reactionary(observation, configuration):
    global last_react_action
    if observation.step == 0:
        last_react_action = random.randrange(0, configuration.signs)
    elif get_score(last_react_action, observation.lastOpponentAction) <= 1:
        last_react_action = (observation.lastOpponentAction + 1) % configuration.signs
    return last_react_action


last_counter_action = None
def counter_reactionary(observation, configuration):
    global last_counter_action
    if observation.step == 0:
        last_counter_action = random.randrange(0, configuration.signs)
    elif get_score(last_counter_action, observation.lastOpponentAction) == 1:
        last_counter_action = (last_counter_action + 2) % configuration.signs
    else:
        last_counter_action = (observation.lastOpponentAction + 1) % configuration.signs
    return last_counter_action


def mirror_opponent_agent(observation, configuration):
    if observation.step > 0:
        return observation.lastOpponentAction
    else:
        return 0
    
    
def mirror_shift_opponent_agent_1(observation, configuration):
    if observation.step > 0:
        return (observation.lastOpponentAction + 1) % 3
    else:
        return 0

    
def mirror_shift_opponent_agent_2(observation, configuration):
    if observation.step > 0:
        return (observation.lastOpponentAction + 2) % 3
    else:
        return 0

    
action_histogram = {}
def statistical(observation, configuration):
    global action_histogram
    if observation.step == 0:
        action_histogram = {}
        return
    action = observation.lastOpponentAction
    if action not in action_histogram:
        action_histogram[action] = 0
    action_histogram[action] += 1
    mode_action = None
    mode_action_count = None
    for k, v in action_histogram.items():
        if mode_action_count is None or v > mode_action_count:
            mode_action = k
            mode_action_count = v
            continue

    return (mode_action + 1) % configuration.signs

# RPS Trainer

This is a training loop abstracted to play against arbitray kaggle_environment agents. 

The RL interface for kaggle_environments is:
```
env         = make("rps", { "episodeSteps": steps }, debug=False)
trainer     = env.train(random.sample([None, agent], 2))  # random player order
observation = trainer.reset()
done        = False
while not done:
    action = model( observation.lastOpponentAction )
    observation, reward, done, info = trainer.step(action)
```

You can read more about this in the [docs](https://github.com/Kaggle/kaggle-environments)

In theory this training loop code could be easily repurposed for other kaggle competitions.

---

This training loop takes a dictionary of opponent agents, and simulates a 100 step = 50 round match against each one.
`loss.backward()` is only called at the end of each match as its the only way I have figured out
how to solve the `RuntimeError: Trying to backward through the graph a second time` exception.
If anybody knows a better way of doing this, please let me know.

I did have a bit of code to probablistically select agents based on their accuracy percentage.
The idea being to skip the rock/paper/scissors agents once they had reached 100% accuracy,
and focus most of the training time on the agents that needed the most training.
However when retraining from scratch it was causing weird effects in my output statitsics 
(unsure if this was just a bug in my logging code). Its been disabled for now.

`RMSprop` seems to work better than other optimizers such as `Adam` or `Adadelta`
for this reinforcement learning task. I don't have any statistics on this,
so can't say if my observations where simply based on coincidence. 

`CosineAnnealingWarmRestarts` was added in because several of the [DeepMind](https://deepmind.com/)
papers mention that they use a cosine based scheduling system for their models.
I am not 100% sure this is the same scheduler, or what the correct settings should be.
I am unsure if this is helping or hurting my model, and if I should leave it in or not.

For the sake of making it easier to read the directional trend of the logfile numbers,
I have implemented a very basic running average by taking the mean of the 
current and next values in the sequence. This might not be the technically correct
way of doing it (its more like summing an infinite series), but gives  
approximately correct numbers and is a simple one-liner to smooth the curves 
and make it easier to observe the directional trend of loss and accuracy.

Learning rate is currently set to `1e-4`. One observation from earlier in the development cycle 
was that if the learning rate was set too high, such as `1e-1`, then the model would fail to train
even against the simplest Rock/Paper/Scissors agents. I am unsure if I have set it too low,
or how exactly this interacts with `CosineAnnealingWarmRestarts`. 
Generally seems safer to be too small than too large. 


A few technical points to remember:
- `model.train()` and `model.eval()` can be used to switch between training and production modes. 
    - Dropout is disabled in production, and this may also have an effect on batch normalization layers.
- `model.load()` and `model.save()` are custom methods of my `NNBase` class
- `except KeyboardInterrupt: pass` ensures the model gets saved on Ctrl-C exit
- `if __name__ == '__main__':` prevents the training loop from running if we import `rps_trainer()` from a seperate file

---

Losses and accuracy are displayed based on running averages
```
   200 | losses = 0.419607 0.569539 | 100 r  99 p  98 s  99 seq  65 rotn  32 tree  35 iocaine  38 greenberg  47 stats
   210 | losses = 0.420502 0.617317 | 100 r  99 p  98 s  99 seq  70 rotn  30 tree  46 iocaine  40 greenberg  54 stats
   220 | losses = 0.411737 0.618371 | 100 r  99 p  98 s  99 seq  54 rotn  32 tree  45 iocaine  40 greenberg  51 stats
   230 | losses = 0.448281 0.667498 | 100 r  99 p  98 s  99 seq  52 rotn  27 tree  45 iocaine  42 greenberg  51 stats
```

The first number is the training epoch, which is a 100 step match against each agent

There are two loss functions:
0. loss_probs() - softmax probability vs EV score of opponent move prediction
1. loss_hash() - Categorical Cross Entropy loss for agent identity prediction using Locality-sensitive hashing

The last set of numbers are accuracy percentage scores in actual gameplay. 
- 100 = 100% accuracy winning on every round
- 50  = 50% is a statistical draw, either via draw or win/loss every other round

In [None]:
#!/usr/bin/env python3
# %%writefile rps_trainer.py
# Source: https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-lstm/
# Source: https://github.com/JamesMcGuigan/ai-games/blob/master/games/rock-paper-scissors/neural_network/rps_trainer.py

import random
import re
import sys
import time
import gc
import textwrap
from typing import Dict


import torch
from kaggle_environments import make
from torch.autograd import Variable
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# from neural_network.RpsLSTM import RpsLSTM
# from roshambo_competition.anti_rotn import anti_rotn
# from roshambo_competition.greenberg import greenberg_agent
# from roshambo_competition.iocaine_powder import iocaine_agent
# from simple.paper import paper_agent
# from simple.rock import rock_agent
# from simple.scissors import scissors_agent
# from simple.sequential import sequential_agent
# from statistical.statistical_prediction import statistical_prediction_agent
# from tree.multi_stage_decision_tree import decision_tree_agent


        
def rps_trainer(model, agents: Dict, steps=100, epochs=1000, lr=1e-3, log_freq=10, timeout=0):
    time_start = time.perf_counter()
    wrapper    = textwrap.TextWrapper(width=78, subsequent_indent=' '*4, initial_indent=' '*4)
    try:
        env   = make("rps", { "episodeSteps": steps }, debug=False)
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
        scheduler = None
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=100)  # TODO: review choice of scheduler

        accuracies     = { agent_name: 0.0 for agent_name in agents.keys()}
        running_losses = torch.zeros((1,)).to(model.device)
        for epoch in range(epochs):
            
            ### skip high-accuracy agents more often, but train at least 10% of the time
            ### this seems to cause more problems than it solves
            # selected_agents = {
            #     agent_name: agent
            #     for (agent_name, agent) in agents.items()
            #     if random.random() + 0.1 >= accuracies[agent_name]
            # }
            selected_agents = agents
            if len(selected_agents) == 0: continue

            scores = Variable(torch.zeros((   len(selected_agents),), requires_grad=True)).to(model.device)
            losses = Variable(torch.zeros((2, len(selected_agents),), requires_grad=True)).to(model.device)

            for agent_index, (agent_name, agent) in enumerate(selected_agents.items()):
                trainer     = env.train(random.sample([None, agent], 2))  # random player order
                observation = trainer.reset()

                model.reset()
                optimizer.zero_grad()
                gc.collect()

                action     = None
                opponent   = None
                for step in range(1,sys.maxsize):
                    action, probs, hash_id = model.forward(action=action, opponent=opponent)

                    observation, reward, done, info = trainer.step(action)
                    opponent = observation.lastOpponentAction

                    losses[0][agent_index] += model.loss_probs(probs, opponent)
                    losses[1][agent_index] += model.loss_hash(hash_id, agent_name)
                    scores[agent_index]    += (reward + 1.0) / 2.0
                    if done: break

                losses[0][agent_index] /= step  # NOTE: steps = 2 * step
                losses[1][agent_index] /= step
                scores[agent_index]    /= step
                accuracies[agent_name]  = ( (accuracies[agent_name] + scores[agent_index].item())
                                            / (2 if sum(accuracies.values()) else 1) )

            # print(env.render(mode='ansi'))
            running_losses = ( (running_losses + torch.mean(losses, dim=1))
                               / (2 if torch.sum(running_losses) else 1) )
            loss = torch.mean(losses)
            loss.backward()
            optimizer.step()
            if scheduler is not None: scheduler.step(loss)

            if epoch % log_freq == 0:
                accuracy_log = " ".join([
                    ('\n'+' '*6 if n % 8 == 0 else '') + 
                    f'{round(value * 100):3d} {name},'
                    for n, (name, value) in enumerate(accuracies.items())
                ])
                message = f'{epoch:4d} | losses = {running_losses[0].item():.6f} {running_losses[1].item():.6f} | {accuracy_log}' 
                print(message)
                if torch.mean(scores).item() >= (1 - 2/steps): break  # allowed first 2 moves wrong

            if timeout and time.perf_counter() > time_start + timeout:  break
        
    except KeyboardInterrupt: pass
    except Exception as exception: 
        print('Exception', exception)
        pass

# Training

In [None]:
!cp ../input/rock-paper-scissors-lstm/RpsLSTM.pth ./

In [None]:
%%time
# Source: https://github.com/JamesMcGuigan/ai-games/blob/master/games/rock-paper-scissors/neural_network/rps_trainer.py

if __name__ == '__main__':
    agents = {
        'r':          rock_agent,
        'p':          paper_agent,
        's':          scissors_agent,
        'seq':        sequential_agent,
        'react':      reactionary,
        'react+1':    counter_reactionary,
        'mirror':     mirror_opponent_agent,
        'mirror+1':   mirror_shift_opponent_agent_1,
        'mirror+2':   mirror_shift_opponent_agent_2,        
        'stat':       statistical,
        'stat_pred':  "statistical_prediction.py",
        'rotn':       "anti_rotn.py",
        'tree':       "decision_tree.py",
        'iocaine':    "iocaine.py",
        'greenberg':  "greenberg.py",
    }

    model = RpsLSTM(hidden_size=128, num_layers=3, dropout=0.25).train()
    print(model)
    model.load()
    rps_trainer(model, agents, steps=100, lr=1e-4, epochs=200, timeout=1*60*60)
    model.save()


# Questions

This model can successfully defeat the simplest of agents such as:
- Rock 
- Paper 
- Scissors 
- Sequential

It has problems however with more complex agents, where struggles to get beyond a draw
- anti_rotn
- multi_stage_decision_tree
- iocaine_powder 
- greenberg

I am unsure exactly what I am doing wrong here.
- Is Rock Paper Scissors a suitable usecase for an LSTM network?
- Am I training for long enough. The model seemed to plateau at a draw for the advanced agents
- Are these more advanced agents too complex for a neural network to reverse engineer?
    - The simple agents are able to train with `hidden_size=16`
    - Do I need a much larger embedding size?
- Is my model too small or too large?
    - Are 3 LSTM layers better than 1
    - Should I have a pyramid of 3 dense layers rather than a square of 2
    - The Largest possible model that will fit in 100Mb is hidden_size=1024 with pyramid shaped dense layers, but this is very slow to train
- I concatenate the original input to the LSTM output before passing it the dense layers
    - Unsure if I should rely on the LSTM embedding to fully encode everything that needs to be remembered
- Am I missing any obvious layers such as batch normaliztion?
- Is there any way to train an LSTM in batch mode with reinforcement learning?


Thank you for any advice or feedback on how to improve this work

# Further Reading

This notebook is part of a series exploring Rock Paper Scissors:

Predetermined
- [PI Bot](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-pi-bot)
- [Anti-PI Bot](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-anti-pi-bot)
- [Anti-Anti-PI Bot](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-anti-anti-pi-bot)
- [De Bruijn Sequence](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-de-bruijn-sequence)

RNG
- [Random Agent](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-random-agent)
- [Random Seed Search](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-random-seed-search)
- [RNG Statistics](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-rng-statistics)

Opponent Response
- [Anti-Rotn](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-anti-rotn)
- [Sequential Strategies](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-sequential-strategies)

Statistical 
- [Weighted Random Agent](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-weighted-random-agent)
- [Anti-Rotn Weighted Random](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-anti-rotn-weighted-random)
- [Statistical Prediction](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-statistical-prediction)

Memory Patterns
- [Naive Bayes](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-naive-bayes)
- [Memory Patterns](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-memory-patterns)

Decision Tree
- [XGBoost](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-xgboost)
- [Multi Stage Decision Tree](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-multi-stage-decision-tree)
- [Decision Tree Ensemble](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-decision-tree-ensemble)

Neural Networks
- [LSTM](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-lstm)

Ensemble
- [Multi Armed Stats Bandit](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-multi-armed-stats-bandit)

RoShamBo Competition Winners
- [Iocaine Powder](https://www.kaggle.com/jamesmcguigan/rps-roshambo-comp-iocaine-powder)
- [Greenberg](https://www.kaggle.com/jamesmcguigan/rock-paper-scissors-greenberg)