<a href="https://colab.research.google.com/github/pgosar/AlphaHacks/blob/main/BrandEmbedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fire
!pip install wikipedia

In [None]:
import io
import os
import shutil
import re
import string
import tensorflow as tf
import numpy as np

import logging
import wikipedia
import random

import nltk
from nltk.corpus import stopwords

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

In [None]:
glove_url = "http://nlp.stanford.edu/data/glove.6B.zip"
glove_embeddings = tf.keras.utils.get_file("glove.6B.zip", glove_url, extract = True, cache_dir='.',
                                  cache_subdir='')

Downloading data from http://nlp.stanford.edu/data/glove.6B.zip


In [None]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [None]:
GLOVE_PATH = "/content/glove.6B.200d.txt"

In [None]:
nltk.download('stopwords')

In [None]:
class GloveEmbeddings:
    GLOVE_DIR = GLOVE_PATH
    EMBEDDING_DIM = 200

    @staticmethod
    def get_dict_word_embedding(path=GLOVE_DIR, embedding_dim=EMBEDDING_DIM):
        f = open(path.format(dim=embedding_dim))

        word2emb = dict()
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            word2emb[word] = coefs
        f.close()
        return word2emb

In [None]:
GloveEmbeddings.get_dict_word_embedding()

In [None]:
IGNORE_WORDS = set(stopwords.words())

In [None]:
_brand_list_fpath = "/content/data/brand_list2.txt"
DEFAULT_SET_BRANDS = set()
with open(_brand_list_fpath) as fp:
    for line in fp.readlines():
        line = line.strip()
        if not line:
            continue
        DEFAULT_SET_BRANDS.add(line)

In [None]:
DEFAULT_BRAND_EMB_SAVE_FPATH = 'data/brand_emb.json'
ENV_EMBEDDING_GLOVE_6B_FPATH = 'data/glove_embeddings/glove.6B.200d.txt'

In [None]:
DEFAULT_SET_BRANDS

In [None]:
logger = logging.getLogger(__name__)

In [None]:
import fire
import json
import codecs

from string import punctuation

In [None]:
def build(set_brands=DEFAULT_SET_BRANDS, fpath_save=DEFAULT_BRAND_EMB_SAVE_FPATH, set_ignore_words=IGNORE_WORDS):
  skipped_brands = []
  disamb_brands = []

  logger.info("building knowledge base")
  dict_brand_name_emb = dict()

  wrd2emb = GloveEmbeddings.get_dict_word_embedding()

  for brand_name in set_brands:
    logger.info("New loop with {}".format(brand_name))
    #wiki_obj = wikipedia.page(brand_name, auto_suggest=False)

    try:
      wiki_obj = wikipedia.page(brand_name, auto_suggest=False)
    except wikipedia.DisambiguationError as e:
      disamb_brands.append(brand_name)
      s = random.choice(e.options)
      print("EXCEPTION TRIGGERED, WITH: ", e.options)
      print("EXCEPTION TRIGGERED, TRYING WITH: ", s)
      wiki_obj = wikipedia.page(s, auto_suggest=False)
    except wikipedia.PageError as pe:
      skipped_brands.append(brand_name)
      print("Page error with ", brand_name)
      continue

    logger.info("{brand_name}: {wiki_url}".format(brand_name=brand_name, wiki_url=wiki_obj.url))
    text = wiki_obj.content
      
    text_tokens = text.split()
    list_emb = list()
    for token in text_tokens:
      token = token.lower()
      token = token.strip(punctuation)
      if token in set_ignore_words:
        #logger.info("Token ignored: {}".format(token))
        continue
          
      emb = wrd2emb.get(token, None)
      if emb is not None:
        list_emb.append(emb)
          
    brand_array = np.array(list_emb)
    brand_emb = brand_array.mean(axis=0)
      
    dict_brand_name_emb[brand_name] = brand_emb.tolist()

  logger.info("saving knowledge base to: `{}`".format(fpath_save))
  with codecs.open(fpath_save, 'w', encoding='utf-8') as fp:
    json.dump(dict_brand_name_emb, fp, separators=(',', ':'), indent=4)

  logger.info("knowledge base compiled")
  print("knowledge base compiled")



In [None]:
build()

knowledge base compiled


In [None]:
#TEST CODE FOR WIKI API

for brand_name in set_brands:
      logger.info("New loop with {}".format(brand_name))
      #wiki_obj = wikipedia.page(brand_name, auto_suggest=False)

      try:
        wiki_obj = wikipedia.page(brand_name, auto_suggest=False)
      except wikipedia.DisambiguationError as e:
        try:
          s = e.options[0]
          print("EXCEPTION TRIGGERED, WITH: ", e.options)
          print("EXCEPTION TRIGGERED, TRYING WITH: ", s)
          wiki_obj = wikipedia.page(s, auto_suggest=False)
        except wikipedia.DisambiguationError as e:
          
          continue
      except wikipedia.PageError as pe:
        print("Page error with ", brand_name)
        continue

        logger.info("{brand_name}: {wiki_url}".format(brand_name=brand_name, wiki_url=wiki_obj.url))
        text = wiki_obj.content

In [None]:
os.path.getsize("/content/data/brand_emb.json")/1000000

0.23748

In [None]:
wikipedia.summary("501".encode("ascii", "ignore"), auto_suggest = False)

'Year 501 (DI) was a common year starting on Monday (link will display the full calendar) of the Julian calendar. At the time, it was known as the Year of the Consulship of Avienus and Pompeius (or, less frequently, year 1254 Ab urbe condita). The denomination 501 for this year has been used since the early medieval period, when the Anno Domini calendar era became the prevalent method in Europe for naming years.\n\n'

In [None]:
import operator
def query(target_brand_name, top_n=None, kb_fpath=DEFAULT_BRAND_EMB_SAVE_FPATH, dict_kb=None):

    if type(target_brand_name) == str:
        target_brand_name = str(target_brand_name)

    if dict_kb is None:
        with codecs.open(kb_fpath, encoding='utf-8') as fp:
            dict_kb = json.load(fp)

    target_brand_emb = np.array(dict_kb[target_brand_name])

    dict_brand_name_emb_distance = dict()
    for candidate_brand_name, candidate_emb in dict_kb.items():

        if candidate_brand_name == target_brand_name:
            continue

        emb_dist = np.linalg.norm(target_brand_emb - np.array(candidate_emb))
        dict_brand_name_emb_distance[candidate_brand_name] = emb_dist

    sorted_dict = sorted(dict_brand_name_emb_distance.items(), key=operator.itemgetter(1))

    if top_n:
        sorted_dict = sorted_dict[: top_n]

    logger.debug("{}: {}".format(target_brand_name, sorted_dict))

    return sorted_dict


In [None]:
query("Nestle", top_n = 5)

[("Kellogg's", 0.6958630893555775),
 ('The Hershey Company', 0.7453138843880832),
 ('Amazon Inc', 0.8028214759360586),
 ('Zara (retailer)', 0.8373109468850889),
 ('Gap Inc.', 0.8650837031828886)]

In [None]:
def query_list(list_target_brand_name, top_n=None, kb_fpath=DEFAULT_BRAND_EMB_SAVE_FPATH):

    with codecs.open(kb_fpath, encoding='utf-8') as fp:
        dict_kb = json.load(fp)

    dict_results = dict()
    for idx, target_brand_name in enumerate(list_target_brand_name, start=1):
        sorted_candidate_brands = query(target_brand_name, top_n=top_n, dict_kb=dict_kb)

        dict_results[target_brand_name] = sorted_candidate_brands

    return dict_results