#  Librerías

### Librerías pesadas
Para ejecutar solo una vez

In [1]:
import math
import torch
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import tqdm
import wandb
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!wandb login 99217068fbd71985701543b0c0064e805ac87449

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rafael/.netrc


In [3]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

# Configuración General

Variables relacionadas al procesamiento de datos y del modelo en sí

### Variables de Preprocesamiento

In [4]:

# Porcentaje para usar solo una fracción del dataset de usuario.
# si al eliminar usuarios quedan viajes o POI sin visitas, estos también
# serán eliminados
USER_FRAC = 0.65
MIN_POI_VISITS = 5
MAX_SEQUENCES_PER_USER = 100
SEQUENCE_LENGTH = 14

In [5]:
BATCH_SIZE=64
EPOCHS=100

In [6]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 80

# Gowalla Dataset

In [7]:
! ./download-gowalla.sh

Already Downloaded


In [8]:
# !mkdir -p download
# !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=0BzpKyxX1dqTYRTFVYTd1UG81ZXc" -O download/gowalla.zip && rm -rf /tmp/cookies.txt
# !unzip download/gowalla.zip -d download

### Cargar Datos

In [9]:
users    = pd.read_csv('download/gowalla/gowalla_userinfo.csv')
friends  = pd.read_csv('download/gowalla/gowalla_friendship.csv')
checkins = pd.read_csv('download/gowalla/gowalla_checkins.csv')
pois_1   = pd.read_csv('download/gowalla/gowalla_spots_subset1.csv', encoding='iso-8859-1')
pois_2   = pd.read_csv('download/gowalla/gowalla_spots_subset2.csv', encoding='iso-8859-1')
pois     = pd.concat((pois_1, pois_2), ignore_index=True)

# Preprocesamiento

### Usuarios

Revisamos la distrubución de checkins de usuarios

In [10]:
users.sample(5)

Unnamed: 0,id,bookmarked_spots_count,challenge_pin_count,country_pin_count,highlights_count,items_count,photos_count,pins_count,province_pin_count,region_pin_count,state_pin_count,trips_count,friends_count,stamps_count,checkin_num,places_num
341610,2382731,0,1,1,0,0,0,2,0,1,0,0,2,2,2,2
116459,255417,1,14,3,0,6,0,17,0,3,0,0,7,661,1076,661
2173,4025,0,1,1,0,5,0,3,1,2,0,0,2,1,1,1
73987,152935,0,39,1,5,6,28,60,0,10,9,0,43,841,1215,844
191020,1084981,1,6,2,0,6,0,8,0,2,0,0,5,33,45,33


In [11]:
users.checkin_num.describe()

count    407533.000000
mean         88.341212
std         435.982581
min           0.000000
25%           1.000000
50%          10.000000
75%          52.000000
max       46981.000000
Name: checkin_num, dtype: float64

nos quedamos con un porcentaje de los usuarios y los filtramos los usuarios segun los checkins que tengan

In [12]:
users.checkin_num.describe()

count    407533.000000
mean         88.341212
std         435.982581
min           0.000000
25%           1.000000
50%          10.000000
75%          52.000000
max       46981.000000
Name: checkin_num, dtype: float64

In [13]:
print('Current users', len(users))
users = users.sample(frac=USER_FRAC)
users = users[(users.checkin_num >= users.checkin_num.quantile(0.1)) &
              (users.checkin_num <= users.checkin_num.quantile(0.9)) &
              (users.checkin_num >= SEQUENCE_LENGTH)]
users = users[['id']]
print('Reduced users', len(users))

Current users 407533
Reduced users 93889


### Amigos

In [14]:
friends.sample(5)

Unnamed: 0,userid1,userid2
4022741,2234455,398150
3922638,2398097,2147996
2872247,2430156,2417234
2386885,2345388,2285831
3278037,2531752,2224690


### Checkins

In [15]:
checkins.sample(5)

Unnamed: 0,userid,placeid,datetime
32287102,2178887,6468486,2010-12-17T05:09:22Z
4705834,98877,21192,2010-09-04T17:19:06Z
28854476,378467,5305167,2011-05-08T14:59:35Z
8786626,4267,84141,2010-01-02T10:44:26Z
11313982,80739,84407,2010-03-24T17:54:53Z


