In [1]:
from src.Common import parse_cmd_args, print_b
from src.datasets.text_datasets.RestaurantDataset import RestaurantDataset
from src.datasets.text_datasets.AmazonDataset import AmazonDataset
from src.datasets.text_datasets.POIDataset import POIDataset

from src.models.text_models.ATT2ITM import ATT2ITM

from bokeh.plotting import ColumnDataSource, figure, output_file, save, show
from bokeh.models import LinearColorMapper, Span, transforms
from sklearn.manifold import TSNE

import tensorflow as tf
import pandas as pd
import numpy as np
import nvgpu
import json

from bokeh.layouts import gridplot
from bokeh.resources import INLINE
from bokeh.io import output_notebook, export_svg
output_notebook(INLINE)

2023-05-24 12:56:41.264976: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [3]:
gpu = int(np.argmin(list(map(lambda x: x["mem_used_percent"], nvgpu.gpu_info())))) 

model = "ATT2ITM"
dataset = "restaurants".lower().replace(" ", "") 
subset = "newyorkcity".lower().replace(" ", "")

best_model = pd.read_csv("models/best_models.csv")
best_model = best_model.loc[(best_model.dataset == dataset) & (best_model.subset == subset) & (best_model.model == model)]["model_md5"].values[0]
model_path = f"models/{model}/{dataset}/{subset}/{best_model}"
with open(f'{model_path}/cfg.json') as f: model_config = json.load(f)
dts_cfg = model_config["dataset_config"]
with open(f'{model_path}/cfg.json') as f: model_config = json.load(f)
mdl_cfg = {"model": model_config["model"], "session": {"gpu": gpu, "mixed_precision": False, "in_md5": False}}

print_b(f"Loading best model: {best_model}")

if dataset == "restaurants":
    # text_dataset = RestaurantDataset(dts_cfg, load=["TRAIN_DEV", "TEXT_TOKENIZER", "TEXT_SEQUENCES", "WORD_INDEX", "VOCAB_SIZE", "MAX_LEN_PADDING", "N_ITEMS", "FEATURES_NAME", "BOW_SEQUENCES"])
    text_dataset = RestaurantDataset(dts_cfg)
elif dataset == "pois":
    text_dataset = POIDataset(dts_cfg)
elif dataset == "amazon":
    text_dataset = AmazonDataset(dts_cfg)
else:
    raise ValueError

model = ATT2ITM(mdl_cfg, text_dataset)
model.train(dev=True, save_model=True) # Cargar el modelo

[94mLoading best model: 8b74c00371d98f236fb265dd46b234c4[0m


[nltk_data] Downloading package stopwords to /home/pperez/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Model: "ATT2ITM_0"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 200)]        0           []                               
                                                                                                  
 input_4 (InputLayer)           [(None, 1985)]       0           []                               
                                                                                                  
 embedding_1 (Embedding)        (None, 200, 384)     4634112     ['input_3[0][0]']                
                                                                                                  
 in_rsts (Embedding)            (None, 1985, 384)    762240      ['input_4[0][0]']                
                                                                                          

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


