In [75]:
import nltk
import re
from nltk import bigrams, trigrams
from collections import Counter, defaultdict

# Cleanup

We use all 40k queries from Million Query Track. First, we want to clean up the queries to exclude special characters or any characters that cannot be parsed by mg4j.

In [69]:
def cleanup(query, sep=' '):
    return sep.join(nltk.word_tokenize(re.sub(r'[~`!@#$%^&*()_\-+=\'";:,.<>/?]', ' ', query.strip())))

with open('/data/queries/train/mq.all', 'r', encoding='ISO-8859-1') as inp,\
     open('/data/queries/train/mq.all.clean', 'w', encoding='UTF-8') as out,\
     open('/data/queries/train/mq.all.clean.or', 'w', encoding='UTF-8') as or_out:
    for line in inp:
        print(cleanup(line), file=out)
        print(cleanup(line, sep=' OR '), file=or_out)

# Learn Model

In [81]:
with open('/data/queries/train/mq.all.clean', 'r', encoding='UTF-8') as inp:
    queries = [line.strip().split() for line in inp]

## Trigram

In [93]:
trigram_model = defaultdict(lambda: defaultdict(lambda: 0))

for query in queries:
    for w1, w2, w3 in trigrams(query, pad_right=True, pad_left=True):
        trigram_model[(w1, w2)][w3] += 1
 
for w1_w2 in trigram_model:
    total_count = float(sum(trigram_model[w1_w2].values()))
    for w3 in trigram_model[w1_w2]:
        trigram_model[w1_w2][w3] /= total_count

# Bigram

In [94]:
bigram_model = defaultdict(lambda: defaultdict(lambda: 0))

for query in queries:
    for w1, w2 in bigrams(query, pad_right=True, pad_left=True):
        bigram_model[w1][w2] += 1
 
for w1 in bigram_model:
    total_count = float(sum(bigram_model[w1].values()))
    for w2 in bigram_model[w1]:
        bigram_model[w1][w2] /= total_count

# Unigram

In [98]:
unigram_model = defaultdict(lambda: 0)

for query in queries:
    for w in query:
        unigram_model[w] += 1
        
total_count = float(sum(unigram_model.values()))
for w in unigram_model:
    unigram_model[w] /= total_count

# Generate

In [None]:
import random

def generate3(model):
    text = [None, None]
    sentence_finished = False
    while not sentence_finished:
        r = random.random()
        accumulator = .0

        for word in model[tuple(text[-2:])].keys():
            accumulator += model[tuple(text[-2:])][word]
            if accumulator >= r:
                text.append(word)
                break

        if text[-2:] == [None, None]:
            sentence_finished = True
    return ' '.join([t for t in text if t])

def generate2(model):
    text = [None]
    sentence_finished = False
    while not sentence_finished:
        r = random.random()
        accumulator = .0

        for word in model[text[-1]].keys():
            accumulator += model[text[-1]][word]
            if accumulator >= r:
                text.append(word)
                break

        if text[-1:] == [None]:
            sentence_finished = True
    return ' '.join([t for t in text if t])

# def generate1(model):
#     text = []
#     prob = 1.0  # <- Init probability
#     sentence_finished = False
#     while not sentence_finished:
#         r = random.random()
#         accumulator = .0

#         for word in model.keys():
#             accumulator += model[word]
#             if accumulator >= r:
#                 text.append(word)
#                 break

#         if text[-1:] == [None]:
#             sentence_finished = True
#     return ' '.join([t for t in text if t]), prob

with open('/data/queries/train/mq.gen.bigram', 'w', encoding='UTF-8') as out:
    for q in range(500000):
        print(generate2(bigram_model), file=out)

In [None]:
with open('/data/queries/train/mq.gen.trigram', 'w', encoding='UTF-8') as out:
    for q in range(500000):
        print(generate3(trigram_model), file=out)

# OR

In [173]:
with open('/data/queries/train/mq.gen.bigram', 'r', encoding='ISO-8859-1') as inp,\
     open('/data/queries/train/mq.gen.bigram.or', 'w', encoding='UTF-8') as out:
    for line in inp:
        print(cleanup(line, sep=' OR '), file=out)

In [174]:
with open('/data/queries/train/mq.gen.trigram', 'r', encoding='ISO-8859-1') as inp,\
     open('/data/queries/train/mq.gen.trigram.or', 'w', encoding='UTF-8') as out:
    for line in inp:
        print(cleanup(line, sep=' OR '), file=out)