Eliminamos los checkins de los usuarios no sampleados

In [16]:
print('Current checkins', len(checkins))
checkins = pd.merge(checkins, users, how='inner', left_on='userid', right_on='id', copy=False)[checkins.columns]
checkins = checkins.reset_index(drop=True)
print('Reduced checkins', len(checkins))

Current checkins 36001959
Reduced checkins 5491669


### POIS

In [17]:
pois.sample(5)

Unnamed: 0,id,created_at,lng,lat,photos_count,checkins_count,users_count,radius_meters,highlights_count,items_count,max_items_count,spot_categories,name,city_state,Unnamed: 5,Unnamed: 6
153640,174460,2009-12-06T18:59:56Z,-75.137764,40.231966,3.0,33.0,22.0,75.0,0.0,3.0,10.0,"[{'url': '/categories/15', 'name': 'Mexican'}]",,,,
596687,668806,2010-03-09T02:07:15Z,-122.197139,37.463542,0.0,10.0,8.0,75.0,0.0,1.0,10.0,"[{'url': '/categories/166', 'name': 'Historic ...",,,,
2728830,554897,,-112.012533,33.304666,,,,,,,,,Music Makers,"Phoenix, AZ",,
2169462,6827452,2011-01-28T10:24:03Z,103.90482,1.329476,0.0,2.0,1.0,75.0,0.0,0.0,10.0,"[{'url': '/categories/18', 'name': 'Asian'}]",,,,
118096,135426,2009-11-26T20:08:21Z,-96.985616,32.84484,0.0,7.0,4.0,35.0,0.0,0.0,10.0,"[{'url': '/categories/89', 'name': 'Craftsman'}]",,,,


Filtramos los pois si han sido visitadas pocas veces según los parámetros que definimos

In [18]:
visited_pois = pd.merge(pois, checkins, left_on='id', right_on='placeid', how='inner', copy=False)

In [19]:
visited_pois['visited_count'] = np.zeros(len(visited_pois))

visited_pois = visited_pois[['id', 'visited_count']].groupby(by='id').count()
visited_pois = visited_pois[visited_pois.visited_count >= MIN_POI_VISITS]

pois = pd.merge(pois, visited_pois, on='id', how='inner', copy=False)

Nos quedamos con sólo las columnas que nos importan

In [20]:
pois = pois[['id', 'lat', 'lng', 'visited_count']]
pois 

Unnamed: 0,id,lat,lng,visited_count
0,8932,32.927662,-97.254356,6
1,8936,39.053318,-94.591995,5
2,8938,39.052824,-94.590311,25
3,8947,37.331880,-122.029631,344
4,8956,32.942655,-97.131200,9
...,...,...,...,...
256355,7519716,18.061014,-66.721559,5
256356,7527534,13.844906,100.855976,5
256357,7529626,51.435737,-3.174222,5
256358,7533476,35.847316,-78.805891,6


Ahora eliminamos los checkins de pois que ya no existen

In [21]:
print('Current checkins', len(checkins))
checkins = pd.merge(pois, checkins, left_on='id', right_on='placeid', how='inner', copy=False)[checkins.columns]
checkins = checkins.reset_index(drop=True)
print('Reduced checkins', len(checkins))

Current checkins 5491669
Reduced checkins 3775509


Finalmente eliminamos nuevamente a los usuarios que se quedaron sin suficientes checkins

In [22]:
print('Current users', len(users))
users = pd.merge(checkins, users, how='inner', left_on='userid', right_on='id', copy=False)[users.columns].drop_duplicates()
print('Reduced users', len(users))

Current users 93889
Reduced users 93653


### Reasignación de IDs

In [23]:
users = users.reset_index(drop=True)
users['user_sid'] = users.index

pois = pois.reset_index(drop=True)
pois['place_sid'] = pois.index

### Agregar Datos

Crearemos un dataset unificado que usaremos para entrenar el modelo de los embeddings

In [24]:
pois.head()

