In [105]:
import torch
from torch import nn
import math
import numpy as np

In [2]:
class grad_skip_softmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sm = nn.Softmax()
    
    def forward(self, x):
        return self.sm(x)

    def backward(self, grad):
        # skip gradient through the softmax on backward pass
        return grad
        
class gru(nn.Module):
    # 'gated-recurrent-unit type gating' as seen in GTrXL paper
    def __init__(self, dim, b_g = 1) -> None:
        super().__init__()

        self.w_r = nn.Linear(dim, dim, bias = False)
        self.u_r = nn.Linear(dim, dim, bias = False)

        self.w_z = nn.Linear(dim, dim, bias = True)
        self.u_z = nn.Linear(dim, dim, bias = True)
        self.b_g = b_g # this is used to hack initial bias of the above to be below zero, such that gate is initialized close to identity
        self.w_g = nn.Linear(dim, dim, bias = False)
        self.u_g = nn.Linear(dim, dim, bias = False)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self, x, y):
        r = self.sigmoid(self.w_r(y) + self.u_r(x))
        z = self.sigmoid(self.w_z(y) + self.u_z(x) - self.b_g) # when zero, gate passes identity of residual
        h_hat = self.tanh(self.w_g(y) + self.u_g(r * x))
        g = (1-z)*x + z * h_hat
        return g
        

class mlp(nn.Module):
    # 1d temporal convolution
    # no communication between tokens, uses same kernel for each token spot
    def __init__(self, embed_dim, internal_dim) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(embed_dim, internal_dim),
            nn.ReLU(),
            nn.Linear(internal_dim, embed_dim)
        ) # no second relu at output of mlp

    def forward(self, input):
        return self.block(input)


class cross_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x, enc):

        return self.attention(x, enc, enc)[0]

class self_attention(nn.Module):
    def __init__(self, embed_dimension, num_heads) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dimension,
            num_heads=num_heads,
            )
    
    def forward(self, x):

        return self.attention(x, x, x)[0]







In [44]:


class Smear_key(nn.Module):

    def __init__(self,
    sequence_length,
    heads
    ) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, heads, sequence_length - 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, k):
        itrp = self.sigmoid(self.alpha)
        smear = k[:,:,1:,:]*itrp + k[:,:,:-1,:]*(1-itrp)
        return torch.cat([k[:,:, 0:1, :], smear], dim = 2)

