First part, when preprocessing is done, is accessible here : [first part](https://www.kaggle.com/josephamigo/hungry-geese-is-a-nlp-problem-part-1)

Here you can find information about TPU in kaggle : [TPU in kaggle](https://www.kaggle.com/docs/tpu)

Here what really made a difference was to use relative embedding as yuricat suggested and use a very small learning rate.

Thanks to yuricat for his great suggestion!!

# TPU initialization and dataset loading

In [None]:
import tensorflow as tf
import os

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
except ValueError:
    TPU = None

if TPU:
    print(f"\n... RUNNING ON TPU - {TPU.master()}...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    print(f"\n... RUNNING ON CPU/GPU ...")
    # Yield the default distribution strategy in Tensorflow
    #   --> Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy() 

# What Is a Replica?
#    --> A single Cloud TPU device consists of FOUR chips, each of which has TWO TPU cores. 
#    --> Therefore, for efficient utilization of Cloud TPU, a program should make use of each of the EIGHT (4x2) cores. 
#    --> Each replica is essentially a copy of the training graph that is run on each core and 
#        trains a mini-batch containing 1/8th of the overall batch size
N_REPLICAS = strategy.num_replicas_in_sync
    
print(f"... # OF REPLICAS: {N_REPLICAS} ...\n")

print(f"\n... ACCELERATOR SETUP COMPLTED ...\n")

In [None]:
print(f"\n... XLA OPTIMIZATIONS STARTING ...\n")

print(f"\n... CONFIGURE JIT (JUST IN TIME) COMPILATION ...\n")
# enable XLA optmizations (10% speedup when using @tf.function calls)
tf.config.optimizer.set_jit(True)

print(f"\n... XLA OPTIMIZATIONS COMPLETED ...\n")

In [None]:
# Step 3: Use a familiar call to get the GCS path of the dataset
from kaggle_datasets import KaggleDatasets
DATA_DIR = KaggleDatasets().get_gcs_path("hungry-geese-nlp-preprocess-tpu-ds")

In [None]:
import re
class_regex = "df_([0-9]+)\.tfrec"
class_prog = re.compile(class_regex)

In [None]:
# Get the Full Paths to The Individual TFRecord Files
TRAIN_TFREC_PATHS = sorted(
    tf.io.gfile.glob(os.path.join(DATA_DIR, "*.tfrec")), 
    key=lambda x: int(class_prog.findall(x)[0]))
TRAIN_TFREC_PATHS

In [None]:
feature_description = {    
    'label': tf.io.FixedLenFeature([], tf.int64),
    'sentence': tf.io.FixedLenFeature([77], tf.int64),
    'positions': tf.io.FixedLenFeature([77], tf.int64)
}

In [None]:
import pickle
import bz2
def _parse_function(example):
    """
    Args:
        example: A string tensor representing a `tf.train.Example`.
    """

    # Parse `example`.
    parsed_example = tf.io.parse_single_example(example, feature_description)
    # Decode the tf.string
    
    sentence = parsed_example['sentence']
    positions = parsed_example['positions']
    label = parsed_example['label']
    return sentence, positions, label

# Creation of the neural network

We will do classification so we'll add a class token

In [None]:
from tensorflow.keras.layers import Embedding, MultiHeadAttention, LayerNormalization, Dropout, Dense

def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
        Dense(d_model)  # (batch_size, seq_len, d_model)
    ])

class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1, training=True):
        super(EncoderLayer, self).__init__()
        self.training = training
        
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)

        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, x):
        attn_output = self.mha(x, x, x)  # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=self.training)
        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=self.training)
        out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

        return out2

class Net(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, rate=0.1, training=True):
        super(Net, self).__init__()
        self.training = training
        self.d_model = d_model
        self.num_layers = num_layers
        
        self.emb = Embedding(input_dim=50, output_dim=128)
        self.pos_emb = Embedding(input_dim=77, output_dim=128) # relative positional embedding
        self.cls_token_emb = Embedding(input_dim=1, output_dim=256) # class token

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate, self.training)
                           for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)
        
        bs = 128 if self.training else 1
        self.token = tf.zeros((bs,1),dtype=tf.int64)
        self.policy_head = Dense(4, activation=None, use_bias=False)
        
    def call(self, sentence, positions):
        h = self.emb(sentence)
        h_pos = self.pos_emb(positions)
        h = tf.concat((h_pos,h), 2)
        
        h_token = self.cls_token_emb(self.token)
        h = tf.concat((h_token,h), 1)
        
        h = self.dropout(h, training=self.training)
        
        for i in range(self.num_layers):
            h = self.enc_layers[i](h)
        
        action_token_h = h[:,0]

        return self.policy_head(action_token_h)

# Utils functions

In [None]:
BATCH_SIZE_PER_REPLICA = 128
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_steps = (100000*24)//GLOBAL_BATCH_SIZE
BUFFER_SIZE = 10000
TRAIN_IMAGE_MODEL = True

prefetch = 50

def get_dataset(_):
    raw_train_ds = tf.data.TFRecordDataset(TRAIN_TFREC_PATHS, num_parallel_reads=AUTO)
    
    dataset = raw_train_ds.repeat().shuffle(BUFFER_SIZE).map(_parse_function,num_parallel_calls=AUTO)
    dataset = dataset.batch(BATCH_SIZE_PER_REPLICA).prefetch(prefetch)
    return dataset