Unnamed: 0,id,lat,lng,visited_count,place_sid
0,8932,32.927662,-97.254356,6,0
1,8936,39.053318,-94.591995,5,1
2,8938,39.052824,-94.590311,25,2
3,8947,37.33188,-122.029631,344,3
4,8956,32.942655,-97.1312,9,4


In [25]:
pois.sort_values(by='visited_count', ascending=False)

Unnamed: 0,id,lat,lng,visited_count,place_sid
5424,23519,13.689897,100.748320,5876,5424
16339,55033,59.330158,18.058079,5283,16339
44029,155746,13.746659,100.534912,5071,44029
19684,66171,60.193511,11.098251,5040,19684
17359,58725,59.650051,17.932262,4804,17359
...,...,...,...,...,...
211598,6474356,65.497450,21.911010,5,211598
118205,732513,44.267990,-88.476317,5,118205
118201,732418,42.048607,-87.685467,5,118201
118193,732344,35.683227,139.615369,5,118193


In [26]:
users.head()

Unnamed: 0,id,user_sid
0,217738,0
1,344284,1
2,1808,2
3,312345,3
4,391806,4


In [27]:
checkins.head()

Unnamed: 0,userid,placeid,datetime
0,217738,8932,2010-07-25T18:13:48Z
1,217738,8932,2010-04-20T17:56:37Z
2,344284,8932,2010-11-13T20:07:16Z
3,1808,8932,2009-05-27T20:59:09Z
4,312345,8932,2010-08-02T01:17:30Z


In [28]:
users_checkins = pd.merge(users, checkins, left_on='id', right_on='userid', copy=False).drop('id', axis=1)
users_checkins = pd.merge(users_checkins, pois[['id', 'place_sid']], left_on='placeid', right_on='id', copy=False)
users_checkins['date'] = pd.to_datetime(users_checkins['datetime'])
users_checkins = users_checkins.drop('datetime', axis=1)
users_checkins.sort_values(by=['user_sid', 'date'], inplace=True)
users_checkins.tail()

Unnamed: 0,user_sid,userid,placeid,id,place_sid,date
3775505,93652,2679347,7511294,7511294,256353,2011-06-27 04:06:53+00:00
3775496,93652,2679347,7510411,7510411,256351,2011-06-28 15:33:48+00:00
3775504,93652,2679347,7511294,7511294,256353,2011-06-29 05:01:17+00:00
3775503,93652,2679347,7511294,7511294,256353,2011-06-30 04:10:25+00:00
3775495,93652,2679347,7510411,7510411,256351,2011-06-30 16:21:21+00:00


Podemos ver que hay un problema: Muchas veces se repiten los POI consecutivos de un usuario, sin embargo eso no aporta mucha información al embedding, por lo que los eliminaremos en el siguiente paso

In [29]:
users_checkins['last_place_sid'] = users_checkins['place_sid'].shift(1)
users_checkins = users_checkins[users_checkins.place_sid != users_checkins.last_place_sid]
users_checkins = users_checkins.drop('last_place_sid', axis=1)

In [30]:
users_checkins.tail()

Unnamed: 0,user_sid,userid,placeid,id,place_sid,date
3775497,93652,2679347,7510411,7510411,256351,2011-06-23 17:21:21+00:00
3775505,93652,2679347,7511294,7511294,256353,2011-06-27 04:06:53+00:00
3775496,93652,2679347,7510411,7510411,256351,2011-06-28 15:33:48+00:00
3775504,93652,2679347,7511294,7511294,256353,2011-06-29 05:01:17+00:00
3775495,93652,2679347,7510411,7510411,256351,2011-06-30 16:21:21+00:00


In [31]:
from collections import defaultdict

In [32]:
user_poi_seq = defaultdict(lambda: [])
user_date_seq = defaultdict(lambda: [])

In [33]:
for user_sid, place_sid, date in zip(users_checkins['user_sid'], users_checkins['place_sid'], users_checkins['date']):
    user_poi_seq[user_sid].append(place_sid)
    user_date_seq[user_sid].append(date)

