In [1]:
import numpy as np
import cPickle
from collections import OrderedDict
import collections
from nltk.metrics import *
import operator

In [26]:
class PSTInfer(object):
    def __init__(self):
        self.query_to_id = {}
        self.id_to_query = []

    def _load_pickle(self, input_handle):
        self.tuple_dict = cPickle.load(input_handle)
        self.query_to_id = cPickle.load(input_handle)

    def load(self, input_path):
        print('Loading inference engine')

        input_handle = open(input_path, 'r')
        self._load_pickle(input_handle)
        input_handle.close()
        print('Preparing internal structures')

        # Transform the dict of tuples to a dict of dicts
        self.search_dict = collections.defaultdict(dict)
        for key, freq in self.tuple_dict.items():
            self.search_dict[key[:-1]][key[-1]] = freq
        
        self.tuple_dict.clear()
        self.id_to_query = [query_str for (query_str, query_id) in \
                            sorted(self.query_to_id.items(), key=operator.itemgetter(1))]

        print('Loaded inference engine')

    def _find(self, suffix, exact_match=False):
        _suffix = [self.query_to_id.get(x, -1) for x in suffix]

        # Back off to shorter suffixes,
        for i in range(len(_suffix)):
            key = tuple(_suffix[i:])
            if key in self.search_dict:
                return {'last_node': key, \
                        'is_found': i==0 and len(_suffix)==len(suffix), \
                        'empty': False, \
                        'probs': self.search_dict[key]}
        # and if nothing is found
        return {'last_node': (0,), \
                'is_found': False, \
                'empty': True, \
                'probs' : {}}

    def rerank(self, suffix, candidates, exact_match=False, no_normalize=False, fallback=False):
        probs = [ self._find(suffix) ]
        print ("start - probs ", probs)
        any_found = probs[0]['empty']
        found = probs[0]['is_found']

        #for i in range(len(suffix)):
        #    probs.append(self._find(suffix[i:]))
        # any_found = sum([p['empty'] for p in probs])
        # Fallback to prefix matches
        if any_found and fallback:
            probs = []
            last_suffix = suffix[-1].split()
            while len(last_suffix) > 1:
                last_suffix = last_suffix[:-1]
                p = self._find(suffix[:-1] + [' '.join(last_suffix)])
                if not p['empty']:
                    probs = [ p ]
            if len(probs) == 0:
                print '!!!! Warning: should this be found instead ? ', suffix
        # If we don't find anything matching the suffix
        # we just return the original candidates
        if exact_match and not found:
            return [(candidate, 0) for candidate in candidates]
        ids_candidates = map(lambda x : self.query_to_id.get(x, -1), \
                             candidates)
        candidates_found = []
        candidates_not_found = []
        n_total_queries = len(self.id_to_query)

        for (id_candidate, candidate) in zip(ids_candidates, \
                                             candidates):
            # smoothed probability
            candidate_prob = 0
            for prob in probs:
                if id_candidate in prob['probs']:
                    # smooth and renormalize.
                    if no_normalize:
                        candidate_prob = prob['probs'][id_candidate]
                    else:
                        n_remaining_queries = (n_total_queries - len(prob['probs']))
                        assert n_remaining_queries >= 0
                        freq = prob['probs'][id_candidate]
                        total_freq = sum(prob['probs'].values())
                        candidate_prob = float(freq)/total_freq
                        candidate_prob = candidate_prob/(candidate_prob
                                        + float(n_remaining_queries)/len(self.id_to_query))
                        candidate_prob = -np.log(candidate_prob)
                    break

            if candidate_prob == 0 and not no_normalize:
                candidate_prob = -np.log(1.0/n_total_queries)
            candidates_found.append((candidate, candidate_prob))
        return zip(*candidates_found)

    def suggest(self, suffix, N=100, exact_match=False):
        result = self._find(suffix)

        node = result['last_node']
        probs = result['probs']

        data = {'last_node_id' : node[0],
                'last_node_query': self.id_to_query[node[0]],
                'found' :   result['is_found'],
                'suggestions' : [],
                'scores' : []}
        if node[0] == 0 or (exact_match and not data['found']):
            return data
        # Get top N
        id_sugg_probs = sorted(probs.items(), key=operator.itemgetter(1), reverse=True)[:N]
        string_sugg_probs = [(self.id_to_query[sugg_id], sugg_score) for sugg_id, sugg_score in id_sugg_probs]
        sugg, score = map(list, zip(*string_sugg_probs))
        data['suggestions'] = sugg
        data['scores'] = score

        return data