In [None]:
with strategy.scope():
    optimizer = tf.keras.optimizers.Adam(0.000025)
    crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    #@tf.function
    def loss_function(real, pred):
        loss_ = crossentropy(real, pred)
      
        return tf.nn.compute_average_loss(loss_, global_batch_size=GLOBAL_BATCH_SIZE)
    
    def accuracy_function(real, pred):
        accuracies = tf.equal(real, tf.argmax(pred, axis=1))
        accuracies = tf.cast(accuracies, dtype=tf.float32)
        
        return tf.math.reduce_mean(accuracies)

In [None]:
@tf.function(experimental_relax_shapes=True)
def train_step(inputs):
    sentence, positions, target = inputs
    loss = 0.
        
    with tf.GradientTape() as tape:  
        logits = model(sentence,positions)
        loss = loss_function(target, logits)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    accuracy = accuracy_function(target, logits)
    
    return loss, accuracy

@tf.function(experimental_relax_shapes=True)
def distributed_train_step(inputs,):
    per_replica_losses, per_replica_accuracy = strategy.run(train_step, args=(inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses,axis=None), strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_accuracy,axis=None)

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

# Training

In [None]:
import time
EPOCHS = 2
with strategy.scope():
    model = Net(num_layers=6, d_model=256, num_heads=8, dff=2048)
    
    dataset = strategy.experimental_distribute_datasets_from_function(get_dataset)
    train_iterator = iter(dataset)
    
    best_loss = 0
    loss_plot = []
    
    for epoch in range(EPOCHS):
        start = time.time()
        total_loss = 0.0
        print(f'Epoch {epoch + 1}/{EPOCHS}')
        for batch in range(train_steps):
            inputs = next(train_iterator)
            loss, accuracy = distributed_train_step(inputs)
            loss = loss.numpy()
            accuracy = accuracy.numpy()
            total_loss += loss

            if batch % 1==0:
                print(f'{batch+1}/{train_steps} Total Loss: {total_loss/(batch+1):.4f}  Batch Loss: {loss:.4f}  Batch Acc: {accuracy:.4f}',end='\r')

        # storing the epoch end loss value to plot later
        loss_plot.append(total_loss / train_steps)
        
        print(f'Epoch {epoch + 1} Loss {total_loss/train_steps:.6f}')
        print(f'Time taken for 1 epoch {time.time() - start} sec\n')

# Saving model locally and loading it

In [None]:
with strategy.scope():    
    save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
    model.save_weights('./model.h5', options=save_locally) # saving in Tensorflow's "SavedModel" format

In [None]:
model_inf = Net(num_layers=6, d_model=256, num_heads=8, dff=2048, training=False)
model_inf(tf.zeros((1,77),dtype=tf.int64),tf.zeros((1,77),dtype=tf.int64))

In [None]:
model_inf.load_weights('./model.h5')

# Testing !

In [None]:
import numpy as np
from kaggle_environments.envs.hungry_geese.hungry_geese import Action
def preprocess_map_obs(obs, previous_obs=None, p=None):
    if p is None:
        p = 0
    
    relativ_center = obs[0]['observation']['geese'][p][0]
    relativ_poss = np.roll(np.arange(77), relativ_center)
    
    sentence = []
    positions = []
    for pp, player in enumerate(obs):
        real_player_index = pp
        player_index = (pp - p) % 4
        geese_length = len(obs[0]['observation']['geese'][real_player_index])
        for goose_body_position, goose_board_position in enumerate(obs[0]['observation']['geese'][real_player_index]):
            if goose_body_position == 0:
                body_part = 0
            elif goose_body_position == (geese_length-1):
                body_part = 2
            else:
                body_part = 1

            last_action = Action[obs[real_player_index]['action']].value

            index_player = 3*4 * player_index
            index_bodypart = 4*body_part
            
            word_unique_index = index_player + index_bodypart + last_action + 1
            sentence.append(word_unique_index)
            
            position = relativ_poss[goose_board_position]
            positions.append(position)
            

    for food_board_position in obs[0]['observation']['food']:
        word_unique_index = 47 + 1 + 1
        sentence.append(word_unique_index)
        
        position = relativ_poss[food_board_position]
        positions.append(position)
        
    left_positions = set(range(0,77))-set(positions)
    
    positions = positions + list(left_positions)
    sentence = sentence + [0]*(77-len(sentence))
        
    return sentence, positions

In [None]:
from kaggle_environments import make
env = make("hungry_geese")
obs = env.reset(4)

action_mapping = ['NORTH', 'SOUTH', 'WEST', 'EAST']
previous_obs = None
while not env.done:
    actions = []
    for p in range(4):
        if obs[p]["status"] == "DONE":
            actions.append("NORTH")
            continue
        sentence, positions = preprocess_map_obs(obs, previous_obs=previous_obs, p=p)

        sentence = tf.convert_to_tensor(sentence, dtype=tf.int64)
        positions = tf.convert_to_tensor(positions, dtype=tf.int64)
        preds = model_inf.call(tf.expand_dims(sentence,0),tf.expand_dims(positions,0))
        pred = tf.math.argmax(preds,1).numpy()[0]

        actions.append(action_mapping[pred])
            
    previous_obs = obs
    obs = env.step(actions)

In [None]:
env.render(mode="ipython", width=800, height=700)