<a href="https://colab.research.google.com/github/wty0511/IC-TIR-Lol/blob/master/model_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Interpretable Contextual Team-aware Item Recommendation: Application in Multiplayer Online Battle Arena Games
*Andres Villa, Vladimir Araujo, Francisca Cattan*

# Introduction

This notebook contains the code of the proposed model. It is composed of 8 main stages:

1. Connect to gDrive
2. Dataset and Transformations
3. Model
4. Logger and Checkpointer
5. Metrics
6. Training and evaluation loop
7. Config file
8. Training and evaluation executor
9. Obtain the role and id of each champion in each match
10. Load the attention weights
11. Draw the attention map

*This notebook can be run in it's entirety. The final cell executes the training and validation of the model. 

>[Main Model - Project Title](#scrollTo=uoLSVVIBCwLm)

>[Introduction](#scrollTo=t8_YV_PIDR97)

>[Install all the dependencies](#scrollTo=etkQTYydGkFM)

>[Import the dependencies](#scrollTo=S0YvGjijGxET)

>[Connect to gDrive](#scrollTo=pfDyM4E7G4L2)

>[Dataset and Transformations](#scrollTo=h9MDWroJSkhM)

>[Model](#scrollTo=UIm1_KUCUNB0)

>>[Transformer encoder modified to obtain the attention weights](#scrollTo=qr3TZbrnUg2H)

>>[Auxiliary Task Classes](#scrollTo=pwRy106QU6sH)

>>[Main Class of the proposed model](#scrollTo=9AlU_u42VG8A)

>[Logger and Checkpointer](#scrollTo=rwYoKWcsVqex)

>[Metrics](#scrollTo=5ktMqAUMWeEz)

>[Training and evaluation loop](#scrollTo=WDA0GHysW4vX)

>[Config file](#scrollTo=CyRfaqN8XvYi)

>[Training and evaluation executor](#scrollTo=IVtKoVTcYDS1)

>[T-test](#scrollTo=D2TPs5U3vv7m)

>[Obtain the role and id of each champion in each match](#scrollTo=sFtaUCU5T8fl)

>[Load the attention weights](#scrollTo=VINfHm76U1vz)

>[Draw the attention map](#scrollTo=SvhQCEzcU6x_)



# Install all the dependencies

Install all the libraries neccesary to run the model. 

In [None]:
!nvidia-smi

Tue Jul 28 06:40:42 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P0    28W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install git+git://github.com/williamFalcon/pytorch-lightning.git@master --upgrade

Collecting git+git://github.com/williamFalcon/pytorch-lightning.git@master
  Cloning git://github.com/williamFalcon/pytorch-lightning.git (to revision master) to /tmp/pip-req-build-j84m_1zy
  Running command git clone -q git://github.com/williamFalcon/pytorch-lightning.git /tmp/pip-req-build-j84m_1zy
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting future>=0.17.1
[?25l  Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)
[K     |████████████████████████████████| 829kB 2.9MB/s 
Collecting PyYAML>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 15.9MB/s 
Building wheels for collected packages: pytorch-lightning
 

In [None]:
!pip install comet_ml==3.0.2

Collecting comet_ml==3.0.2
[?25l  Downloading https://files.pythonhosted.org/packages/99/c6/fac88f43f2aa61a09fee4ffb769c73fe93fe7de75764246e70967d31da09/comet_ml-3.0.2-py3-none-any.whl (170kB)
[K     |████████████████████████████████| 174kB 2.9MB/s 
[?25hCollecting websocket-client>=0.55.0
[?25l  Downloading https://files.pythonhosted.org/packages/4c/5f/f61b420143ed1c8dc69f9eaec5ff1ac36109d52c80de49d66e0c36c3dfdf/websocket_client-0.57.0-py2.py3-none-any.whl (200kB)
[K     |████████████████████████████████| 204kB 8.9MB/s 
[?25hCollecting everett[ini]>=1.0.1; python_version >= "3.0"
  Downloading https://files.pythonhosted.org/packages/12/34/de70a3d913411e40ce84966f085b5da0c6df741e28c86721114dd290aaa0/everett-1.0.2-py2.py3-none-any.whl
Collecting wurlitzer>=1.0.2
  Downloading https://files.pythonhosted.org/packages/0c/1e/52f4effa64a447c4ec0fb71222799e2ac32c55b4b6c1725fccdf6123146e/wurlitzer-2.0.1-py2.py3-none-any.whl
Collecting comet-git-pure>=0.19.11
[?25l  Downloading https://f

In [None]:
!pip install omegaconf

Collecting omegaconf
  Downloading https://files.pythonhosted.org/packages/3d/95/ebd73361f9c6e94bd0f3b19ffe31c24e833834c022f1c0328ac71b2d6c90/omegaconf-2.0.0-py3-none-any.whl
Installing collected packages: omegaconf
Successfully installed omegaconf-2.0.0


In [None]:
!pip install adabound

Collecting adabound
  Downloading https://files.pythonhosted.org/packages/cd/44/0c2c414effb3d9750d780b230dbb67ea48ddc5d9a6d7a9b7e6fcc6bdcff9/adabound-0.0.5-py3-none-any.whl
Installing collected packages: adabound
Successfully installed adabound-0.0.5


In [None]:
!pip install ml_metrics

Collecting ml_metrics
  Downloading https://files.pythonhosted.org/packages/c1/e7/c31a2dd37045a0c904bee31c2dbed903d4f125a6ce980b91bae0c961abb8/ml_metrics-0.1.4.tar.gz
Building wheels for collected packages: ml-metrics
  Building wheel for ml-metrics (setup.py) ... [?25l[?25hdone
  Created wheel for ml-metrics: filename=ml_metrics-0.1.4-cp36-none-any.whl size=7850 sha256=4c96be29d4d35a67a4b7cfad5648a20405ba81588e3a54aafe24d966ed354a94
  Stored in directory: /root/.cache/pip/wheels/b3/61/2d/776be7b8a4f14c5db48c8e5451451cabc58dc6aa7ee3801163
Successfully built ml-metrics
Installing collected packages: ml-metrics
Successfully installed ml-metrics-0.1.4


# Import the dependencies

Import all the libraries neccesary to run the model.

In [None]:
from comet_ml import Experiment as CometExperiment
from comet_ml import ExistingExperiment as CometExistingExperiment
from google.colab import drive
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
import pandas as pd
import time

# from tqdm.notebook import trange, tqdm
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.logging import LightningLoggerBase
from pytorch_lightning.loggers import CometLogger

import os
import pytorch_lightning as pl
import pickle
import adabound
import ml_metrics as metrics
import random
import itertools
from torchvision import transforms


# Connect to gDrive

Connect the notebook with the gDrive, which is essential to load and save data like dataset, checkpoints, and attention weights. 

In [None]:
drive.mount('/content/gdrive/')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive/


# Dataset and Transformations

This is important to load the k different partitions which are obtained using cross validation k-fold.

In [None]:
train_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/train_splits.pkl'
test_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/test_splits.pkl'
champion_path = '/content/gdrive/My Drive/Proyecto_RecSys/dataset/champion_types.pkl'

In [None]:
#@title Cargar listas de particiones
with open(train_path, 'rb') as handle:
    list_trainset = pickle.load(handle)

with open(test_path, 'rb') as handle:
    list_testset = pickle.load(handle)

with open(champion_path, 'rb') as handle:
    champion_types = pickle.load(handle)

In [None]:
def get_partition(id_split, list_splits = list_trainset):
    df = list_splits[id_split]
    null_registers = df.loc[(df.item1 == 0) & (df.item2 == 0) & (df.item3 == 0) & (df.item4 == 0) & (df.item5 == 0) & (df.item6 == 0)]
    match_to_del = list(set(null_registers['matchid']))
    df = df[~df.matchid.isin(match_to_del)]
    return df

These transformations rote randomly the order between the two teams, and the champions within each team.

In [None]:
class RandomSort_Team(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """
    
    def get_random_sample(self, sample):
        x, y = sample

        ids_teams_1 = [x for x in range(5)]
        ids_teams_2 = [x for x in range(5,10)]

        ids_team_t = [ids_teams_1, ids_teams_2]

        ids_teams = [1, 0]
        #ids_teams = [x for x in range(2)]
        #random.shuffle(ids_teams)

        ids_team_t = [ids_team_t[i] for i in ids_teams]
        
        ids_team_t = list(itertools.chain.from_iterable(ids_team_t))

        x['champions'] = x['champions'][ids_team_t]
        x['role'] = x['role'][ids_team_t]
        x['type'] = x['type'][ids_team_t,:]

        y['items'] = y['items'][ids_team_t,:]

        if ids_teams == [1, 0]:
            y['win'] = torch.tensor(1) - y['win']
        
        return x, y

    def __call__(self, sample_list):
        list_x_champions = []
        list_x_role = []
        list_x_type = []
        list_y_items = []
        list_y_win = []
        x_old, y_old = sample_list
        if isinstance(x_old, (list)) and isinstance(y_old, (list)):
            for i in range(len(x_old)):
                list_x_champions.append(x_old[i]['champions'])
                list_x_role.append(x_old[i]['role'])
                list_x_type.append(x_old[i]['type'])
                list_y_items.append(y_old[i]['items'])
                list_y_win.append(y_old[i]['win'])
                sample = x_old[i], y_old[i]
                x, y = self.get_random_sample(sample)
                list_x_champions.append(x['champions'])
                list_x_role.append(x['role'])
                list_x_type.append(x['type'])
                list_y_items.append(y['items'])
                list_y_win.append(y['win'])
        else:
            list_x_champions.append(x_old['champions'])
            list_x_role.append(x_old['role'])
            list_x_type.append(x_old['type'])
            list_y_items.append(y_old['items'])
            list_y_win.append(y_old['win'])
            sample = x_old, y_old
            x, y = self.get_random_sample(sample)
            list_x_champions.append(x['champions'])
            list_x_role.append(x['role'])
            list_x_type.append(x['type'])
            list_y_items.append(y['items'])
            list_y_win.append(y['win'])
        new_x = {
            'champions': torch.stack(list_x_champions, dim=0),
            'role': torch.stack(list_x_role, dim=0),
            'type': torch.stack(list_x_type, dim=0)
        }
        new_y = {
            'items': torch.stack(list_y_items, dim=0),
            'win': torch.stack(list_y_win, dim=0)
        }
        return new_x, new_y


In [None]:
class RandomSort_Part(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """
        

    def __call__(self, sample):

        list_t_x = []
        list_t_y = []
        x, y = sample

        list_t_x.append(x)
        list_t_y.append(y)

        ids_team_1 = [x for x in range(5)]
        ids_team_2 = [x for x in range(5,10)]
        random.shuffle(ids_team_1)
        random.shuffle(ids_team_2)

        ids_match = ids_team_1
        ids_match.extend(ids_team_2)
        
        x['champions'] = x['champions'][ids_match]
        x['role'] = x['role'][ids_match]
        x['type'] = x['type'][ids_match,:]

        y['items'] = y['items'][ids_match,:]

        list_t_x.append(x)
        list_t_y.append(y)

        return list_t_x, list_t_y

In [None]:
class LolDataset(Dataset):
  def __init__(self, data, transform=None):
  # cargar el dataset
    #self.matches = self._load_matches(path)
    self.matches = data
    # comprobar si existe el .pkl con los diccionarios

    # else:
    # extraer info. del dataframe
    self.champions = set(self.matches['championid'])
    self.roles = set(self.matches['position-role'])
    self.matches_id = list(set(self.matches['matchid']))
    self.items = self.matches['item1']
    self.items.append(self.matches['item2'])
    self.items.append(self.matches['item3'])
    self.items.append(self.matches['item4'])
    self.items.append(self.matches['item5'])
    self.items.append(self.matches['item6'])
    items = set(self.items)
    self.items = {i for i in items if i != 0}
    self.champion_types = champion_types
    list_champion_types = []
    for k,v in champion_types.items():
      list_champion_types.extend(v)
  
    self.set_champ_type = set(list_champion_types)

    # crear diccionarios token2id y id2token
    self.champions_token2id, self.champions_id2token = self._token_dict(self.champions)
    self.roles_token2id, self.roles_id2token = self._token_dict(self.roles)
    self.items_token2id, self.items_id2token = self._token_dict(self.items)
    self.types_token2id, self.types_id2token = self._token_dict(self.set_champ_type)

    self.transform = transform

  def _load_matches(self, path):
    data_matches = pd.read_csv(path) 
    return data_matches

  def _token_dict(self, data):
    token2id = {}
    id2token = {}
    for i, j in enumerate(data):
      token2id.update({j:i})
      id2token.update({i:j})

    return token2id, id2token

  def _tokens2ids(self, token2id, tokens):
    ids = []
    for token in tokens:
      ids.append(token2id[token])
      
    return ids

  def _tokens2ids_items(self, token2id, tokens):
    #items_vecs = []
    item_vec = np.zeros((len(token2id)))
    for token in tokens:
      if token in token2id: 
        item_vec[token2id[token]] = 1
      #items_vecs.append(item_vec)
      
    return item_vec

  def _build_dict(self, match):
    # sacar en orden los campeones de la partida
    champion_tokens = list(match['championid'])
    champions_ids = self._tokens2ids(self.champions_token2id, champion_tokens)

    # sacar en orden los items de la partida
    #items_tokens = match['championid']
    #items_ids = self._tokens2ids(self.items_token2id, items_tokens)
    # sacar en orden los roles de la partida
    role_tokens = list(match['position-role'])
    role_ids = self._tokens2ids(self.roles_token2id, role_tokens)
    list_win = list(match['win'])[4:6]
    
    list_win = np.array(list_win)
    num_win = np.argsort(list_win)
    num_win = num_win[len(num_win)-1]

    list_part_items = []
    list_types = []
    items_list = ['item1','item2','item3','item4','item5','item6']
    for id_champ in champion_tokens:
      champ_atr = match[match.championid == id_champ]
      items = champ_atr[items_list]
      items_tokens = list(items.iloc[0, :])
      items_ids = self._tokens2ids_items(self.items_token2id, items_tokens)
      list_part_items.append(items_ids)

      type_champ = self.champion_types[id_champ]
      type_ids = self._tokens2ids(self.types_token2id, type_champ)
      list_types.append(type_ids)

    # construir 5 veces 0s y 5 veces 1s
    #team_ids = 
    x = {
        'champions': torch.from_numpy(np.array(champions_ids)),
        'role': torch.from_numpy(np.array(role_ids)),
        'type': torch.from_numpy(np.array(list_types))
    }
    y= {
        'items': torch.from_numpy(np.array(list_part_items)),
        'win': torch.from_numpy(np.array(num_win))
    }
    
    return x, y

  def __getitem__(self, idx): 
    # idx es el match_id en este caso
    # la función debiera retornar la info de cada partida
    # buscar idx de la partida en mi estructura, y retornar los diccionarios con los atributos
    id_match = self.matches_id[idx]
    match = self.matches[(self.matches.matchid == id_match)]
    x, y = self._build_dict(match) # entrega un df de la partida según el idx
    if self.transform:
        sample = x, y
        x, y = self.transform(sample)
    return x, y # el item per sé, la partida con todas sus características

  def __len__(self):
    return len(self.matches_id)


# Model

## Transformer encoder modified to obtain the attention weights

In [None]:
class TransformerEncoder(nn.Module):
    """TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the TransformerEncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = transformer_encoder(src)
    """

    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src, mask=None, src_key_padding_mask=None):
        """Pass the input through the endocder layers in turn.

        Args:
            src: the sequnce to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src
        att_weights = []

        for i in range(self.num_layers):
            output, attn_output_weights = self.layers[i](output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
            
            att_weights.append(attn_output_weights)

        if self.norm:
            output = self.norm(output)

        return output, att_weights



In [None]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise RuntimeError("activation should be relu/gelu, not %s." % activation)

In [None]:
class TransformerEncoderLayer(nn.Module):
    """TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)
    """

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """Pass the input through the endocder layer.

        Args:
            src: the sequnce to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        src2, attn_output_weights = self.self_attn(src, src, src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        if hasattr(self, "activation"):
            src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        else:  # for backward compatibility
            src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src, attn_output_weights

## Auxiliary Task Classes

In [None]:
def getItems(gt_items, table_emb, num_items, emb_dim):
    list_match = []
    device = gt_items.device
    for i in range(gt_items.size(0)):
        match = gt_items[i,:,:]
        list_part_item = []
        for j in range(gt_items.size(1)):
            participant_items = match[j,:]
            sum_k = torch.sum(participant_items, dim = 0).item()
            if int(sum_k) > 0:
                _, pos_items = torch.topk(participant_items, k = int(sum_k), dim = 0)
                items_emb = table_emb(pos_items)
                items_emb = torch.mean(items_emb, dim = 0)
                list_part_item.append(items_emb)
            else:
                list_part_item.append(torch.zeros(emb_dim).to(device))
        team_item_emb = torch.stack(list_part_item)
        list_match.append(team_item_emb)
    return torch.stack(list_match)

In [None]:
class WinEncoder(nn.Module):
    def __init__(self, model_dim, n_items):
        super(WinEncoder, self).__init__()
        self.proj_win = nn.Linear(4*model_dim, 2)
        self.embeddings_table_items = nn.Embedding(num_embeddings = n_items, embedding_dim = model_dim)
        self.n_items = n_items
        self.model_dim = model_dim
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.proj_win.bias.data.zero_()
        self.proj_win.weight.data.uniform_(-initrange, initrange)
        self.embeddings_table_items.weight.data.uniform_(-initrange, initrange)

    def forward(self, att_match, item_list):
        # att_match size (Batch, Seq, Emb)
        # item_list size (Batch, Seq, Num_items, Emb)
        att_item_team_1, att_item_team_2 = torch.chunk(att_match, 2, dim=1)
        items_team_1, items_team_2 = torch.chunk(item_list, 2, dim=1)

        items_team_1 = getItems(items_team_1, self.embeddings_table_items, self.n_items,self.model_dim)
        items_team_2 = getItems(items_team_2, self.embeddings_table_items, self.n_items,self.model_dim)

        att_item_team_1 = torch.mean(att_item_team_1, dim=1)
        att_item_team_1 = F.relu(att_item_team_1)
        att_item_team_1 = (att_item_team_1 / att_item_team_1.max())
        items_team_1 = torch.mean(items_team_1, dim=1)
        items_team_1 = F.relu(items_team_1)
        items_team_1 = (items_team_1 / items_team_1.max())

        att_item_team_2 = torch.mean(att_item_team_2, dim=1)
        att_item_team_2 = F.relu(att_item_team_2)
        att_item_team_2 = (att_item_team_2 / att_item_team_2.max())
        items_team_2 = torch.mean(items_team_2, dim=1)
        items_team_2 = F.relu(items_team_2)
        items_team_2 = (items_team_2 / items_team_2.max())

        att_item_team_1 = torch.cat((att_item_team_1, items_team_1), 1)
        att_item_team_2 = torch.cat((att_item_team_2, items_team_2), 1)
        proj_win_team = torch.cat((att_item_team_1, att_item_team_2), 1)
        win_emb = self.proj_win(F.relu(proj_win_team))

        return win_emb

In [None]:
def getTensorPredItem(items_logits):
  pred_items = torch.zeros(items_logits.size())
  for i in range(items_logits.size(0)):
    for j in range(items_logits.size(1)):
      _,pos_items = torch.topk(items_logits[i,j,:],k = 6,dim=0)
      pred_items[i,j,pos_items] = 1
  return pred_items

## Main Class of the proposed model

In [None]:
class TransformerLolRecommender(nn.Module):

    def __init__(self, n_role, n_champions, embeddings_size, nhead, n_items, n_type, nlayers = 1, nhid = 2048, dropout=0.5, aux_task = False, 
                 learnable_team_emb = False):
        super(TransformerLolRecommender, self).__init__()

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        
        self.embeddings_table_role = nn.Embedding(num_embeddings = n_role, embedding_dim = embeddings_size)
        
        self.embeddings_table_champion = nn.Embedding(num_embeddings = n_champions, embedding_dim = embeddings_size)

        self.embeddings_table_type = nn.Embedding(num_embeddings = n_type, embedding_dim = embeddings_size, padding_idx=0)
        
        self.learnable_team_emb = learnable_team_emb
        if learnable_team_emb:
            self.team_encoder = nn.Embedding(num_embeddings = 2, embedding_dim = embeddings_size)
        else:
            self.team_encoder = self.get_team_encoding(embeddings_size, 10)
        
        encoder_layers = TransformerEncoderLayer(embeddings_size, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        
        self.recommender = nn.Linear(embeddings_size, n_items)
        self.pred_champ = nn.Linear(embeddings_size, n_champions)

        self.aux_task = aux_task

        if self.aux_task: 
            self.win_encoder = WinEncoder(embeddings_size, n_items)

        self.init_weights()
    
    def get_learnable_team_emb(self, num_batch):
        emb_team_0 = self.team_encoder(torch.LongTensor([0]).to(self.device))
        emb_team_0 = emb_team_0.expand(5, emb_team_0.size(1))
        emb_team_1 = self.team_encoder(torch.LongTensor([1]).to(self.device))
        emb_team_1 = emb_team_1.expand(5, emb_team_1.size(1))
        emb_team = torch.cat([emb_team_0, emb_team_1], dim = 0)
        emb_team = emb_team.unsqueeze(0).expand(num_batch, emb_team.size(0), emb_team.size(1))
        return emb_team

    
    def get_team_encoding(self, embedding_dim, num_champions = 10):
        team_encoding = torch.zeros(num_champions, embedding_dim)
        team_encoding[5:,:] = 1
        return team_encoding.to(self.device)

    def init_weights(self):
        initrange = 0.1
        
        self.embeddings_table_role.weight.data.uniform_(-initrange, initrange)
        self.embeddings_table_champion.weight.data.uniform_(-initrange, initrange)
        self.embeddings_table_type.weight.data.uniform_(-initrange, initrange)
        
        self.recommender.bias.data.zero_()
        self.recommender.weight.data.uniform_(-initrange, initrange)

        self.pred_champ.bias.data.zero_()
        self.pred_champ.weight.data.uniform_(-initrange, initrange)

        if self.learnable_team_emb:
            self.team_encoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, role, champion_id, types, items, win, enable_teacher_f):

        role_participants = self.embeddings_table_role(role)
        id_participants = self.embeddings_table_champion(champion_id)
        type_champ = self.embeddings_table_type(types)
        type_champ = torch.sum(type_champ, dim =2)
        batch_size = role_participants.size(0)
        if self.learnable_team_emb:
            team_participants = self.get_learnable_team_emb(batch_size)
        else:
            size_team_emb = self.team_encoder.size()
            team_participants = self.team_encoder.unsqueeze(0).expand(batch_size, size_team_emb[0], size_team_emb[1])

        sel_champions = []
        pos_champions = []
        for i in range(win.size(0)):
            id_el = random.randint(0,4)
            pos_champions.append(id_el)
            if win[i] != 0:
                id_el = id_el + 5
            sel_champion = champion_id[i,id_el]
            id_participants[i,id_el,:] = 0
            sel_champions.append(sel_champion)

        sel_champions = torch.stack(sel_champions)
        # pos_champions = torch.stack(pos_champions)

        participants = role_participants + id_participants + team_participants + type_champ
        # size (Seq, Batch, Emb)
        participants = participants.permute(1,0,2)
        # size (Seq, Batch, Emb)
        output, att_weights = self.transformer_encoder(participants)
        # size (Batch, Seq, Emb)
        output = output.permute(1,0,2)
        logits_items = self.recommender(output)

        output_obj = {
            'logits_items': logits_items,
            'att_weights': att_weights,
            'outputs': output,
            'sel_champions': sel_champions,
            'pos_champions': pos_champions
        }

        if self.aux_task:
            if enable_teacher_f: 
                items_used = items
            else:
                items_used = getTensorPredItem(logits_items).to(self.device)
            logits_win = self.win_encoder(output, items_used)
            output_obj['logits_win'] = logits_win

        return output_obj

# Logger and Checkpointer

These classes and methods are essential to log relevant information about the model and metrics in Coment. Likewise, they allow to save checkpoint in each epoch. 

In [None]:
def load_defaults(defaults_file):
    return OmegaConf.load(defaults_file)


def load_config_file(config_file):
    if not config_file:
        return OmegaConf.create()
    return OmegaConf.load(config_file)


def load_config(config_file, defaults_file):
    defaults = load_defaults(defaults_file)
    config = OmegaConf.merge(defaults, load_config_file(config_file))
    config.merge_with_cli()
    return config


def build_config(args):
    return load_config(args.config_file, args.defaults_file)


def config_to_dict(cfg):
    return dict(cfg)


def config_to_comet(cfg):
    def _config_to_comet(cfg, local_dict, parent_str):
        for key, value in cfg.items():
            full_key = "{}.{}".format(parent_str, key)
            if isinstance(value, (dict, DictConfig)):
                _config_to_comet(value, local_dict, full_key)
            else:
                local_dict[full_key] = value

    local_dict = {}
    for key, value in cfg.items():
        if isinstance(value, (dict, DictConfig)):
            _config_to_comet(value, local_dict, key)
        else:
            local_dict[key] = value
    return local_dict

In [None]:
def get_checkpointer(save_path, metric_name='val_acc'):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    return ModelCheckpoint(
        filepath=save_path,
        verbose=True,
        monitor=metric_name,
        mode='max',
    )


# class CometLogger(LightningLoggerBase):
#     # Thank you @ceyzaguirre4
#     def __init__(self, config, *args, **kwargs):
#         super().__init__()
#         self.comet_exp = CometExperiment(*args, **kwargs)
#         self.comet_exp.set_name(config['exp_name'])
#         self.comet_exp.log_parameters(config)
#         self.config = config

#     @rank_zero_only
#     def log_hyperparams(self, params):
#         self.comet_exp.log_parameters(config_to_comet(params))

#     @rank_zero_only
#     def log_metrics(self, metrics, step):
#         self.comet_exp.log_metrics(metrics)

#     @rank_zero_only
#     def finalize(self, status):
#         self.comet_exp.end()
    
#     def version(self):
#         return self.config['exp']


# Metrics

In [None]:
def recall_at_k(output, target, k = 6):
    output_k, ind_k = torch.topk(output, k, dim = 1)
    sum_recall = 0
    num_part = output_k.size(0)
    relevants = target.sum(dim = 1)
    list_recall = []
    for i in range(num_part):
      target_k = target[i, ind_k[i,:]]
      intersection = target_k.sum(dim = 0)
      recall_n = intersection/relevants[i]
      list_recall.append(recall_n)
      sum_recall+=recall_n
    
    recall_avg = sum_recall/num_part
    return recall_avg, num_part, list_recall


In [None]:
def precision_at_k(r, k):
    """Score is precision @ k

    Relevance is binary (nonzero is relevant).

    >>> r = [0, 0, 1]
    >>> precision_at_k(r, 1)
    0.0
    >>> precision_at_k(r, 2)
    0.0
    >>> precision_at_k(r, 3)
    0.33333333333333331
    >>> precision_at_k(r, 4)
    Traceback (most recent call last):
        File "<stdin>", line 1, in ?
    ValueError: Relevance score length < k


    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)

    Returns:
        Precision @ k

    Raises:
        ValueError: len(r) must be >= k
    """
    assert k >= 1
    r = np.asarray(r)[:k] != 0
    if r.size != k:
        raise ValueError('Relevance score length < k')
    return np.mean(r)


def average_precision(r):
    """Score is average precision (area under PR curve)

    Relevance is binary (nonzero is relevant).

    >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1]
    >>> delta_r = 1. / sum(r)
    >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y])
    0.7833333333333333
    >>> average_precision(r)
    0.78333333333333333

    Args:
        r: Relevance scores (list or numpy) in rank order
            (first element is the first item)

    Returns:
        Average precision
    """
    r = np.asarray(r) != 0
    out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]]
    if not out:
        return 0.
    return np.mean(out)

In [None]:
def map_at(output, target, k=6):
    sum_ap = 0
    num_part = output.size(0)
    list_map = []
    for i in range(num_part):
      out_p = output[i,:]
      target_p = target[i,:]
      output_k, ind_k = torch.topk(out_p, k, dim = 0)
      list_rel = target_p[ind_k].tolist()
      ap_at = average_precision(list_rel)
      list_map.append(ap_at) 
      sum_ap += ap_at
    return sum_ap/num_part, list_map

In [None]:
def calc_precision_multiclass(output, target, k = 6):
    output_k, ind_k = torch.topk(output, k, dim = 1)
    sum_prec = 0
    num_part = output_k.size(0)
    list_prec = []
    for i in range(num_part):
      target_k = target[i, ind_k[i,:]]
      intersection = target_k.sum(dim = 0)
      preci_n = intersection/k
      list_prec.append(preci_n)
      sum_prec+=preci_n
    
    prec_avg = sum_prec/num_part
    return prec_avg, num_part, list_prec

In [None]:
def f1_score(recall, precision):
  f1 = 2 * ((precision * recall) / (precision + recall))
  return f1

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value
    Taken from PyTorch's examples.imagenet.main
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def set_seed(seed, slow=False):
    import random

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    if slow:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

In [None]:
def get_winners(att_vec, gt_item, win_vec, pos_champions, outputs_log):
    list_att = []
    list_gt = []
    list_cham = []
    for i in range(att_vec.size(0)):
        win = win_vec[i]
        pos = pos_champions[i]
        if win == 0:
            a = list(range(0,5))
            del a[pos]
            att_vec_match = att_vec[i,a,:]
            gt_match = gt_item[i,a, :]
            list_cham.append(outputs_log[i,pos, :])
            list_att.append(att_vec_match)
            list_gt.append(gt_match)          
        else:
            a = list(range(5,10))
            del a[pos]
            att_vec_match = att_vec[i,a,:]
            gt_match = gt_item[i, a, :]
            list_cham.append(outputs_log[i,pos + 5, :])
            list_att.append(att_vec_match)
            list_gt.append(gt_match)

    att_winners = torch.stack(list_att, dim=0)
    gt_winners = torch.stack(list_gt, dim=0)
    att_cham = torch.stack(list_cham, dim=0)
    return att_winners, gt_winners, att_cham
    



In [None]:
def save_att_weights(list_att, path_save_att):
  with open(path_save_att, 'wb') as handle:
    pickle.dump(list_att, handle)

# Training and evaluation loop

The training and evaluation loop are based on [Pytorch-lightning](https://github.com/williamFalcon/pytorch-lightning)

In [None]:
import argparse

In [None]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)
        #self.elems = entries.items()
    
    def items(self):
        return self.__dict__.items()

In [None]:
class LolRecAttModel(pl.LightningModule):

    def __init__(self, cfg):
        super(LolRecAttModel, self).__init__()
        
        if type(cfg) is argparse.Namespace:
          cfg = vars(cfg)
        self.conf = cfg
        self.hparams = cfg
        self.index_split = self.conf['index_split']
        self.optim = self.conf['optim']
        set_seed(seed = self.conf['seed'])
        train_dataset = self.train_dataset()
        self.batch_size = self.conf['batch_size']
        self.iter_max_train = len(train_dataset)//self.batch_size
        num_roles = len(train_dataset.roles)
        num_champions = len(train_dataset.champions)
        n_items = len(train_dataset.items)
        n_types = len(train_dataset.set_champ_type)
        self.model = TransformerLolRecommender(n_role=num_roles, n_champions=num_champions, embeddings_size=self.conf['embeddings_size'], nhead=self.conf['nhead'], n_items=n_items, n_type=n_types,
                                               nlayers = self.conf['nlayers'], nhid = self.conf['nhid'], dropout=self.conf['dropout'], aux_task = self.conf['win_task'], 
                                               learnable_team_emb = self.conf['learnable_team_emb'])
        self.loss = nn.BCEWithLogitsLoss()
        self.loss_aux = nn.CrossEntropyLoss()
        self.train_loss = AverageMeter()
        self.train_prec = AverageMeter()
        self.iter_epoch = 0
        isExist = os.path.exists(path_save) 
        if isExist:
          dirs = os.listdir(path_save)
          self.iter_epoch = len(dirs)

        self.aux_task = self.conf['win_task']
        
        if self.aux_task:
            self.second_loss = nn.CrossEntropyLoss()
            self.train_acc_win = AverageMeter()
            self.train_main_loss = AverageMeter()
            self.train_win_loss = AverageMeter()
            self.alpha = self.conf['alpha']
            self.beta = self.conf['beta']
            self.epoch_to_win = self.conf['init_epoch']

    def check_epoch(self, num_iter):
      if num_iter == 0:
        self.train_loss = AverageMeter()
        self.train_prec = AverageMeter()
        if self.aux_task:
            self.train_acc_win = AverageMeter()
            self.train_main_loss = AverageMeter()
            self.train_win_loss = AverageMeter()
        self.iter_epoch+=1

    def custom_print(self, batch,  loss, start_time, prec, acc = 0, log_interval = 100, loss_win =0, epoch=1):
      if batch % log_interval == 0:
            elapsed = time.time() - start_time
            elapsed = elapsed*log_interval if batch > 0 else elapsed
            if self.aux_task and self.iter_epoch >= self.epoch_to_win:
                print('| epoch {:3d} | {:5d}/{:5d} batches | '
                      'ms/batch {:5.2f} | '
                      'loss {:5.6f} | loss win {:5.6f} | precision {:5.6f} | Accuracy (win) {:5.6f}'.format(
                        self.iter_epoch, batch, self.iter_max_train,
                        elapsed, loss, loss_win, prec, acc))
            else:
                print('| epoch {:3d} | {:5d}/{:5d} batches | '
                      'ms/batch {:5.2f} | '
                      'loss {:5.6f} | precision {:5.6f}'.format(
                        self.iter_epoch, batch, self.iter_max_train,
                        elapsed, loss, prec)) 

    def forward(self, x, items, win, teacher_forcing):
        role = x['role']
        champions = x['champions']
        types = x['type']
        out = self.model(role, champions, types, items, win, teacher_forcing)
        return out
        #return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        self.check_epoch(batch_nb)
        start_time = time.time()
        x, y = batch
        if len(x['role'].size()) == 3:
          x['role'] = x['role'].reshape(x['role'].size(0)*x['role'].size(1), x['role'].size(2))
          x['champions'] = x['champions'].reshape(x['champions'].size(0)*x['champions'].size(1), x['champions'].size(2))
          x['type'] = x['type'].reshape(x['type'].size(0)*x['type'].size(1), x['type'].size(2), x['type'].size(3))
          y['items'] = y['items'].reshape(y['items'].size(0)*y['items'].size(1), y['items'].size(2), y['items'].size(3))
          y['win'] = y['win'].reshape(y['win'].size(0)*y['win'].size(1))
        y_hat = self.forward(x, y['items'], y['win'], self.conf['teacher_forcing'])
        
        #Mains task
        logits_items = y_hat['logits_items']
        gt_items = y['items']
        sel_champions = y_hat['sel_champions']
        pos_champions = y_hat['pos_champions']
        outputs_log = y_hat['outputs']
        logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)
        
        out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))
        out_aux = self.model.pred_champ(att_cham)

        gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))
        loss = self.loss(out, gt)
        loss_aux = self.loss_aux(out_aux, sel_champions)

        prec, num, _ = calc_precision_multiclass(out, gt, k=6)
        self.train_prec.update(prec, num)

        tensor_avg_prec = torch.tensor([self.train_prec.avg], device=loss.device)
        tensorboard_logs = {'train_loss': loss, 'train_loss_aux': loss_aux, 'train_prec_avg': tensor_avg_prec}

        if self.aux_task and self.iter_epoch >= self.epoch_to_win:

            #Second Task
            out_win = y_hat['logits_win']
            
            gt_win = y['win'].reshape(-1)

            _, preds_win = torch.max(out_win, 1)

            loss_win = self.second_loss(out_win, gt_win)
            loss_total = self.alpha*loss + self.beta*loss_win
            self.train_loss.update(self.alpha*loss.item(), out.size(0))
            self.train_loss.update(self.beta*loss_win.item(), out_win.size(0))

            train_acc = torch.sum(preds_win == gt_win).item()/out_win.size(0)
            self.train_acc_win.update(train_acc, out_win.size(0))
            self.train_main_loss.update(loss.item(), out.size(0))
            self.train_win_loss.update(loss_win.item(), out_win.size(0))

            tensor_avg_acc = torch.tensor([self.train_acc_win.avg], device=loss.device)
            tensorboard_logs['train_acc_win_avg'] = tensor_avg_acc
            tensorboard_logs['train_win_loss'] = loss_win
            tensorboard_logs['train_main_loss'] = loss
            tensorboard_logs['train_win_loss_avg'] = torch.tensor([self.train_win_loss.avg], device=loss.device)
            tensorboard_logs['train_main_loss_avg'] = torch.tensor([self.train_main_loss.avg], device=loss.device)
            tensorboard_logs['train_loss'] = loss_total

            # self.custom_print(batch_nb, self.train_main_loss.avg, start_time, self.train_prec.avg, self.train_acc_win.avg, 100, self.train_win_loss.avg)
        else:
            loss_total = loss + 0.2*loss_aux
            self.train_loss.update(loss.item(), out.size(0))
            tensorboard_logs['total_loss_train'] = loss_total
            # self.custom_print(batch_nb, self.train_loss.avg, start_time, self.train_prec.avg, 0, 100)
        
        tensor_avg_loss = torch.tensor([self.train_loss.avg], device=loss.device)
        tensorboard_logs['train_loss_avg'] = tensor_avg_loss
        return {'loss': loss_total, 'progress_bar': tensorboard_logs, 'avg_loss':  tensor_avg_loss, 'avg_prec':tensor_avg_prec ,'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x, y['items'], y['win'], False)
        att_weights = y_hat['att_weights']

        #Main Task
        logits_items = y_hat['logits_items']
        gt_items = y['items']
        sel_champions = y_hat['sel_champions']
        pos_champions = y_hat['pos_champions']
        outputs_log = y_hat['outputs']

        logits_items, gt_items, att_cham = get_winners(logits_items, gt_items, y['win'], pos_champions, outputs_log)
        out = logits_items.reshape(logits_items.size(0)*logits_items.size(1), logits_items.size(2))
        out_aux = self.model.pred_champ(att_cham)
      
        gt = gt_items.reshape(gt_items.size(0)*gt_items.size(1), gt_items.size(2))

        loss = self.loss(out, gt)
        loss_aux = self.loss_aux(out_aux, sel_champions)

        prec, num, list_prec = calc_precision_multiclass(out, gt, k=6)
        prec1, num, list_prec1 = calc_precision_multiclass(out, gt, k=1)
        prec3, num, list_prec3 = calc_precision_multiclass(out, gt, k=3)

        recall1, num, list_recall1 = recall_at_k(out, gt, k=1)
        recall3, num, list_recall3 = recall_at_k(out, gt, k=3)
        recall6, num, list_recall6 = recall_at_k(out, gt, k=6)

        f11 = f1_score(recall1, prec1)
        f13 = f1_score(recall3, prec3) 
        f16 = f1_score(recall6, prec)

        map6, list_map6 = map_at(out, gt, k=6)
        map1, list_map1 = map_at(out, gt, k=1)
        map3, list_map3 = map_at(out, gt, k=3)

        obj_list = {
            'list_prec1': list_prec1,
            'list_prec3': list_prec3,
            'list_prec': list_prec,
            'list_recall1': list_recall1,
            'list_recall3': list_recall3,
            'list_recall6': list_recall6,
            'list_map1': list_map1,
            'list_map3': list_map3,
            'list_map6': list_map6
        }
        obj_res = {'val_loss': loss, 'val_loss_aux': loss_aux, 'val_prec': prec, 'num_batch': out.size(0), 'num':num, 'map6': map6, 
                   'map1': map1, 'map3': map3, 'val_prec1': prec1, 'val_prec3': prec3, 'val_recall1': recall1, 
                   'val_recall3': recall3, 'val_recall6': recall6, 'val_f1_1': f11, 'val_f1_3': f13, 'val_f1_6': f16, 
                   'att_weights': att_weights, 'logits_items': logits_items, 'obj_list': obj_list}

        #Second Task
        if self.aux_task and self.iter_epoch >= self.epoch_to_win:
            out_win = y_hat['logits_win']
            
            gt_win = y['win'].reshape(-1)
            _, preds_win = torch.max(out_win, 1)

            loss_win = self.second_loss(out_win, gt_win)

            acc_win = torch.sum(preds_win == gt_win).item()/out_win.size(0)
            obj_res['val_acc'] = acc_win
            obj_res['val_loss_win'] = loss_win
            obj_res['val_main_loss'] = loss
            obj_res['num_batch_acc'] = out_win.size(0)

        return obj_res

    def validation_epoch_end(self, outputs):
        avg_loss = AverageMeter()
        avg_loss_aux = AverageMeter()
        avg_prec = AverageMeter()
        avg_prec1 = AverageMeter()
        avg_prec3 = AverageMeter()

        avg_recall1 = AverageMeter()
        avg_recall3 = AverageMeter()
        avg_recall6 = AverageMeter()

        avg_f1_1 = AverageMeter()
        avg_f1_3 = AverageMeter()
        avg_f1_6 = AverageMeter()

        avg_map = AverageMeter()
        avg_map1 = AverageMeter()
        avg_map3 = AverageMeter()

        list_att_weights = []
        list_logits_items = []

        list_prec1 = []
        list_prec3 = []
        list_prec6 = []

        list_recall1 = []
        list_recall3 = []
        list_recall6 = []

        list_map1 = []
        list_map3 = []
        list_map6 = []

        if self.aux_task and self.iter_epoch >= self.epoch_to_win:
          avg_main_loss = AverageMeter()
          avg_win_loss = AverageMeter()
          avg_acc = AverageMeter()

        device = None
        for x in outputs:

          avg_prec.update(x['val_prec'], x['num'])
          avg_prec1.update(x['val_prec1'], x['num'])
          avg_prec3.update(x['val_prec3'], x['num'])

          avg_recall1.update(x['val_recall1'], x['num'])
          avg_recall3.update(x['val_recall3'], x['num'])
          avg_recall6.update(x['val_recall6'], x['num'])

          avg_f1_1.update(x['val_f1_1'], x['num'])
          avg_f1_3.update(x['val_f1_3'], x['num'])
          avg_f1_6.update(x['val_f1_6'], x['num'])

          avg_map.update(x['map6'], x['num_batch'])
          avg_map1.update(x['map1'], x['num_batch'])
          avg_map3.update(x['map3'], x['num_batch'])

          list_att_weights.append(x['att_weights'])
          list_logits_items.append(x['logits_items'])

          list_prec1.extend(x['obj_list']['list_prec1'])
          list_prec3.extend(x['obj_list']['list_prec3'])
          list_prec6.extend(x['obj_list']['list_prec'])

          list_recall1.extend(x['obj_list']['list_recall1'])
          list_recall3.extend(x['obj_list']['list_recall3'])
          list_recall6.extend(x['obj_list']['list_recall6'])

          list_map1.extend(x['obj_list']['list_map1'])
          list_map3.extend(x['obj_list']['list_map3'])
          list_map6.extend(x['obj_list']['list_map6'])

          device = x['val_loss'].device

          if self.aux_task and self.iter_epoch >= self.epoch_to_win:
            avg_main_loss.update(x['val_main_loss'], x['num_batch'])
            avg_win_loss.update(x['val_loss_win'], x['num_batch_acc'])
            avg_acc.update(x['val_acc'], x['num_batch_acc'])

            avg_loss.update(self.alpha*x['val_main_loss'], x['num_batch'])
            avg_loss.update(self.beta*x['val_loss_win'], x['num_batch_acc'])
          else:
            avg_loss.update(x['val_loss'], x['num_batch'])
            avg_loss_aux.update(x['val_loss_aux'], x['num_batch'])

        tensorboard_logs = {'val_loss': torch.tensor([avg_loss.avg], device=device), 'val_prec': torch.tensor([avg_prec.avg], device=device), 
                            'val_map6': torch.tensor([avg_map.avg], device=device), 'val_map1': torch.tensor([avg_map1.avg], device=device),
                            'val_map3': torch.tensor([avg_map3.avg], device=device), 'val_prec1': torch.tensor([avg_prec1.avg], device=device), 
                            'val_prec3': torch.tensor([avg_prec3.avg], device=device), 'val_recall1': torch.tensor([avg_recall1.avg], device=device),
                            'val_recall3': torch.tensor([avg_recall3.avg], device=device), 'val_recall6': torch.tensor([avg_recall6.avg], device=device),
                            'val_f1_1': torch.tensor([avg_f1_1.avg], device=device), 'val_f1_3': torch.tensor([avg_f1_3.avg], device=device),
                            'val_f1_6': torch.tensor([avg_f1_6.avg], device=device)}

        if self.aux_task and self.iter_epoch >= self.epoch_to_win:
          tensorboard_logs['val_main_loss'] = torch.tensor([avg_main_loss.avg], device=device)
          tensorboard_logs['val_win_loss'] = torch.tensor([avg_win_loss.avg], device=device)
          tensorboard_logs['val_win_acc'] = torch.tensor([avg_acc.avg], device=device)
          print('| loss_val {:5.6f} | main_loss_val {:5.6f} | win_loss_val {:5.6f} | precision_val {:5.6f} | map6_val {:5.6f} | acc_val {:5.6f}'.format(avg_loss.avg, avg_main_loss.avg, 
                                                                                                                                                        avg_win_loss.avg, avg_prec.avg, 
                                                                                                                                                        avg_map.avg, avg_acc.avg))
        # else:
        #   print('| loss_val {:5.6f} | precision1_val {:5.6f} | precision3_val {:5.6f} | precision6_val {:5.6f} | map1_val {:5.6f} | map3_val {:5.6f} | map6_val {:5.6f} | recall1 {:5.6f} | recall3 {:5.6f} | recall6 {:5.6f} | f1_1 {:5.6f} | f1_3 {:5.6f} | f1_6 {:5.6f}'.format(
        #       avg_loss.avg, avg_prec1.avg, avg_prec3.avg, avg_prec.avg, avg_map1.avg, avg_map3.avg, avg_map.avg, avg_recall1.avg, avg_recall3.avg, avg_recall6.avg, avg_f1_1.avg, avg_f1_3.avg, avg_f1_6.avg))
        
        path_save_att = path_save_att_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))
        path_save_list_metrics = path_save_list_metrics_format.format(str(self.conf['index_split']), str(self.conf['exp']), str(self.iter_epoch))
        weights_items = {
            'list_att_weights': list_att_weights,
            'list_logits_items': list_logits_items
        }

        list_metrics = {
            'list_prec1': list_prec1, 
            'list_prec3': list_prec3,
            'list_prec6': list_prec6,
            'list_recall1': list_recall1,
            'list_recall3': list_recall3,
            'list_recall6': list_recall6,
            'list_map1': list_map1,
            'list_map3': list_map3,
            'list_map6': list_map6
        }
        save_att_weights(weights_items, path_save_att)
        save_att_weights(list_metrics, path_save_list_metrics)
        return {'avg_val_loss': avg_loss.avg,  'avg_val_prec': avg_prec.avg, 'val_map6': avg_map.avg,'progress_bar': tensorboard_logs,'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        # OPTIONAL
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        
        return self.validation_end(outputs)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        if self.optim == 'adabound':
          optimizer = adabound.AdaBound(self.model.parameters(), lr=1e-3, final_lr=0.1)
        else:
          optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        return optimizer
    
    def train_dataset(self):

      data = get_partition(self.index_split, list_trainset)
      composed = transforms.Compose([RandomSort_Part(),
                               RandomSort_Team()])
      train_dataset = LolDataset(data, transform=composed)
      return train_dataset

    @pl.data_loader
    def train_dataloader(self):
        
        train_dataset = self.train_dataset()
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)

    @pl.data_loader
    def val_dataloader(self):
        #data = list_testset[self.index_split]
        data = get_partition(self.index_split, list_testset)
        val_dataset = LolDataset(data)
        return DataLoader(val_dataset, batch_size=self.batch_size)
    
    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return self.val_dataloader()

# Config file

This config establish the model hyperparameters like:

1. index_split - num of the partition used to train.
2. optim - optimizer (could be adam or adabound).
3. batch_size - Batch size 
4. embeddings_size - model dim
5. nhead - number of attention heads 
6. nlayers - number of encoder layers
7. exp - experiment number
8. epochs - number of epoch
9. exp_name - experiment name in comet.ml
10. alpha, beta - importance weights for losses
11. win_task - enable the auxiliary task.
12. learnable_team_emb - when it is True the team embedding is learnable 
otherwise it is static.  
13. teacher_forcing - enable the teacher forcing for the auxiliary task.
14. init_epoch - indicate the epoch when the second task start. 


In [None]:
conf = {
    'index_split': 0,
    'optim': 'adam',
    'seed': 1642,
    'batch_size': 100,
    'embeddings_size': 512,
    'nhead': 2,
    'nlayers': 1, 
    'nhid': 2048, 
    'dropout': 0.5,
    'exp': 13,
    'epochs': 10,
    'exp_name': 'Main_tasks_rec_only_winners_final_prueba',
    'win_task': False,
    'alpha': 1,
    'beta': 1,
    'learnable_team_emb': True,
    'teacher_forcing': False,
    'init_epoch': 2
}

# Training and evaluation executor

In [None]:
from pytorch_lightning import Trainer

path_save = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/'.format(str(conf['index_split']), str(conf['exp']))
path_save_att_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/att_weights_{}.pkl'
path_save_list_metrics_format = '/content/gdrive/My Drive/Proyecto_RecSys/split/{}/exp_recsys/{}/checkpoints/list_metrics_{}.pkl'

model = LolRecAttModel(conf)

checkpoint_callback = get_checkpointer(path_save,'avg_val_prec')


comet_logger = CometLogger(
    experiment_name=conf['exp_name'],
    api_key = 'YOUR_KEY',
    project_name="YOUR_PROJECT_NAME",
    workspace = 'YOUR_WORKSPACE'
)
trainer = Trainer(
    gpus=[0],
    distributed_backend='dp',
    logger=comet_logger,
    max_epochs=conf['epochs'],
    checkpoint_callback=checkpoint_callback,
    show_progress_bar=False,
    gradient_clip_val=0.5
)

trainer.fit(model)   

CometLogger will be initialized in online mode
COMET INFO: ----------------------------
COMET INFO: Comet.ml Experiment Summary:
COMET INFO:   Data:
COMET INFO:     url: https://www.comet.ml/afvilla/lolnet/5094e7cd80244d62bac9be446cbfeb0b
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     sys.cpu.percent.01 [4]       : (1.0, 12.3)
COMET INFO:     sys.cpu.percent.02 [4]       : (1.0, 12.9)
COMET INFO:     sys.cpu.percent.03 [4]       : (0.9, 12.5)
COMET INFO:     sys.cpu.percent.04 [4]       : (1.0, 12.9)
COMET INFO:     sys.cpu.percent.avg [4]      : (0.975, 12.65)
COMET INFO:     sys.gpu.0.free_memory [4]    : (17061249024.0, 17061249024.0)
COMET INFO:     sys.gpu.0.gpu_utilization [4]: (0.0, 0.0)
COMET INFO:     sys.gpu.0.total_memory       : (17071734784.0, 17071734784.0)
COMET INFO:     sys.gpu.0.used_memory [4]    : (10485760.0, 10485760.0)
COMET INFO:     sys.ram.total [4]            : (27393740800.0, 27393740800.0)
COMET INFO:     sys.ram.used [4]             : (779220582

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…







HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…