class PST(object):
    def __init__(self, D=4, q_dict=None):
        if q_dict is None:
            self.query_dict = {'_root_' : 0}
        else:
            print("Parsing query dictionary!")
            self.query_dict = q_dict
        self.norm_dict = {}
        self.tuple_dict = {}
        self.normalized = False
        self.num_nodes = 1
        self.size = 0
        self.D = D

    def prune(self, epsilon=0.05):
        # Transform the dict of tuples to
        # a proper dict of dicts
        print('Started pruning with epsilon {}'.format(epsilon))
        search_dict = collections.defaultdict(lambda: {})
        for key, prob in self.tuple_dict.items():
            if not self.normalized:
                prob = float(prob) / self.norm_dict[key[:-1]]
            search_dict[key[:-1]][key[-1]] = prob
        self.tuple_dict.clear()
        logger.info('Checking constistency')
        for key, prob in search_dict.items():
            assert np.abs(sum(prob.values()) - 1.0) < 1e-5
        self.normalized = True

        smoothing = 1.0/len(self.query_dict)
        logger.info('{} nodes / {} smoothing'.format(len(search_dict), smoothing))

        for num, (child_key, child_probs) in enumerate(search_dict.items()):
            if num % 100000 == 0:
                logger.info('{} nodes explored'.format(num))
            # The parent of 1-length contexts is the root
            # thus we do not need to check here.
            if len(child_key) == 1:
                continue
            parent_key = child_key[1:]
            parent_probs = search_dict[parent_key]
            kl = kl_divergence(parent_probs, child_probs, smoothing)
            if kl <= epsilon:
                search_dict[child_key] = {}
        # Re-convert to tuple
        for num, (key, probs) in enumerate(search_dict.items()):
            for qid, qpr in probs.items():
                join_key = key + tuple([qid])
                assert len(join_key) >= 2
                assert join_key not in self.tuple_dict
                self.tuple_dict[join_key] = qpr
        # logger.info('{} nodes - pruning done'.format(len(self.tuple_dict)))
        self.num_nodes = len(self.tuple_dict)

    def save(self, output_path, no_normalize=False):
        print('Saving PST to {} / {} nodes.'.format(output_path, len(self.tuple_dict)))
        # Save the normalized format
        # if not self.normalized and not no_normalize:
        #    logger.info('Normalizing PST')
        #    for key, count in self.tuple_dict.iteritems():
        #        self.tuple_dict[key] = float(count) # /self.norm_dict.get(key[:-1])
        # self.norm_dict.clear()

        f = open(output_path, 'w')
        cPickle.dump(self.tuple_dict, f)
        cPickle.dump(self.query_dict, f)
        f.close()

    def add_session(self, session):
        def _update_prob(entry):
            key = entry[:-1]
            self.tuple_dict[entry] = self.tuple_dict.get(entry, 0) + 1
            self.norm_dict[key] = self.norm_dict.get(key, 0) + 1

        len_session = len(session)
        if len_session < 2:
            return
        for query in session:
            if query not in self.query_dict:
                self.query_dict[query] = len(self.query_dict)
        session = [self.query_dict[query] for query in session]
        # print("add session - session ", session, str(len_session))
        for x in range(len_session - 1):
            tgt_indx = len_session - x - 1
            for c in range(self.D):
                ctx_indx = tgt_indx - c - 1
                if ctx_indx < 0:
                    break

                entry = tuple(session[ctx_indx:tgt_indx + 1])
                # print("tuple entry ", entry)
                _update_prob(entry)

                self.num_nodes = len(self.tuple_dict)

