In [1]:
# coding=utf-8
import models
import itertools
from copy import deepcopy
from collections import namedtuple
import numpy as np

# 4 seems to be the general idea. Could be different. Need to find out
d = distortion_limit = 4

# Start working on a single sentence generalize later
sentence_orig = "honorables sénateurs , que se est - il passé ici , mardi dernier ?"
sentence = sentence_orig.split(' ')

# Get a translation for each set of words using tm
tm = models.TM("../data/tm", 1)
lm = models.LM("../data/lm")

sc_P = []

for i in range(0, len(sentence)):
    for j in range(i, len(sentence)):
        if i == j:
            subset = (sentence[i],)
        else:
            subset = tuple(sentence[i:j + 1])
        try:
            tm_output = tm[subset]
            sc_P.append((i, j, tm_output[0]))
        except KeyError:
            continue

Reading translation model from ../data/tm...
Reading language model from ../data/lm...


In [2]:
sc_P

[(0, 0, phrase(english='honourable', logprob=0.0)),
 (1, 1, phrase(english='senators', logprob=-0.124938733876)),
 (1, 2, phrase(english='senators ,', logprob=-0.176091253757)),
 (2, 2, phrase(english=',', logprob=-0.0358480773866)),
 (2, 3, phrase(english=', that', logprob=-0.195274576545)),
 (2, 4, phrase(english=', that if', logprob=0.0)),
 (2, 7, phrase(english=', what', logprob=0.0)),
 (2, 8, phrase(english=', what happened', logprob=0.0)),
 (3, 3, phrase(english='that', logprob=-0.154812589288)),
 (3, 4, phrase(english='that if', logprob=-0.412180453539)),
 (3, 5, phrase(english='what has', logprob=-0.301030009985)),
 (3, 7, phrase(english='what', logprob=0.0)),
 (3, 8, phrase(english='what happened', logprob=-0.176091253757)),
 (4, 4, phrase(english='is', logprob=-0.628903687)),
 (4, 5, phrase(english='has', logprob=-0.39545121789)),
 (4, 6, phrase(english='did', logprob=0.0)),
 (4, 7, phrase(english='come he did', logprob=-0.301030009985)),
 (4, 8, phrase(english='has happened'

# Let's start building helper functions

In [3]:
# define a numedtuple which will be the state vector q
State = namedtuple('State', "e1 e2 bitstring r alpha")

def ph(sc_P, q, d=4):
    ph_states = []
    for state in sc_P:
        flag = True  # we assume it as a valid state
        s = state[0]
        t = state[1]

        orig_bitstring = q.bitstring

        '''
        Step 1: Ensure bit string is not overlapped
        '''
        for _num in range(s, t+1):
            if orig_bitstring[_num] != 0:
                flag = False

        '''
        Step 2: Ensure distortion limit is still obeyed
        '''

        r = q.r

        if not (abs(r + 1 - s) <= d):
            flag = False

        if flag == True:
            ph_states.append(state)

    return ph_states


def next(q, p, eta=-1):
    s = p[0]
    t = p[1]

    num_translated_words = len(p[2].english.split(' '))

    if num_translated_words >= 2:
        last_word = p[2].english.split(' ')[-1]
        second_last = p[2].english.split(' ')[-2]
    elif num_translated_words == 1:
        last_word = p[2].english.split(' ')[-1]
        second_last = q.e2

    # Calculate new_bitstring

    new_bitstring = deepcopy(q.bitstring)

    for i in range(s, t+1):
        # just a safeguard..check whether earlier these values were zeros or not

        if q.bitstring[i] == 1:
            print('MAJOR ERROR!')
            print('The Ph(q) function has issues if you can read this...')
        new_bitstring[i] = 1

    '''
    Calculate new logprob (alpha value)
    '''

    # CALCULATE LANGUAGE MODEL PROBABILITY

    e1 = q.e1
    e2 = q.e2
    prob = 0
    if len(p[2].english.split(' ')) == 1: # just a single word
#         print('Single english word')
        # first try bigram probability
        word = p[2].english.split(' ')[0]
#         print('\n\tFinding probability of {}, {}, {}'.format(e1, e2, word))
        try:
    #         print('Finding probability of {}, {}, {}'.format(e1, e2, word))
            # the index 1 contains the logprob, 0 index contains the repeated state
            prob += lm.score((e1, e2), word)[1]
        except KeyError:
#             print("\t\tCouldn't find trigram prob..backing off to bigram")
            try:
#                 print('\t\t\tFinding probability of {}, {}'.format(e2, word))
                prob += lm.score((e2, ), word)[1]
            except KeyError:
                prob += lm.score((), word)[1]
    else: # there are multiple words
#         print('Multiple english words')
        for _num in range(0, len(p[2].english.split(' '))):
            word = p[2].english.split(' ')[_num]
#             print('\n\tFinding probability of {}, {}, {}'.format(e1, e2, word))
            try:
                prob += lm.score((e1, e2), word)[1]
            except KeyError:
#                 print("\t\tCouldn't find trigram prob..backing off to bigram")
                try:
#                     print('\t\t\tFinding probability of {}, {}'.format(e2, word))
                    prob += lm.score((e2, ), word)[1]
                except KeyError:
                    prob += lm.score((), word)[1]

            e1 = deepcopy(e2)
            e2 = deepcopy(word)

    # CALCULATE G(x) LOGPROB

    g_x = p[2].logprob

    # CALCULATE DISTORTION VALUE
    dist_val = eta * abs(q.r + 1 - s)

    # UPDATE ALPHA VALUE

    new_alpha = q.alpha + g_x + prob + dist_val

    new_state = State(e1=second_last, e2=last_word, bitstring=new_bitstring, r=t, alpha=new_alpha)
    return new_state


def eq(q1, q2):
    if q1.e1 != q2.e1:
        return False
    elif q1.e2 != q2.e2:
        return False
    elif q1.bitstring != q2.bitstring:
        return False
    elif q1.r != q2.r:
        return False
    else:
        return True


def add(Q_main, index, q_new, q, valid_phrase, back_pointer):
    for q_dd in Q_main[index]: # q double dash
        if eq(q_new, q_dd) == True:
            # print('Found similar q..')
            if q_new.alpha > q_dd.alpha: # score of new thing is greater than older
                # print('Changing the older q for the newer one..')
                Q_main[index].remove(q_dd)
                Q_main[index].append(q_new)
                back_pointer.append((q_new, q, valid_phrase))
                return
            else:
                return
    Q_main[index].append(q_new)
    back_pointer.append((q_new, q, valid_phrase))
    return


def beam(Q_main, index, beam_width=5):
    running_max = -10000

    # find the highest scoring state in the set
    for q in Q_main[index]:
        if q.alpha > running_max:
            running_max = q.alpha
            curr_max_state = q

    final_beam = running_max - beam_width

    final_list = []
    for q in Q_main[index]:
        if q.alpha >= final_beam:
            final_list.append(q)

    return final_list


# Run the code

In [4]:
import sys

In [37]:
bitstring = [0]*len(sentence)


q0 = State('<s>', '<s>', bitstring, 0, 0)
Q_main = {k: [] for k in range(len(sentence)+1)}
Q_main[0] = [q0]
back_pointer = []

for i in range(0, len(sentence)-1):
    sys.stdout.write('.')
    for q in beam(Q_main, i, beam_width=12):
        for valid_phrase in ph(sc_P, q, d=4):
            q_new = next(q, valid_phrase)
            index = len(np.nonzero(q_new.bitstring)[0])
            assert index > i
            add(Q_main, index, q_new, q, valid_phrase, back_pointer)

.............

In [38]:
def getBestPerformerState(Q_main):
    running_max = -10000
    best_state = None
    
    for state in Q_main[14]:
        if state.alpha > running_max:
            running_max = state.alpha
            best_state = state
            
    return best_state

In [39]:
Q_main[14]

[State(e1='last', e2='Tuesday', bitstring=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], r=12, alpha=-35.04663133071855)]

In [40]:
end_point = getBestPerformerState(Q_main)

In [41]:
for entry in back_pointer:
    print(entry[0])

State(e1='<s>', e2='honourable', bitstring=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=0, alpha=-5.573022)
State(e1='<s>', e2='senators', bitstring=[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=1, alpha=-5.397384733876001)
State(e1='senators', e2=',', bitstring=[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=2, alpha=-6.905608253757)
State(e1='<s>', e2=',', bitstring=[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=2, alpha=-4.2746270773866)
State(e1=',', e2='that', bitstring=[0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=3, alpha=-6.102653576545)
State(e1='that', e2='if', bitstring=[0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=4, alpha=-8.07728)
State(e1=',', e2='what', bitstring=[0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], r=7, alpha=-6.522851)
State(e1='what', e2='happened', bitstring=[0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], r=8, alpha=-8.213728)
State(e1='<s>', e2='that', bitstring=[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=3, alpha=-3.838979589288)
State(e1='that', e2='if', bitst

State(e1='it', e2='past', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], r=8, alpha=-19.346814952787)
State(e1='it', e2='here', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], r=9, alpha=-18.770536843157)
State(e1='here', e2=',', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0], r=10, alpha=-19.8354785949)
State(e1='it', e2=',', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0], r=10, alpha=-18.2244420359156)
State(e1='last', e2='Tuesday', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0], r=12, alpha=-23.155805758529002)
State(e1='it', e2='Tuesday', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], r=11, alpha=-22.78245762218647)
State(e1='last', e2='Tuesday', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0], r=12, alpha=-23.586423512286004)
State(e1='it', e2='last', bitstring=[0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0], r=12, alpha=-22.59510650272)
State(e1=',', e2='honourable', bitstring=[1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=0, alpha=-23.7722335

State(e1='past', e2=',', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0], r=10, alpha=-19.507552974004597)
State(e1='last', e2='Tuesday', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0], r=12, alpha=-24.301185696618)
State(e1='past', e2='Tuesday', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0], r=11, alpha=-21.54174526027547)
State(e1='last', e2='Tuesday', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0], r=12, alpha=-24.294272050375)
State(e1='past', e2='last', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0], r=12, alpha=-23.302955040808996)
State(e1='past', e2='?', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1], r=13, alpha=-22.845744159332547)
State(e1='here', e2='it', bitstring=[0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0], r=7, alpha=-22.360964722353)
State(e1='here', e2='past', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0], r=8, alpha=-23.126872281246)
State(e1='last', e2='Tuesday', bitstring=[0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0], r=12, alp

State(e1='here', e2='last', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0], r=12, alpha=-20.6683022161906)
State(e1='here', e2='?', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1], r=13, alpha=-21.20747823471415)
State(e1=',', e2='it', bitstring=[0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0], r=7, alpha=-21.321592800123202)
State(e1=',', e2='past', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0], r=8, alpha=-22.5319812590162)
State(e1=',', e2='here', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0], r=9, alpha=-20.7798404493862)
State(e1=',', e2='Tuesday', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0], r=11, alpha=-19.39908992841567)
State(e1=',', e2='last', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0], r=12, alpha=-19.4471718089492)
State(e1=',', e2='?', bitstring=[0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1], r=13, alpha=-20.37892672747275)
State(e1='that', e2='honourable', bitstring=[1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], r=0, alpha=-23.7830148666746)

