In [None]:
import numpy as np
import os
import random
from tqdm.notebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [None]:
def seed_everything(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

seed = 42
seed_everything(seed)

In [None]:
class GeeseDataset(Dataset):
    def __init__(self, states, actions, rewards):
        self.states = states
        self.actions = actions
        self.rewards = rewards
        
    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        state, action, reward = self.states[idx], self.actions[idx], self.rewards[idx]
        # vertical flip
        if torch.rand(1) > 0.5:
            state = np.flip(state, axis=2).copy()
            action = action[[0,1,3,2]]

        # horizontal flip
        if torch.rand(1) > 0.5:
            state = np.flip(state, axis=1).copy()
            action = action[[1,0,2,3]]
            
        action = action.argmax()
        
        return state, action, reward

In [None]:
# Neural Network for Hungry Geese
class TorusConv2d(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, bn):
        super().__init__()
        self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
        self.conv = nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
        self.bn = nn.BatchNorm2d(output_dim) if bn else None

    def forward(self, x):
        h = torch.cat([x[:,:,:,-self.edge_size[1]:], x, x[:,:,:,:self.edge_size[1]]], dim=3)
        h = torch.cat([h[:,:,-self.edge_size[0]:], h, h[:,:,:self.edge_size[0]]], dim=2)
        h = self.conv(h)
        h = self.bn(h) if self.bn is not None else h
        return h

class GeeseNet(nn.Module):
    def __init__(self):
        super().__init__()
        layers, filters = 12, 32
        self.conv0 = TorusConv2d(17, filters, (3, 3), True)
        self.blocks = nn.ModuleList([TorusConv2d(filters, filters, (3, 3), True) for _ in range(layers)])
        self.head_p = nn.Linear(filters, 4, bias=False)
        self.head_v = nn.Linear(filters * 2, 1, bias=False)

    def forward(self, x):
        h = F.relu_(self.conv0(x))
        for block in self.blocks:
            h = F.relu_(h + block(h))
        h_head = (h * x[:,:1]).view(h.size(0), h.size(1), -1).sum(-1)
        h_avg = h.view(h.size(0), h.size(1), -1).mean(-1)
        p = self.head_p(h_head)
        v = torch.tanh(self.head_v(torch.cat([h_head, h_avg], 1)))
        return p, v.squeeze(dim=1)

In [None]:
def train_model(num_epochs):
    dataloader = DataLoader(
        GeeseDataset(
            np.load('../input/alphageese-trajectories/states.npy'), 
            np.load('../input/alphageese-trajectories/actions.npy'), 
            np.load('../input/alphageese-trajectories/rewards.npy')
        ), 
        batch_size=2048, 
        shuffle=True, 
        drop_last=True
    )
    model = GeeseNet()
    model.cuda()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    data_size = len(dataloader.dataset)
    
    for epoch in range(num_epochs):
        epoch_ploss = 0.0
        epoch_vloss = 0.0
        epoch_acc = 0

        for item in tqdm(dataloader, leave=False):
            states = item[0].cuda().float()
            actions = item[1].cuda().long()
            rewards = item[2].cuda().float()

            optimizer.zero_grad()

            policy, value = model(states)
            policy_loss = F.cross_entropy(policy, actions)
            value_loss = F.mse_loss(value, rewards)
            _, preds = torch.max(policy, 1)
            
            (policy_loss + value_loss).backward()
            optimizer.step()

            epoch_ploss += policy_loss.item() * len(policy)
            epoch_vloss += value_loss.item() * len(policy)
            epoch_acc += torch.sum(preds == actions.data)

        epoch_ploss = epoch_ploss / data_size
        epoch_vloss = epoch_vloss / data_size
        epoch_acc = epoch_acc.double() / data_size

        print('Epoch {}/{} | Policy Loss: {:.4f} | Value Loss: {:.4f} | Acc: {:.4f}'.format(
            epoch + 1, num_epochs, epoch_ploss, epoch_vloss, epoch_acc))
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'alphageese_epoch{epoch + 1}.pth')

In [None]:
train_model(num_epochs=30)