None
[92m[INFO] Model already trained. Loading weights...[0m


In [4]:
wrd_embs = tf.keras.models.Model(inputs=[model.MODEL.input[0]], outputs=[model.MODEL.get_layer("word_emb").output])
rst_embs = tf.keras.models.Model(inputs=[model.MODEL.input[1]], outputs=[model.MODEL.get_layer("rest_emb").output])

rst_embs = rst_embs.predict([list(range(model.DATASET.DATA["N_ITEMS"]))], verbose=0).squeeze()
rest_names = model.DATASET.DATA["TRAIN_DEV"][["id_item", "name"]].sort_values("id_item").drop_duplicates().name.values.tolist()

word_names = np.array(["UNK"]+list(model.DATASET.DATA["WORD_INDEX"].keys()))
wrd_embs = wrd_embs.predict(list(range(model.DATASET.DATA["VOCAB_SIZE"])), verbose=0).squeeze()



In [5]:
tsne_r = TSNE(n_components=2, learning_rate="auto", init="pca", metric="cosine", random_state=2032)
tsne_w = TSNE(n_components=2, learning_rate="auto", init="pca", metric="cosine", random_state=2032)
rst_tsne = tsne_r.fit_transform(rst_embs)
wrd_tsne = tsne_w.fit_transform(wrd_embs)

In [64]:
from scipy.spatial.distance import cdist
import bokeh

plot_size = 200

# Colorear los puntos en función de la distancia coseno a un restaurante dado.
words = ["i", "fresh", "cheap", "hotdog", "tacos", "pizza", "pasta", "burger"]
words = ["croissant", "sushi", "i", "pizza"]
plots = []

for word in words:
    word_id = np.argwhere(word_names==word)[0][0]
    word_colors = cdist([wrd_embs[word_id]], wrd_embs, metric="cosine")[0]
    word_colors = ((word_colors-word_colors.min())/(word_colors.max()-word_colors.min()))
    word_alpha = 1-((word_colors-word_colors.min())/(word_colors.max()-word_colors.min()))

    print(word_names[word_id], word_names[np.argsort(word_colors)][1:4])

    # Se ordenan los puntos según el color para que se dibujen los últimos los más cercanos y evitar superposición
    sb = np.argsort(-word_colors) # Sort indexes 
    data_x = wrd_tsne[:, 0][sb]
    data_y = wrd_tsne[:, 1][sb]
    data_desc = np.array(word_names)[sb]
    data_col = word_colors[sb]
    data_alpha = word_alpha[sb]

    source_w = ColumnDataSource(data=dict(x=data_x, y=data_y, desc=data_desc, col=data_col, alpha=data_alpha))

    TOOLTIPS = [("Name", "@desc"), ("Color", "@col")]
    lc = LinearColorMapper(palette=bokeh.palettes.OrRd[9], low=min(word_colors), high=max(word_colors))
    p = figure(width=plot_size, height=plot_size, tooltips=TOOLTIPS, title=f"{word_names[word_id].title()}", output_backend="svg")
    p.scatter('x', 'y', size=5, source=source_w, line_color=None, fill_color={"field": "col", "transform": lc}, fill_alpha = "alpha")
    p.axis.visible = False
    plots.append(p)

# put all the plots in an HBox
max_cols = 4
plots = [plots[i:i+max_cols] for i in range(0, len(plots), max_cols)]
p = gridplot(plots)
show(p)

#export_svg(p, filename=f"tsne_words.svg")

croissant ['patisserie' 'latte' 'starbuck']
sushi ['sushis' 'japanese' 'nigiri']
i ['my' 'there' 'that']
pizza ['pepperoni' 'margherita' 'kitchenette']


['tsne_words.svg']

In [180]:
word_names[np.argsort(word_colors)][1:4]

array(['margherita', 'pizzeria', 'pepperoni'], dtype='<U22')

In [8]:
from scipy.spatial.distance import cdist
import bokeh

plot_size = 200

# Colorear los puntos en función de la distancia coseno a un restaurante dado.
items = [865, 418, 728, 216, 510]
plots = []
for item_id in items:
    item_colors = cdist([rst_embs[item_id]], rst_embs, metric="cosine")[0]
    item_colors = ((item_colors-item_colors.min())/(item_colors.max()-item_colors.min()))
    item_alpha = 1-((item_colors-item_colors.min())/(item_colors.max()-item_colors.min()))

    print(rest_names[item_id], pd.unique(np.array(rest_names)[np.argsort(item_colors)])[1:10])

    # Se ordenan los puntos según el color para que se dibujen los últimos los más cercanos y evitar superposición
    sb = np.argsort(-item_colors) # Sort indexes 
    data_id = np.array(list(range(len(rest_names))))[sb]
    data_x = rst_tsne[:, 0][sb]
    data_y = rst_tsne[:, 1][sb]
    data_desc = np.array(rest_names)[sb]
    data_col = item_colors[sb]
    data_alpha = item_alpha[sb]

    # Graph ------------
    source_r = ColumnDataSource(data=dict(id=data_id, x=data_x, y=data_y, desc=data_desc, col=data_col, alpha=data_alpha))

    TOOLTIPS = [("Name", "[@id] @desc"), ("Color", "@col")]
    lc = LinearColorMapper(palette=bokeh.palettes.OrRd[9], low=min(item_colors), high=max(item_colors))
    p = figure(width=plot_size, height=plot_size, tooltips=TOOLTIPS, title=f"{rest_names[item_id]}", output_backend="svg")
    p.scatter('x', 'y', size=5, source=source_r, line_color=None, fill_color={"field": "col", "transform": lc}, fill_alpha = "alpha")
    p.axis.visible = False
    plots.append(p)

# put all the plots in an HBox
max_cols = 4
plots = [plots[i:i+max_cols] for i in range(0, len(plots), max_cols)]
p = gridplot(plots)

show(p)
#export_svg(p, filename=f"tsne_items.svg")                                             

99 Cent Fresh Pizza ['99 Cents Fresh Pizza' '2 Bros Pizza' "Mamoun's Falafel" 'Gotham Pizza'
 "Joe's Pizza" '53rd & 6th Halal' "Big Nick's Pizza Joint"
 'Little Italy Pizza' 'Sacco Pizza']
Planet Hollywood ['Buca di Beppo Italian Restaurant' 'Hard Rock Cafe' "Applebee's"
 'TGI Friday’s' "TGI Friday's" 'TGI Fridays' 'Ruby Tuesday' 'Ihop'
 'Starbucks']
Starbucks ["McDonald's" 'Stumptown Coffee Roasters' "Dunkin' Donuts" 'Grom'
 'Ground Central Coffee Company' 'Caffe Roma Pastry' 'Amorino'
 'Le Pain Quotidien' 'Magnolia Bakery']
Wolfgang's Steakhouse ["Wolfgang's Steakhouse--Midtown 54th Street"
 "Wolfgang's Steakhouse - Tribeca" "Bobby Van's Steakhouse - 54th Street"
 'Benjamin Steakhouse' 'Novita' "Bobby Van's Steak House"
 'Benjamin Steakhouse Prime' "Michael Jordan's The Steak House N.Y.C."
 'Keens Steakhouse']
Pastrami Queen ['JG Melon' "Luigi's Gourmet Pizza"
 '2nd Avenue Deli and 2nd Floor Bar & Essen' '2nd Avenue Deli' 'ruchi'
 'JG Melons' "Vanessa's Dumpling House" 'Madangsui Kor

In [None]:
pd.unique(np.array(rest_names)[np.argsort(item_colors)])[1:4]

In [6]:
[(idr, n) for idr, n in enumerate(rest_names) if "Pastrami Queen" in n]

[(510, 'Pastrami Queen')]