In [34]:
vocab_size = len(pois)
emb_dim = 16

embeddings = nn.Embedding(vocab_size, emb_dim)
embeddings.load_state_dict(torch.load(f"emb-{vocab_size}-{emb_dim}D.pt"))

def get_user_embedding(poi_sequence):
    with torch.no_grad():
        return embeddings(torch.tensor(poi_sequence)).mean(dim=0)

def get_poi_embedding(poi_sid):
     with torch.no_grad():
        return embeddings(torch.tensor(poi_sid))

In [35]:
get_user_embedding([3, 4]).shape

torch.Size([16])

In [36]:
user_embeddings = {}

for user_sid, seq in user_poi_seq.items():
    user_embeddings[user_sid] = get_user_embedding(seq)

In [37]:
user_id_to_sid = {user_id: user_sid for user_id, user_sid in zip(users_checkins['userid'], users_checkins['user_sid'])}

In [38]:
user_friends = defaultdict(lambda: set())

for user1, user2 in zip(friends['userid1'], friends['userid2']):
    if user1 not in user_id_to_sid or user2 not in user_id_to_sid:
        continue

    user_friends[user_id_to_sid[user1]].add(user_id_to_sid[user2])
    user_friends[user_id_to_sid[user2]].add(user_id_to_sid[user1])

In [39]:
friends_embeddings = defaultdict(lambda: torch.zeros(emb_dim))

for user_sid, friends_sids in user_friends.items():
    friends_embeddings[user_sid] = torch.stack([user_embeddings[f] for f in friends_sids]).mean(dim=0)


mean_friend = torch.stack([v for v in friends_embeddings.values()]).mean(dim=0)
friends_embeddings.default_factory = lambda: mean_friend

Generamos secuencias de puntos de interes visitados por usuarios de un largo predefinido para entrenar el modelo de embeddings

In [40]:
from random import sample

poi_sequence_dataset = []

for user_sid, sequence in user_poi_seq.items():
    if len(sequence) < SEQUENCE_LENGTH: continue

    user_embedding = user_embeddings[user_sid]
    friends_embedding = friends_embeddings[user_sid]
    

    candidate_indexes = list(range(0, len(sequence) - SEQUENCE_LENGTH, SEQUENCE_LENGTH))

    n_sequences = min(len(candidate_indexes), MAX_SEQUENCES_PER_USER)
    start_indexes = sample(candidate_indexes, n_sequences)

    for idx in start_indexes:
        new_seq = [get_poi_embedding(poi) for poi in sequence[idx:idx + SEQUENCE_LENGTH]]
        poi_sequence_dataset.append((new_seq, user_embedding, friends_embedding))
    

In [41]:
# 

# rands = np.random.rand(len(poi_sequence_dataset))

# total = 0
# for idx, seq in enumerate(poi_sequence_dataset):
#   if idx % 100000 == 0:
#     print(idx, '/', len(poi_sequence_dataset))

#   index = int(rands[idx] *  len(pois_list))
#   total += 33991 in seq[-5:]

# print(total / len(poi_sequence_dataset))
  

In [42]:
def split_list(input, frac=0.5):
    split_index = int(len(input) * frac)
    return input[:split_index], input[split_index:]

## Split de Datos

Realizamos separacion en train / test / split de 80 / 10 / 10

In [43]:
train_poi_sequence, rest = split_list(poi_sequence_dataset, 0.8)
test_poi_sequence, val_poi_sequence = split_list(rest)

In [44]:
def split_history_target(sequences):
    history = [ (seq[:-1], user, friends) for (seq, user, friends) in sequences]
    targets = [ seq[-1] for (seq, user, friends) in sequences]
    return history, targets

In [45]:
train_seq_history, train_seq_target = split_history_target(train_poi_sequence)
test_seq_history, test_seq_target = split_history_target(test_poi_sequence)
val_seq_history, val_seq_target = split_history_target(val_poi_sequence)

In [46]:
unique_pois = { poi for (sequence, user, friend) in poi_sequence_dataset  for poi in sequence }

print("Total pois incluidos en dataset:", len(unique_pois))
print("Porcentaje de POIs que se usaran en el modelo",  len(unique_pois) * 100 / len(pois))