In [2]:

# load the data molde
input_handle = open('data/bg_session.ctx_ADJ.mdl', 'r')

# load the tuple dict and the query dict
tuple_dict = cPickle.load(input_handle)
query_to_id = cPickle.load(input_handle)

In [3]:
def save_pickle_dict(a_dict, output_path):
    # Save the query to ID dictionary because we need it for
    # VMM feature construction
    # /home/jogi/git/repository/ir2_jorg/data/query_dict.pkl
    f = open(output_path, 'wb')
    cPickle.dump(a_dict, f)
# save_pickle_dict(query_to_id, '/home/jogi/git/repository/ir2_jorg/data/query_dict.pkl')
# d = cPickle.load(open('/home/jogi/git/repository/ir2_jorg/data/query_dict.pkl', 'rb'))

In [11]:
# make a inverted version of the query to id dict
id_to_query =  {v: k for k, v in query_to_id.iteritems()}

In [12]:
"""
When using enumerate you can only use this ones for the data set, you need to reload the data
before you can use emenumerate again
"""
def open_data():
    val_sessions = open('data/val_session.ctx', 'r')
    train_session = open('data/tr_session.ctx', 'r')
    bg_session = open('data/bg_session.ctx', 'r')
    
    return train_session, val_sessions, bg_session

In [13]:
# use the keys (tuples with two query id's) of the tuple dict to make a new dict 
tuple_pairs = tuple_dict.keys()

In [14]:
search_dict = collections.defaultdict(dict)
"""
make a new dict with key anchor query, as value we have a new dict with keys previous query and 
their value count 

dict[anchor_query] = { previous_query: count_value}

"""

for _tuple in tuple_pairs:
    search_dict[_tuple[1]][_tuple[0]] = tuple_dict[_tuple] 

In [15]:
"""
Func to print the suggested query id's as strings using the id_to_query map
"""
def print_suggestion(suggestions):
    for suggest in suggestions:
        print id_to_query[suggest[0]]

In [17]:
"""
Function that makes suggestions for a session

Input: session file, *.ctx
Output: dict with key:session_idx value: (target_query,anchor_query, session, suggestions)

"""

def make_suggestions(session_file, recent_queries=1,num_suggestions=20):
    # make a dict to save all the results
    suggestion_dict = {}
    c = 1
    # loop over every session in the *.ctx file
    for idx, line in enumerate(session_file):
        # queries are tab-separated 
        session = line.strip().split('\t')
        
        if len(session) >= recent_queries+1:
            target_query = session[-1] # target query is the last query Qm
            anchor_query = session[-2] # Anchor query is the query Qm-1
            context = session[:-1] # Qm-1 till Q1 are the context queries
            
            # find anchor in the background set
            if anchor_query in query_to_id:
                key =  query_to_id[anchor_query] # the key of the query in the bg-set 
                # check if target query and anchor query are in the background set
                if key in search_dict and target_query in query_to_id:
                    """
                    We could use the search dict to find all the queries that follow the anchor query 
                    in the bg set, we use this queries as suggestions
                    """
                    suggestions = search_dict[key]
                    if len(suggestions) > num_suggestions: # we need at least 20 suggestions 
                        target_key = query_to_id[target_query] # find the key of the target query
                        list_suggestions = [(key, suggestions[key] )for key in suggestions.keys()]
                        # sort list of tuples by second tuple entry which is the frequency count
                        # also reverse order so it is in descending order
                        sorted_suggestions = sorted(list_suggestions, key=lambda x: x[1])[::-1]
                        #take only the top 20 suggestions based on counts 
                        suggestions = sorted_suggestions[0:num_suggestions]
                        # final check, is the target query really in the set of suggestions? 
                        if target_key in (x[0] for x in suggestions): 
                            # we have a valid session, now we list all the suggestions and sort them
                            # save this in the dict key(idx):(target_query,anchor_query, session, suggestions)
                            suggestion_dict[idx] = (target_query,anchor_query, session, suggestions)
    return suggestion_dict

