In [334]:
import matplotlib.pyplot as plt
import numpy as np
import collections
import random
from tqdm.auto import tqdm


%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

Let's take a look at the data:

In [335]:
with open('data/data_small.txt', 'r', encoding='utf-8') as file:
    text = file.read()
    print(text[:1000])

TributespouredinfromaroundtheworldThursdaytothelateLabourPartyleaderJohnSmith,whodiedearlierfromamassiveheartattackaged55.InWashington,theUSStateDepartmentissuedastatementregretting"theuntimelydeath"oftherapier-tonguedScottishbarristerandparliamentarian."Mr.Smith,throughouthisdistinguishedcareeringovernmentandinopposition,leftaprofoundimpressiononthehistoryofhispartyandhiscountry,"StateDepartmentspokesmanMichaelMcCurrysaid."Secretary(ofStateWarren)ChristopherextendshisdeepestcondolencestoMrs.SmithandtotheSmithchildren."InBonn,theheadoftheGermanSocialDemocraticParty,RudolfScharping,saidinastatementhewas"veryaffectedbythesuddendeathofJohnSmith."AgoodfriendofGermansocialdemocracyhasleftustooearly.Hewasveryclosetoachievinghislife'sgoalofmakingtheLabourPartythelargestpoliticalforceinBritain"andwouldbe"cruellymissed"inEurope,hesaid.HongKongGovernorChrisPatten,aformerConservativePartychairman,offeredhiscondolencestotheSmithfamilyandsaidhisformerpolitcalopponentwasa"goodanddecentman,widelyresp

In [336]:
text_size = len(text)
text_size

652297

In [337]:
C = len(set(text))

In [338]:
print("Number of unique characters: {}".format(C))

Number of unique characters: 74


- There is 2 𝑛−1 possible segmentations for 𝑛-characters long data.
- 𝑛 − 1 latent binary variables $𝑠_𝑖$: denoting whether there is of isn’t a separator between two characters.
- Collapsed Gibbs sampling. Sample one variable conditioned by all the others.
- Exchangeability: if we reorder the words in the sequence, overall probability is the same.
- We can virtually move the changed words at the end of the sequence, compute the overal probablility of the two possibilities and then move the words virtually back.


if $s_i$ is 1, then there is a separator between characters $c_i$ and $c_{i+1}$

In [339]:
# Fixing the random seeds
np.random.seed(1234)
random.seed(1234)

In [340]:
import collections

def segmentation(text, text_size, s):
    words = []
    current_word = ""
    for idx, character in enumerate(text):
        if idx == text_size - 1:
            current_word += character
            continue
        
        if s[idx] == 1:
            current_word += character
            words.append(current_word)
            current_word = ""
        else:
            current_word += character
    return words

def get_word_counts(words):
    count = {}
    counter = collections.Counter(words)
    return counter
    for key, value in counter.items():
        count[key] = value
    return count

In [341]:
def get_prev_word(text, s, i):
    word = ""
    start_idx = i - 1
    if s[start_idx] == 1:
        start_idx -= 1
    while start_idx >= 0:
        if s[start_idx] == 1:
            break
        start_idx -= 1
    
    word = text[start_idx+1:i]
    return word

In [342]:
def get_next_word(text, s, i):
    word = ""
    end_idx = i + 1
    while end_idx < len(text) - 1:
        if s[end_idx] == 1:
            break
        end_idx += 1
    
    word = text[i:end_idx + 1]
    return word

In [343]:
def p0(word, p_c):
    uniform = 1.0 / float(C)
    return uniform**len(word) * p_c**(len(word)-1) * (1 - p_c)

In [344]:
def increment(count_dict, word):
    if word not in count_dict:
        count_dict[word] = 0
    
    count_dict[word] += 1

In [345]:
def decrement(count_dict, word):
    if word not in count_dict:
        count_dict[word] = 1
    
    count_dict[word] -= 1

In [346]:
def CRP_TextSegmentation(text, text_size, iterations, alpha, p_c, p_cont):
    s = np.random.randint(low=0, high=2, size=text_size-1)
    words = segmentation(text, text_size, s)
    count = get_word_counts(words)
    t = sum(count.values())
    
    opsa_progress_bar = tqdm(range(1, text_size - 1), desc="Text processing")
    
    for iteration in tqdm(range(iterations), desc="Iterations"):
        opsa_progress_bar.reset()
        for i in np.random.permutation(range(1, text_size - 1)):
            
            prev_word = get_prev_word(text, s, i)
            next_word = get_next_word(text, s, i)
            
            joined = prev_word + next_word
            if s[i] == 0:
                count[joined] -= 1
                count[joined] = max(0, count[joined])
                t -= 1
            else:
                count[prev_word] -= 1
                count[prev_word] = max(0, count[prev_word])
                count[next_word] -= 1
                count[next_word] = max(0, count[next_word])
                t -= 2
            
            p_0 = (alpha * p0(joined, p_c) + count[joined]) / (alpha + t)
            p_1 = (alpha * p0(prev_word, p_c) + count[prev_word]) / (alpha + t)
            p_1 *= (alpha * p0(next_word, p_c) + count[next_word]) / (alpha + t + 1)
            p_1 *= p_cont
            
            #s[i] = sample 0 or 1 with weights p[0] and p[1]
            #s[i] = np.random.choice([0, 1], p=[p_0, p_1])
            if (random.random() * (p_0 + p_1)) < p_1:
                s[i] = 0
            else:
                s[i] = 1
            
            if s[i] == 0:
                count[joined] += 1
                #increment(count, joined)
                t += 1
            else:
                count[prev_word] += 1
                count[next_word] += 1
                #increment(count, prev_word)
                #increment(count, next_word)
                t += 2

            opsa_progress_bar.update(1)
        
    words_updated = segmentation(text, text_size, s)
    now = " ".join(words_updated)
    with open('data/output.txt', 'w', encoding='utf-8') as file:
        file.write(now)

In [349]:
iterations = 100
alpha = 100
p_c = 0.5
p_cont = 0.99

In [350]:
CRP_TextSegmentation(text, text_size, iterations, alpha, p_c, p_cont)

Text processing:   0%|          | 0/652295 [00:00<?, ?it/s]

Iterations:   0%|          | 0/100 [00:00<?, ?it/s]