Total pois incluidos en dataset: 2664060
Porcentaje de POIs que se usaran en el modelo 1039.187080667811


In [47]:
print("Total Train", len(train_seq_history))
print("Total Test ", len(test_seq_history))
print("Total Val  ", len(val_seq_history))

Total Train 152232
Total Test  19029
Total Val   19029


Ya tenemos nuestros datos listos para entrenar!

In [48]:
train_seq_history[0]

([tensor([-1.0815,  1.7382, -0.5828,  0.3109, -0.6681,  0.3705,  0.2970,  0.1471,
          -1.0182,  1.3231, -0.3098, -0.3890,  0.0528,  2.0355, -0.3125,  0.2087]),
  tensor([ 1.3917,  0.8755,  0.6181, -0.1709,  0.2053,  0.7079,  0.0043,  0.9738,
           0.1164,  0.7789, -0.5652,  0.4827,  1.0263,  0.1525, -1.2215, -0.3930]),
  tensor([ 0.0656, -0.4870, -0.8642, -1.0535,  1.1900,  0.4123, -1.4416, -0.3157,
          -0.1253,  0.7605, -0.5281,  2.8620, -0.6858, -1.0389, -0.3812, -0.8714]),
  tensor([ 1.3917,  0.8755,  0.6181, -0.1709,  0.2053,  0.7079,  0.0043,  0.9738,
           0.1164,  0.7789, -0.5652,  0.4827,  1.0263,  0.1525, -1.2215, -0.3930]),
  tensor([ 0.7794,  1.2590, -1.2216, -0.0528,  0.2550,  1.0299,  0.0704,  1.2135,
          -0.3993,  2.1099,  0.3683, -0.1485,  0.1717,  0.7308, -0.2099, -1.3154]),
  tensor([-0.8216,  0.5673, -0.6387,  0.6098,  0.5859, -1.5497,  0.3307,  0.9557,
          -0.7622, -1.4526, -0.2503, -0.6267,  0.1771,  1.4767, -0.0459,  2.2959]),
  te

# Modelo

## Embeddings

In [49]:
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size=None, emb_dim=None, hidden_dim=None, sample_length=None):
        super(EmbeddingModel, self).__init__()
        
        self.emb = nn.Embedding(vocab_size, emb_dim)
        self.hidden = nn.Linear(sample_length * emb_dim, hidden_dim)
        self.hidden_activation = nn.ReLU()
        self.output = nn.Linear(hidden_dim, vocab_size)
        self.output_activation = nn.LogSoftmax(dim=-1)

    def forward(self, xs):
        batch_size = xs.size()[0]

        # embed and merge
        xs = self.emb(xs)
        xs = torch.reshape(xs, (batch_size, -1))

        # hidden layer
        hidden = self.hidden(xs)
        hidden = self.hidden_activation(hidden)

        # output log probabilities
        output_logits = self.output(hidden)
        output_log_probs = self.output_activation(output_logits)
        
        return output_log_probs

    def predict(self, xs):
        return torch.argmax(self.forward(xs))

## BiLSTM-Attention

In [62]:
class BiLSTM_Attention(nn.Module):
    def __init__(self, emb_dim=None, hidden_dim=None):
        super(BiLSTM_Attention, self).__init__()

        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(self.emb_dim, self.hidden_dim, bidirectional=True)

    # lstm_output : [batch_size, n_step, self.hidden_dim * num_directions(=2)], F matrix
    def attention_net(self, lstm_output, final_state):
        hidden = final_state.view(-1, self.hidden_dim * 2, 1)   # hidden : [batch_size, self.hidden_dim * num_directions(=2), 1(=n_layer)]
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
        soft_attn_weights = torch.nn.functional.softmax(attn_weights, 1)
        
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        return context, None # context : [batch_size, self.hidden_dim * num_directions(=2)]

    def forward(self, X):
        # input : [batch_size, len_seq, embedding_dim]
        input = X.permute(1, 0, 2) # input : [len_seq, batch_size, embedding_dim]

        hidden_state = torch.autograd.Variable(torch.zeros(1*2, len(X), self.hidden_dim)).to(device) # [num_layers(=1) * num_directions(=2), batch_size, self.hidden_dim]
        cell_state = torch.autograd.Variable(torch.zeros(1*2, len(X), self.hidden_dim)).to(device) # [num_layers(=1) * num_directions(=2), batch_size, self.hidden_dim]



        # final_hidden_state, final_cell_state : [num_layers(=1) * num_directions(=2), batch_size, self.hidden_dim]
        output, (final_hidden_state, final_cell_state) = self.lstm(input, (hidden_state, cell_state))
        output = output.permute(1, 0, 2) # output : [batch_size, len_seq, self.hidden_dim]
        attn_output, attention = self.attention_net(output, final_hidden_state)
        return attn_output # [batch_size, hidden_dim * 2]

# Modelo Completo

In [63]:
class Model(nn.Module):
    def __init__(self, vocab_size=None, emb_dim=None, bilstm_hidden=None):
        super(Model, self).__init__()

        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.bilstm_hidden = bilstm_hidden

        self.lstm = BiLSTM_Attention(emb_dim=self.emb_dim, hidden_dim=self.bilstm_hidden).to(device)
        self.location_fc = nn.Linear(self.bilstm_hidden * 2 + self.emb_dim * 2, self.emb_dim)
        self.fc = nn.Linear(self.bilstm_hidden * 2 + self.emb_dim * 2, 2)
        self.relu = torch.nn.ReLU()

    def forward(self, X, user_emb, friends_emb):
        # input : [batch_size, len_seq]
        batch_size = X.size()[0]
        
        x = self.lstm(X) # [batch_size, hidden_dim * 2]
        x = torch.cat((x, user_emb, friends_emb), 1)
        
        emb_out = self.relu(self.location_fc(x))
        geo_out = self.relu(self.fc(x))
        
        return emb_out, geo_out

# Entrenamiento

In [64]:
from torch.utils.data import DataLoader

def dataset_to_tensors(sequences, targets):
    return [[(torch.stack(seq), user, friend), y] for (seq, user, friend), y in zip(sequences, targets)]


train_tensors = dataset_to_tensors(train_seq_history, train_seq_target)
train_dataloader = DataLoader(train_tensors, batch_size=BATCH_SIZE, shuffle=True)

test_tensors = dataset_to_tensors(test_seq_history, test_seq_target)
test_dataloader = DataLoader(test_tensors, batch_size=BATCH_SIZE, shuffle=True)

val_tensors = dataset_to_tensors(val_seq_history, val_seq_target)
val_dataloader = DataLoader(val_tensors, batch_size=BATCH_SIZE, shuffle=True)

In [65]:
class EarlyStopper:
    def __init__(self, skip_first_n=2, patience=5, min_delta=0.05):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf
        self.total_counter = 0
        self.skip_first_n = skip_first_n

    def early_stop(self, validation_loss):
        self.total_counter += 1
        if self.total_counter <= self.skip_first_n: return

        if validation_loss < self.min_validation_loss - self.min_delta:
            self.min_validation_loss = validation_loss
            self.counter = 0
            return False

        self.counter += 1
        return self.counter >= self.patience

In [75]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Trainer:
    BASE_LR = 1e-2
    EPOCHS = 100
    PRINT_EVERY = 100
    VAL_EVERY = 500

    def __init__(self, vocab_size, emb_dim, bilstm_hidden, train_dataloader, val_dataloader):
        self.vocab_size = vocab_size
        self.emb_dim = emb_dim
        self.bilstm_hidden = bilstm_hidden
        self.model = Model(vocab_size=self.vocab_size, emb_dim=self.emb_dim, bilstm_hidden=self.bilstm_hidden).to(device)

        self.optimizer = optim.Adam(self.model.parameters(), amsgrad=True, lr=self.BASE_LR)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.1, patience=2, verbose=True)
        self.stopper = EarlyStopper(patience=30, min_delta=0.01)
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader

    def _location_loss(self, y_pred, y_true):
        return torch.nn.functional.cosine_similarity(y_pred, y_true).abs().mean()
    
    def train(self):
        run = wandb.init(project="proyecto-recsys",
            name=f"model-{self.emb_dim}-hidden-{self.bilstm_hidden}",
            tags=['model'],
            config={
                "vocab_size": self.vocab_size,
                "embedding_dim": self.emb_dim,
                "bilstm_hidden": self.bilstm_hidden,
                "epochs": self.EPOCHS
            })

        wandb.watch(self.model)

        self.model.to(device)

        for epoch in range(self.EPOCHS):    
            stopped = self._train_epoch(epoch)
            
            if stopped:
                print(f"Early stopping at epoch {epoch}")
                break

        run.finish()
        return self.model

    def _train_epoch(self, epoch):
        print(f"Training model on epoch {epoch}")
        i = 1
        losses = []
        for (seqs, users, friends), ys in self.train_dataloader:
            self.optimizer.zero_grad()
            
            seqs = seqs.to(device)
            users = users.to(device)
            friends = friends.to(device)
            ys = ys.to(device)
            emb_out, loc_out = self.model(seqs, users, friends)
        
            output = self._location_loss(emb_out, ys)
            output.backward()
            self.optimizer.step()

            losses.append(output.item())

            if i % self.PRINT_EVERY == 0:
                avg_loss = sum(losses) / len(losses)
                losses = []
                wandb.log({"train_loss": avg_loss, "epoch": epoch, "step": i})
            
            i += 1

        print("\nEvaluating model on val set ...")
        self.model.eval()

        with torch.no_grad():
            val_iter = iter(self.val_dataloader)

            val_losses = []            
            for (seqs, users, friends), ys in tqdm.tqdm(val_iter, total=len(val_iter)):
                seqs = seqs.to(device)
                users = users.to(device)
                friends = friends.to(device)
                ys = ys.to(device)
                emb_out, loc_out = self.model(seqs, users, friends)
                output = self._location_loss(emb_out, ys)
                val_losses.append(output.item())

        avg_val_loss = sum(val_losses) / len(val_losses)
        print(avg_loss)

        wandb.log({"val_loss": avg_val_loss, "epoch": epoch, "step": i, 
                    "lr" : self.optimizer.param_groups[0]['lr']})
        
        self.scheduler.step(avg_val_loss)
        stop = self.stopper.early_stop(avg_val_loss)
        if stop: return True
        
        self.model.train()