class decoder_mha(nn.Module):
    #Masked smeared self attention
    def __init__(self, model_dim, sequence_length, heads) -> None:
        super().__init__()
        self.mask = torch.triu(torch.ones(sequence_length, sequence_length) * float('-inf'), diagonal=1) # make batch, heads, seq,seq
        self.model_dim = model_dim
        self.sequence_length = sequence_length
        self.heads = heads
        self.key_dim = model_dim // heads
        self.W_q = nn.Linear(model_dim, model_dim, bias=False)
        self.W_k = nn.Linear(model_dim, model_dim, bias=False)
        self.W_v = nn.Linear(model_dim, model_dim, bias=False)
        self.output = nn.Linear(model_dim, model_dim, bias=True)
        self.ln = nn.LayerNorm(model_dim)
        self.smear = Smear_key(sequence_length, heads)

    def forward(self,x):
        # batch, sequence, model_dim
        q = self.W_q(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.W_k(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        v = self.W_v(x).view(-1, self.sequence_length, self.heads, self.key_dim).transpose(1,2)
        k = self.smear(k)
        # batch, heads, sequence, dim // heads
        key_dim = k.shape[-1:][0]
        scores = q @ k.transpose(2,3) / key_dim**.5
        scores += self.mask
        attn = torch.softmax(scores, dim = 3)
        mha = attn @ v
        mha = mha.transpose(1, 2).contiguous().view(-1, self.sequence_length, self.model_dim)
        out = self.output(mha)
        # batch, sequence, model_dim
        return out





In [45]:
class encoder_layer(nn.Module):
    # transformer layer
    # not masked, no cross attention, no memory, for encoder
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.mha = self_attention(
            embed_dimension=embed_dim,
            num_heads=attention_heads
        )

        self.mlp = mlp(
            embed_dim= embed_dim,
            internal_dim=mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
    
    def forward(self, x):
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

class decoder_layer(nn.Module):
    # transformer layer
    # masked, cross attention, smeared key
    def __init__(self,
    embed_dim,
    mlp_dim,
    attention_heads,
    sequence_lenth
    ) -> None:
        super().__init__()

        self.mha = decoder_mha(
            model_dim=embed_dim,
            sequence_length=sequence_lenth,
            heads=attention_heads
        ) #smeared key masked self attention

        self.cross_mha = cross_attention(
            embed_dimension = embed_dim,
            num_heads = attention_heads,
        )

        self.mlp = mlp(
            embed_dim = embed_dim,
            internal_dim = mlp_dim
        )

        self.gate1 = gru(
            dim = embed_dim
        )
        self.gate2 = gru(
            dim = embed_dim
        )
        self.gate3 = gru(
            dim = embed_dim
        )

        self.ln = nn.LayerNorm(embed_dim)

        self.activation = nn.ReLU()
        self.ln1 = nn.LayerNorm(embed_dim)
    
    def forward(self, x, enc):
        # masked self attention, smeared key
        y = self.ln1(x)
        y = self.mha(y)
        x = self.gate1(x,self.activation(y))

        # cross attention
        # consider output sequence length and 
        y = self.ln1(x)
        enc = self.ln1(enc)
        y = self.cross_mha(enc, x)
        x = self.gate2(x, self.activation(y))

        # position-wise multi layper perceptron
        y = self.ln1(x)
        y = self.mlp(y)
        x = self.gate2(x, self.activation(y))
        
        return x

In [None]:
# positional encoding class drawn largely from pytorch tutorial on thier website
class positional_encoding(nn.Module):
    def __init__(self,
    model_dim,
    sequence_length
    ) -> None:
        super().__init__()

        position = torch.arange(sequence_length).unsqueeze(1)
        freq = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(1, sequence_length, model_dim)
        pe[0, :, 0::2] = torch.sin(position * freq)
        pe[0, :, 1::2] = torch.cos(position * freq)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(1)]


In [64]:
class encoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()
        
        # no inductive biases on encoder here
        self.block = nn.Sequential()
        for x in range(layers):
            self.block.append(encoder_layer(
                embed_dim = model_dim,
                mlp_dim = mlp_dim,
                attention_heads = heads,
                sequence_length = sequence_length
            ))
            
    def forward(self, x):
        return self.block(x)

class decoder(nn.Module):
    def __init__(self,
    layers,
    model_dim,
    mlp_dim,
    heads,
    sequence_length
    ) -> None:
        super().__init__()

        self.pe = positional_encoding(
            model_dim=model_dim, 
            sequence_length=sequence_length
            )

        self.block = []

        for x in range(layers):
            self.block.append(
                decoder_layer(
                    embed_dim = model_dim,
                    mlp_dim= mlp_dim,
                    attention_heads= heads,
                    sequence_lenth = sequence_length
                )
            )
        
    def forward(self, x, y):
        # y is input from encoder
        x = self.pe(x)
        for layer in self.block:
            x = layer(x,y)
            
        return x

        

In [None]:
# positional encoding class drawn largely from tutorial on pytorch website
class positional_encoding(nn.Module):
    def __init__(self,
    model_dim,
    sequence_length
    ) -> None:
        super().__init__()

        position = torch.arange(sequence_length).unsqueeze(1)
        freq = torch.exp(torch.arange(0, model_dim, 2) * (-math.log(10000.0) / model_dim))
        pe = torch.zeros(1, sequence_length, model_dim)
        pe[0, :, 0::2] = torch.sin(position * freq)
        pe[0, :, 1::2] = torch.cos(position * freq)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(1)]


In [65]:
class RLformer(nn.Module):
    #modify to do buffer and positional encoding

    def __init__(self,
    model_dim,
    mlp_dim,
    attn_heads,
    sequence_length,
    enc_layers,
    dec_layers,
    action_dim
    ) -> None:
        super().__init__()

        self.encoder = encoder(
            layers=enc_layers,
            model_dim=model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length = sequence_length
        )

        self.decoder = decoder(
            layers=dec_layers,
            model_dim= model_dim,
            mlp_dim=mlp_dim,
            heads=attn_heads,
            sequence_length=sequence_length,
        )

        self.actor = nn.Sequential(
            nn.Linear(model_dim, action_dim),
            nn.ReLU(),
            grad_skip_softmax() # To do neural replicator dynamics
        )

        self.critic = nn.Sequential(
            nn.Linear(model_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, 1)
        )
        

    def forward(self, enc_input, dec_input):
        enc = self.encoder(enc_input)
        dec = self.decoder(dec_input, enc)
        policy = self.actor(dec)
        value = self.critic(dec)
        return policy, value

In [None]:
class Agent(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = RLformer(
            model_dim =,
            mlp_dim =,
            attn_heads =,
            sequence_length =,
            enc_layers =,
            dec_layers =,
            action_dim =,
        )

        self.register_buffer('buffername', tensor=)
    
    def init_player(self, player, hand):
        # initialize this players hand through encoder and store in buffer, after tokenizing

    def forward(self, player, obs_flat, reward_flat):
        #takes flattened inputs in list form, not tokenized

        return action, value
        

In [None]:
import torch
import random


class poker_env():
    '''
    Texas no-limit holdem environment.
    '''

    def __init__(self, n_players) -> None:

        self.n_players = n_players

        self.stacks = [0] * n_players
        for player in range(n_players):
            self.stacks[player] = 200

        self.button = 0  # button starts at player 0 WLOG

        self.deck = []
        for suit in ["hearts", "diamonds", "spades", "clubs"]:
            for rank in range(2, 15):
                self.deck += [suit, rank]

    def new_hand(self):
        self.community_cards = []
        self.hands = []
        self.deck_position = 0
        self.button = (self.button + 1) % self.n_players
        self.in_turn = (self.button + 1) % self.n_players
        self.behind = [0] * self.n_players
        self.in_hand = [True] * self.n_players
        self.took_action = [
                               False] * self.n_players  # tracks whether players have taken action in a specific round of betting
        self.pot = 0
        self.current_bet = 0
        self.stage = 0  # 0: pre-flop, 1: flop, 2: turn, 3: river
        self.deck_position = 0

        # deal cards, pass to agents
        random.shuffle(self.deck)
        for i in range(self.n_players):
            self.hands += self.get_next_cards(2)

        # big blind is 2, small blind is 1
        small_blind = {'player': self.in_turn, 'type': 'bet', 'value': 1}
        rewards_1, observations_1 = self.take_action(small_blind)

        big_blind_player = self.in_turn
        big_blind = {'player': big_blind_player, 'type': 'bet', 'value': 2}
        rewards_2, observations_2 = self.take_action(big_blind)
        self.took_action[big_blind_player] = False

        rewards_1 += rewards_2
        observations_1 += observations_2

        return rewards_1, observations_1

    def get_hand(self, player):
        if len(self.hands) == 0:
            return None
        return self.hands[player]

    def take_action(self, action):
        '''
        Only function that is externally called in training
        Takes an action, returns a rewards tensor which has an element for each player, and a list of observations. 
        Observations are all public information -- does not include dealt hands
        Moves game state to next point where action input is required
        Rewards implementation currently changing -- very fucked up rn
        '''
        rewards = [torch.zeros(self.n_players)]

        observations = [action]  # first observation returned is always the action being taken
        player = action['player']
        type = action['type']  # action type is one of {bet, call, fold}
        value = action['value']

        self.took_action[player] = True

        if type == 'bet':
            # move money from player to pot
            self.stacks[player] -= value
            self.pot += value
            self.current_bet += value  # bets are valued independently and are NOT measured by cumulative sum -- current_bet tracks that

            # reward is negative of amount bet
            rewards[0][player] = -value

            # other players are now behind the bet
            for x in range(self.n_players):
                self.behind[x] += value

            # player who just bet cannot be behind
            self.behind[player] = 0

        if type == 'call':
            # need to catch up to current bet
            call_size = self.behind[player]
            self.behind[player] = 0

            # move money from player to pot
            self.stacks[player] -= call_size
            self.pot += call_size
            self.current_bet += call_size  # bets are valued independently and are NOT measured by cumulative sum -- current_bet tracks that

            # reward is negative of amount bet
            rewards[0][player] = -1 * call_size

        if type == 'fold':
            # player becomes inactive
            self.in_hand[player] = False

        # if everyone is square or folded, advance to next game stage
        square_check = True
        for p in range(self.n_players):
            if (self.in_hand[p] and self.behind[p] != 0) or not self.took_action[
                p]:  # Big blind option handled via took_action
                square_check = False

        if square_check:
            # advance stage, and any other subcalls that come with that
            advance_stage_rewards, advance_stage_observations = self.advance_stage()
            rewards += advance_stage_rewards
            observations += advance_stage_observations

        else:
            # advance to next player
            self.in_turn = (self.in_turn + 1) % self.n_players

        return rewards, observations

    def advance_stage(self):
        # this is called anytime that there is no player who is: 1. in the hand, 2. behind the bet, and 3. has not taken action
        advance_stage_rewards = [torch.zeros(self.n_players)]
        advance_stage_observations = []

        # payout if only one player is left
        if sum(self.in_hand) == 1:
            for p in range(self.n_players):
                if self.in_hand[p]:
                    # payout!
                    advance_stage_rewards[0][p] += self.pot
                    advance_stage_observations += {'player': p, 'type': 'win', 'value': self.pot}
            new_hand_rewards, new_hand_observations = self.new_hand()  # move on to next hand
            advance_stage_rewards += new_hand_rewards
            advance_stage_observations += new_hand_observations

        # advance stage if not river
        elif self.stage != 3:
            self.stage += 1
            for p in range(self.n_players):
                if self.in_hand[
                    p]:  # this keeps took_action true for players who have folded to save a conditional above
                    self.took_action[p] = False
            advance_stage_rewards, advance_stage_observations = self.card_reveal()

        # compare hands and payout, then deal new hand
        else:
            winners = self.determine_showdown_winners()
            for p in winners:
                advance_stage_rewards[0][p] += self.pot / len(winners)
                advance_stage_observations += {'player': p, 'type': 'win', 'value': self.pot / len(winners)}

            new_hand_rewards, new_hand_observations = self.new_hand()  # move on to next hand
            advance_stage_rewards += new_hand_rewards
            advance_stage_observations += new_hand_observations

        return advance_stage_rewards, advance_stage_observations

    def card_reveal(self):

        if self.stage == 0:
            # revealing the flop
            card_rewards = [torch.zeros(self.n_players)] * 3  # card reveals have reward zero
            card_observations = self.get_next_cards(3)
            self.community_cards.extend(card_observations)
        else:
            # one card to be revealed
            card_rewards = torch.zeros(self.n_players)
            card_observations = self.get_next_cards(1)
            self.community_cards.extend(card_observations)
        self.in_turn = (self.button + 1) % self.n_players

        return card_rewards, card_observations

    def get_next_cards(self, num_cards):
        if num_cards == 1:
            card = self.deck[self.deck_position]
            self.deck_position += 1
            return card
        elif num_cards == 2:
            cards = [self.deck[self.deck_position], self.deck[self.deck_position + 1]]
            self.deck_position += 2
            return cards
        elif num_cards == 3:
            cards = [self.deck[self.deck_position], self.deck[self.deck_position + 1],
                     self.deck[self.deck_position + 2]]
            self.deck_position += 3
            return cards
        return None

    def determine_showdown_winners(self):
        scores = [0] * self.n_players
        for p in range(self.n_players):
            if not self.in_hand[p]:
                continue

            cards = self.community_cards + self.hands[p]

            rank_count = [0] * 13
            suit_count = {"hearts": 0, "diamonds": 0, "spades": 0, "clubs": 0}

            for card in cards:
                suit_count[card[0]] += 1
                rank_count[card[1]] += 1

            # find rank with highest count and rank with second highest count
            first_count = 0
            second_count = 0
            first_rank = 0
            second_rank = 0
            straight_count = 0
            straight_high = 0
            for rank in range(2, 15):
                current_count = rank_count[rank]
                if current_count > first_count:
                    second_count = first_count
                    second_rank = first_rank
                    first_count = current_count
                    first_rank = rank
                elif current_count == first_count:
                    if rank > first_rank:
                        second_count = first_count
                        second_rank = first_rank
                        first_count = current_count
                        first_rank = rank
                    elif current_count == second_count:
                        second_rank = rank
                    else:
                        second_count = current_count
                        second_rank = rank
                elif current_count > second_count:
                    second_count = rank
                    second_rank = rank
                elif current_count == second_count:
                    second_rank = rank

                if current_count == 0:
                    continue

                if rank == 2:
                    if rank_count[14] > 0:
                        straight_count = 2
                else:
                    if rank_count[rank - 1] > 0:
                        straight_count += 1
                    else:
                        straight_count = 1

                    if straight_count >= 5:
                        straight_high = rank

            # check for flush
            flush_high = 0
            flush_suit = ""
            for suit in suit_count:
                if suit_count[suit] >= 5:
                    flush_suit = suit
                    for card in cards:
                        if card[0] == suit:
                            flush_high = max(flush_high, card[1])

            # check for straight flush
            if flush_high != 0 and straight_high != 0:
                sf_high = 0
                sf_count = 0
                suit_ranks = [0] * 13
                for card in cards:
                    if card[0] == flush_suit:
                        suit_ranks[card[1]] = 1

                for rank in range(2, 15):
                    if rank == 2:
                        if suit_ranks[14]:
                            sf_count = 2
                        else:
                            if suit_ranks[rank - 1]:
                                sf_count += 1
                            else:
                                sf_count = 1

                            if sf_count >= 5:
                                sf_high = rank

                if sf_high != 0:
                    scores[p] = 27 + (0.2 * sf_high)
                    continue

            # quads
            if first_count == 4:
                scores[p] = 24 + (0.2 * first_rank)
                continue

            # full house
            if first_count == 3 and second_count >= 2:
                scores[p] = 21 + 0.2 * first_rank + 0.01 * second_rank
                continue

            # flush
            if flush_high != 0:
                scores[p] = 18 + (0.2 * flush_high)
                continue

            # straight
            if straight_high != 0:
                scores[p] = 15 + (0.2 * straight_high)
                continue

            # sort ranks now that high cards matter
            ranks = []
            for card in cards:
                ranks += card[1]
            ranks.sort()

            # trips
            if first_count == 3:
                high = 0
                second_high = 0
                pos = 7
                while high == 0 and second_high == 0:
                    if ranks[pos] == first_rank:
                        pass
                    elif high == 0:
                        high = ranks[pos]
                    else:
                        second_high = ranks[pos]
                    pos -= 1

                scores[p] = 12 + 0.2 * high + 0.01 * second_high
                continue

            # two pair
            if first_count == 2 and second_count == 2:
                high = 0
                pos = 7
                while high == 0:
                    if ranks[pos] == first_rank or ranks[pos] == second_rank:
                        pass
                    else:
                        high = ranks[pos]

                    pos -= 1

                scores[p] = 9 + 0.2 * first_rank + 0.01 * second_rank + 0.002 * high
                continue

            # pair
            if first_count == 2:
                high = 0
                second_high = 0
                third_high = 0
                pos = 7
                while high == 0 and second_high == 0 and third_high == 0:
                    if ranks[pos] == first_rank or ranks[pos] == second_rank:
                        pass
                    elif high == 0:
                        high = ranks[pos]
                    elif second_high == 0:
                        second_high = ranks[pos]
                    else:
                        third_high = ranks[pos]
                    pos -= 1
                scores[p] = 6 + 0.2 * first_rank + 0.01 * high + 0.002 * second_high + 0.0005 * third_high
                continue

            # high card
            scores[p] = 0.2 * ranks[7] + 0.01 * ranks[6] + 0.002 * ranks[5] + 0.0005 * ranks[4] + 0.00001 * ranks[3]
            continue

        max_score = 0
        winners = []
        for p in scores:
            if scores[p] > max_score:
                winners = [p]
                max_score = scores[p]
            elif scores[p] == max_score:
                winners += p

        return winners

In [None]:
from itertools import chain

class actor_critic():
    #Needs to be able to run hand, return loss with grad enabled
    def __init__(self, 
    max_sequence: int = 200, 
    n_players: int = 2,
    gamma: float = .8,
    n_actions: int = 10, # random placeholder value
    ) -> None:
        self.gamma = gamma
        self.env = poker_env(n_players = n_players)
        self.agent = Agent()

        self.observations = [] #this will be a list of lists, each is the list of observations in a hand
        self.obs_flat = list(chain(*self.observations))
        
        self.rewards = []
        self.rewards_flat = list(chain(*self.rewards))

        self.values = []
        self.val_flat = list(chain(*self.values))

        self.action_log_probabilies = []
        self.alp_flat = list(chain(*self.action_log_probabilies))

        self.max_sequence = max_sequence

        self.n_players = n_players

        self.n_actions = n_actions

    def init_hands(self):
        # get all hands
        # run encoder for each of players
        for player in range(self.n_players):
            hand = self.env.get_hand(player)
            self.agent.init_player(player, hand)

    
    def chop_seq(self):
        #if length of observations is above a certain size, chop it back down to under sequence length by removing oldest hand
        #return flattened version to give to model on next run
        if len(self.observations) > self.max_sequence:
            self.observations = self.observations[1:]
            self.obs_flat = list(chain(*self.observations))

            self.rewards = self.rewards[1:]
            self.rewards_flat = list(chain(*self.rewards_flat))

            self.values = self.values[1:]
            self.val_flat = list(chain(*self.values))

            self.action_log_probabilies = self.action_log_probabilies[1:]
            self.alp_flat = list(chain(*self.action_log_probabilies))

        else:
            self.obs_flat = list(chain(*self.observations))
            self.rewards_flat = list(chain(*self.rewards_flat))
            self.val_flat = list(chain(*self.values))
            self.alp_flat = list(chain(*self.action_log_probabilies))

    def play_hand(self):
        # makes agent play one hand
        
        # deal cards
        rewards, observations = self.env.new_hand() # start a new hand
        self.init_hands() # pre load all of the hands

        # init lists for this hand
        self.observations += [observations] 
        self.rewards += [rewards]

        self.chop_seq() # prepare for input to model
        
        hand_over = False
        while not hand_over:                

            # get values and policy -- should be in list form over sequence length
            values, policy_dists = self.agent(self.obs_flat, self.rewards_flat)
            value = values[-1].detach().numpy()[0,0] # get last value estimate
            dist = policy_dists[-1].detach().numpy() # get last distribution

            # randomly sample an action
            action = np.random.choice(self.n_actions, p=np.squeeze(dist))

            # UNFINISHED: Need to detokenize actions HERE

            alp = torch.log(policy_dist.squeeze(0)[action])
            reward, obs, hand_over = self.env.take_action(action) # need to change environment to return hand_over boolean

            # add new information from this step
            self.rewards[-1].append(reward)
            self.observations[-1].append(obs)
            self.values[-1].append(value)
            self.action_log_probabilies.append(alp)
            
            # prepare for next action
            self.chop_seq()
        
        V_T, _ = self.agent(self.obs_flat, self.rewards_flat)
        
        # process gradients and return loss:
        return self.get_loss(V_T)

    def get_loss(self, values, rewards, V_T):

        Qs = []
        Q_t = V_T
        for t in reversed(range(len(rewards))):
            Q_t = rewards[t] + self.gamma * Q_t
            Qs[t] = Q_t
        
        Qs = torch.FloatTensor(Qs)
        values = torch.FloatTensor(self.val_flat)
        alps = torch.stack(self.alp_flat)
        advantages = Qs - values

        
        actor_loss = (-alps * advantages).mean() # loss function for policy going into softmax on backpass
        critic_loss = 0.5 * advantages.pow(2).mean() # autogressive critic loss
        loss = actor_loss + critic_loss # no entropy in this since that would deviate from deepnash
        return loss
    

        

In [104]:
reward = [0, 2,3,0,5,0,8]
reward[1:]

[2, 3, 0, 5, 0, 8]

In [98]:
reversed(range(len(rewards)))

<range_iterator at 0x1209af3f0>