State(e1='last', e2='here', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0], r=9, alpha=-26.909133418957204)
State(e1='last', e2='Tuesday', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], r=11, alpha=-22.567234597986673)
State(e1='last', e2='?', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1], r=13, alpha=-22.07582479704375)
State(e1='?', e2='Tuesday', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1], r=11, alpha=-29.67172851651022)
State(e1='last', e2='Tuesday', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1], r=12, alpha=-30.47569440660975)
State(e1='?', e2='last', bitstring=[0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1], r=12, alpha=-27.48437739704375)
State(e1='Tuesday', e2='here', bitstring=[0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0], r=9, alpha=-26.70633970921807)
State(e1='Tuesday', e2='last', bitstring=[0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0], r=12, alpha=-23.86530806878107)
State(e1='Tuesday', e2='?', bitstring=[0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1], 

State(e1='Tuesday', e2=',', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], r=10, alpha=-27.687086349305073)
State(e1='Tuesday', e2='last', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0], r=12, alpha=-28.614372216109473)
State(e1='Tuesday', e2='?', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1], r=13, alpha=-28.622095234633022)
State(e1='last', e2=',', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0], r=10, alpha=-28.4265417698386)
State(e1='last', e2='Tuesday', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0], r=11, alpha=-27.654279056109473)
State(e1='last', e2='?', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1], r=13, alpha=-26.43699445516655)
State(e1='?', e2=',', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1], r=10, alpha=-32.64494474836215)
State(e1='last', e2='Tuesday', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], r=12, alpha=-37.29327967097555)
State(e1='?', e2='Tuesday', bitstring=[1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1], r=11, 

In [42]:
def search(back_pointer, value, which_index):
    for entry in back_pointer:
        if value == entry[which_index]:
            return entry[2], entry[1]

In [43]:
# search through the back_pointer list of tuples, in tuple[0] for the end_point

end_point = getBestPerformerState(Q_main)

ptr = end_point
phrase = [1,1]
phrase_list_final = []
while(phrase[0] != 0 and phrase[1] != 0):
    phrase, ptr = search(back_pointer=back_pointer, value=ptr, which_index=0)
    print(phrase, ptr)
    print('Translation: {}\n'.format(phrase))
    phrase_list_final.append(phrase)

phrase_list_final = sorted(phrase_list_final, key=lambda tup: tup[0])

for phrase in phrase_list_final:
    print(phrase[2].english + ' '),
    

((11, 12, phrase(english='last Tuesday', logprob=-0.176091253757)), State(e1=',', e2='?', bitstring=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1], r=13, alpha=-24.466961076961546))
Translation: (11, 12, phrase(english='last Tuesday', logprob=-0.176091253757))

((13, 13, phrase(english='?', logprob=-0.000198262714548)), State(e1='here', e2=',', bitstring=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], r=10, alpha=-19.180860314247))
Translation: (13, 13, phrase(english='?', logprob=-0.000198262714548))

((9, 10, phrase(english='here ,', logprob=-0.163856804371)), State(e1='what', e2='happened', bitstring=[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], r=8, alpha=-15.681275073876002))
Translation: (9, 10, phrase(english='here ,', logprob=-0.163856804371))

((2, 8, phrase(english=', what happened', logprob=0.0)), State(e1='honourable', e2='senators', bitstring=[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], r=1, alpha=-10.228200433876001))
Translation: (2, 8, phrase(english=', what happened', logprob=0.0

In [44]:
lm.score(('happened', 'here'), ',')

(('here', ','), -1.074051036)

In [45]:
lm.score(('happened', 'here'), 'last')

(('here', 'last'), -2.753970036)