In [18]:
train_session, val_sessions, bg_session = open_data() # reload the data

In [19]:
# dicts with results
suggestion_train = make_suggestions(train_session, recent_queries=5)
save_pickle_dict(suggestion_train, '/home/jogi/git/repository/ir2_jorg/baselines/tests/tr_suggest.pkl')
print len(suggestion_train)
# suggestion_val = make_suggestions(val_sessions)

24332


In [73]:
"""
Input: session file with string queries
Output: dict with the query frequencies 
"""
def make_query_frequncies(session_file):
    query_freq = {}
    total_freq = 0
    for num, session in enumerate(session_file):
        session = session.strip().split('\t')
        for query in session:
            query_freq[query] = query_freq.get(query, 0.) + 1.
            total_freq += 1
    return query_freq

query_freq = make_query_frequncies(bg_session)

In [3]:
# save the suggestions for the training set, so we can start from here later on
suggestion_train = cPickle.load(open('../../baselines/tests/tr_suggest.pkl', 'rb'))

# load the VMM model made with Allesandro's Probabilistic Suffix Tree (PST)
# currently the context scope is limited to D=2 which means the tuple dict contains
# tuples with max lenght of 3 (so the memory span is look 2 queries ahead)
pstree = PSTInfer()
pstree.load('../../baselines/tests/bg_session.ctx_VMM.mdl')
print("======== READY ===========")

Loading inference engine
Preparing internal structures
Loaded inference engine


In [90]:

# print(len(pstree.search_dict))
c = 0
for session_key in suggestion_train.keys():
        # tuple 
        session_tuple = suggestion_train[session_key]
        target_query = session_tuple[0]
        session = session_tuple[2]
        context_queries = session_tuple[2][:-1]
        anchor_query = session_tuple[1]
        suggestions = session_tuple[3]
        print "target query ", target_query
        print "anchor_query ", anchor_query
        # print "session ", session
        candidates = []
        for idx, suggestion in enumerate(suggestions):
            candidates = []
            suggestion_id = suggestion[0]
            candidates.append(pstree.id_to_query[suggestion_id])
            candidates_new, scores = pstree.rerank(session, candidates, no_normalize=True, fallback=False)
            print "candidates ", candidates, scores[0]
        # probs = [ pstree._find(session) ]
        # print("probs ", probs)
        # print probs[0]['empty']
        # print probs[0]['is_found']
        # candidates_new, scores = pstree.rerank(session, candidates, no_normalize=True, fallback=False)
        # print "candidates ", candidates, scores
        if c > 10:
            break
        c += 1
           

target query  jesse mccartney
anchor_query  jesse mccartney
candidates  ['jesse mccartney'] 543
candidates  ['jesse mcartney'] 0
candidates  ['gefilte fish'] 0
candidates  ['myspace chatrooms'] 1
candidates  ['jensen ackles'] 1
candidates  ['kelly clarkson'] 0
candidates  ['jesse macartney'] 0
candidates  ['yahoo com'] 0
candidates  ['jessie mccartney'] 0
candidates  ['ryan phill'] 0
candidates  ['rieker shoes'] 0
candidates  ['zac efron'] 3
candidates  ['summerland'] 0
candidates  ['jessy mcartny'] 0
candidates  ['jesse mccartney com'] 1
candidates  ['jesse mccarney'] 0
candidates  ['jesse mccartnry'] 0
candidates  ['teddy geiger'] 0
candidates  ['taylor ball'] 0
candidates  ['nycotb com'] 0
target query  jesse mccartney
anchor_query  jesse mccartney
candidates  ['jesse mccartney'] 543
candidates  ['jesse mcartney'] 0
candidates  ['gefilte fish'] 0
candidates  ['myspace chatrooms'] 1
candidates  ['jensen ackles'] 1
candidates  ['kelly clarkson'] 0
candidates  ['jesse macartney'] 0
can