### Parameters

In [None]:
for hidden_dim in [8, 16, 32, 64, 128, 256]:
    trainer = Trainer(vocab_size=len(pois), emb_dim=emb_dim, bilstm_hidden=8,
                    train_dataloader=train_dataloader, val_dataloader=val_dataloader)
    trainer.train()

In [89]:
from scipy.spatial import KDTree

In [82]:
vector_to_poi_sid = { get_poi_embedding(sid) : sid for sid in pois['place_sid'] }

In [93]:
pois_vectors = torch.stack(tuple(vector_to_poi_sid.keys()))
tree = KDTree(pois_vectors)

In [159]:
def model_accuracy(model, test_dataloader):
    model = model.to(device)
    model = model.eval()
    total = 0
    correct = 0
    for (seqs, users, friends), ys in tqdm.tqdm(test_dataloader):
        seqs = seqs.to(device)
        users = users.to(device)
        friends = friends.to(device)
        ys = ys.detach().cpu()
        with torch.no_grad():
            emb_out, loc_out = model(seqs, users, friends)

        for idx, vect in enumerate(emb_out.detach().cpu()):
            _, index = tree.query(vect)
            closest = pois_vectors[index]
            correct += torch.equal(closest,  ys[idx])

        total += len(ys)

    return correct / total

In [160]:
model_accuracy(q.model, test_dataloader)

100%|█████████████████████████████████████████████████████████████████████████████| 298/298 [02:31<00:00,  1.96it/s]


0.0

Lamentablemente, nuestro modelo no muestra buen performance. Esto se puede deber a que la implementación se desvió del modelo implementado por los autores originalmente, y se tomaron soluciones suboptimas al momento de diseñar.