In [107]:
def count_letter_ngram(sentence, n=3):
    """
    How many n-grams fits in this sentenec 
    """
    if len(sentence) < n:
        return set(sentence)
    local_counts = set()
    for k in range(len(sentence.strip()) - n + 1): 
        local_counts.add(sentence[k:k+n])
    return local_counts

def matches(ng1, ng2):
    """
    For both n-gram sets how many sim elements they contain
    """
    return len(ng1 & ng2)

def n_gram_sim(query1, query2,n=3):
    """
    return n-gram similarity between two queries 
    """
    return matches(count_letter_ngram(query1, n), count_letter_ngram(query2, n))

def make_n_gram_sim_features(context_queries,suggestion):
    """
    For every suggestion make the n-gram similarity for the context queries (at most 10)
    """
    n_sim = [0] * 10
    for idx, context_query in enumerate(context_queries):
        if idx >=10:
            """
            only do this for at most 10 context queries 
            """
            break
        n_sim[idx] = n_gram_sim(suggestion, context_query,n=3)
    
    return n_sim


def get_VMM_score(session, suggestion, no_normalize=True, fallback=False):
    """
    For every suggestion determine the VMM score (variable memory Markov score)
    """
    
    _, scores = pstree.rerank(session, suggestion, no_normalize=no_normalize, fallback=fallback)
    
    return scores[0]

In [108]:
"""
Function that returens a feature vector for every suggestion 

Input: suggestion_dict
Output: per session a matrix [17,20] with the feature vectors 
"""

def make_suggestion_features(suggestion_dict, num_features=17):
    global pstree
    
    c = 0
    for session_key in suggestion_dict.keys():
        # tuple 
        session_tuple = suggestion_dict[session_key]
        target_query = session_tuple[0]
        context_queries = session_tuple[2][:-1]
        anchor_query = session_tuple[1]
        suggestions = session_tuple[3]
        VMM_scores = []
        candidates = []
        for idx, suggestion in enumerate(suggestions):
            suggestion_id = suggestion[0]
            query_string = pstree.id_to_query[suggestion_id]
            """"
            For each candidate suggestion, we count how many times it follows 
            the anchor query in the background data and add this count as a feature.
            """
            follow_anchor_count = suggestion[1]

            """
            Additionally, we use the frequency of the anchor query in the background data.
            """
            # bg_freq = query_freq[query_string]

            """
            We also add the Levenshtein distance between the anchor and the suggestion.
            """
            levenshtein_distance = edit_distance(anchor_query, query_string)

            """
            The suggestion length (characters and words)
            """
            chars_leng = len(query_string) 
            word_leng = len(query_string.split())
            
            
            """
            We add 10 features corresponding to the character n-gram similarity 
            between the suggestion and the 10 most recent queries in the context.
            """
            n_gram_sim =  make_n_gram_sim_features(context_queries, suggestion)
            
            candidates.append(query_string)
            VMM_scores.append(get_VMM_score(context_queries, [query_string]))
            
            """
            HRED Score
            """
            hred_score = None 
            
        if anchor_query == "jesse mccartney":
            print("session ", context_queries)
            print("anchor_query ", anchor_query)
            print("candidates ", candidates)
            print("VMM scores ", VMM_scores)
            break
        
    
            
        

In [109]:
make_suggestion_features(suggestion_train, num_features=17)

('session ', ['new mexico', 'jessemccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney', 'jesse mccartney'])
('anchor_query ', 'jesse mccartney')
('candidates ', ['jesse mccartney', 'jesse mcartney', 'gefilte fish', 'myspace chatrooms', 'jensen ackles', 'kelly clarkson', 'jesse macartney', 'yahoo com', 'jessie mccartney', 'ryan phill', 'rieker shoes', 'zac efron', 'summerland', 'jessy mcartny', 'jesse mccartney com', 'jesse mccarney', 'jesse mccartnry', 'teddy geiger', 'taylor ball', 'nycotb com'])
('VMM scores ', [543, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 0, 0, 0])
