# Assignment 1

Using text http://www.gutenberg.org/files/2600/2600-0.txt
1. Make text lowercase and remove all punctuation except spaces and dots.
2. Tokenize text by BPE with vocab_size = 100
3. Train 3-gram language model with laplace smoothing $\delta=1$
4. Using beam search with k=10 generate sequences of length=10 conditioned on provided inputs. Treat dots as terminal tokens.
5. Calculate perplexity of the language model for the first sentence.

In [1]:
text = open('peace.txt', 'r').read()[2:]
len(text)

3227579

In [2]:
from string import punctuation
import re

def preprocess_text(text):
    text = text.lower() # make lowercase
    text = text.replace("\n", " ")
    punct = punctuation.replace(".", "")
    for p in punct:
      text = text.replace(p, "")
    text = text.replace(". ", ",")
    text = text.replace(".", "")
    text = text.replace(",", ". ") # replace all punctuation except dots with spaces
    text = re.sub(r" +", " ", text) # collapse multiple spaces into one '   ' -> ' '
    return text

text = preprocess_text(text)

In [3]:
text = text.split('.')
text = [x.strip() for x in text]

In [4]:
from collections import Counter
from sklearn.base import TransformerMixin


class BPE(TransformerMixin):
    def __init__(self, vocab_size=100):
        super(BPE, self).__init__()
        self.vocab_size = vocab_size
        # index to token
        self.itos = []
        # token to index
        self.stoi = {}
        
    def fit(self, text):
        """
        fit itos and stoi
        text: list of strings 
        """
        symbols = set()

        for sent in text:
          for l in sent:
            if l not in symbols:
              symbols.add(l)
              i = len(self.itos)
              self.itos.append(l)
              self.stoi[l] = i
        
        text = [[self.stoi[l] for l in sent] for sent in text]

        while len(self.itos) < self.vocab_size:
            new_list = []
            for sent in text:
              for i in range(len(sent) - 2):
                new_list.append(sent[i:i+2])
            c = Counter(map(tuple, new_list))

            new_token = c.most_common(1)[0][0]
            new_id = len(self.itos)
            
            self.itos.append(new_token)
            
            for i in range(len(text)):
              sent = text[i]
              for j in range(len(sent) - 2):
                if sent[j:j+2] == list(new_token):
                  sent[j] = new_id
                  sent[j+1] = -1
              text[i] = [l for l in sent if l >= 0]

        return self
    
    def transform(self, text):
        """
        convert text to a sequence of token ids
        text: list of strings
        """
        text = [list(sent) for sent in text]

        for token_id, token in enumerate(self.itos):
            for i in range(len(text)):
              sent = text[i]
              if type(token) == str:
                for j in range(len(sent)):
                  if sent[j] == token:
                    sent[j] = token_id
              else:
                for j in range(len(sent) - 2):
                  if sent[j:j+2] == token:
                    sent[j] = new_id
                    sent[j+1] = -1
                text[i] = [l for l in sent if l >= 0]
        return text
    
    def decode_token(self, tok):
        """
        tok: int or tuple
        """
        if type(tok) == tuple:
          return self.decode_token(tok[0]) + self.decode_token(tok[1])
        else:
          if type(self.itos[tok]) == list:
            return(self.decode_token(tuple(tok)))
          else:
            return self.itos[tok]
        return result
            
    def decode(self, text):
        """
        convert token ids into text
        """
        return ''.join(map(self.decode_token, text))
        
        
vocab_size = 100
bpe = BPE(vocab_size)
tokenized_text = bpe.fit_transform(text)

In [5]:
assert bpe.decode(tokenized_text[0]) == text[0]

In [6]:
start_token = vocab_size
end_token = vocab_size + 1


class LM:
    def __init__(self, vocab_size, delta=1):
        self.delta = delta
        self.vocab_size = vocab_size + 2
        self.proba = {}
        
    def infer(self, a, b, tau=1):
        """
        return vector of probabilities of size self.vocab for 3-grams which start with (a,b) tokens
        a: first token id
        b: second token id
        tau: temperature
        """
        result = {tok: self.get_proba(a, b, tok, tau)
                  for tok in range(self.vocab_size) if tok != start_token}
        return result
        
    def get_proba(self, a, b, c, tau=1):
        """
        get probability of 3-gram (a,b,c)
        a: first token id
        b: second token id
        c: third token id
        tau: temperature
        """
        proba2 = pow(self.proba[a][b]["sum"],
                      1/tau) + (self.delta * self.vocab_size)
        proba3 = pow(self.proba[a][b][c], 1/tau) + self.delta
        result = proba3/proba2
        return result
    
    def fit(self, text):
        """
        train language model on text
        text: list of lists
        """
        ttext = []
        for sentence in text:
          ttext += [start_token] + sentence + [end_token]
        ttext = tuple(ttext)
        grams = Counter([ttext[i:i+3] for i in range(len(ttext) - 2)])
        for a in range(self.vocab_size):
          if a == end_token:
            continue
          self.proba[a] = {}
          for b in range(self.vocab_size):
            if b in (start_token, end_token):
              continue
            self.proba[a][b] = {}
            for c in range(self.vocab_size):
              if c == start_token:
                continue
              self.proba[a][b][c] = grams[(a, b, c)]
            self.proba[a][b]["sum"] = sum(self.proba[a][b].values())
        return self
    
lm = LM(vocab_size, 1).fit(tokenized_text)

In [7]:
def perplexity(snt, lm, tau=1):
    """
    snt: sequence of token ids
    lm: language model
    """
    result = 0
    for i in range(3, len(snt)+1):
      result += lm.get_proba(snt[i-3], snt[i-2], snt[i-1], tau)
    return result

perplexity(tokenized_text[0], lm)

38.02930569801232

In [8]:
def beam_search(input_seq, lm, max_len=10, k=5, tau=1):
    """
    generate sequence from language model *lm* conditioned on input_seq
    input_seq: sequence of token ids for conditioning
    lm: language model
    max_len: max generated sequence length
    k: size of beam
    tau: temperature
    """

    input_seq = [start_token] + input_seq
    max_len += 1
    
    candidates = [tuple(input_seq)] * k

    for i in range(len(input_seq)-1, max_len):
      beam = {}
      for cand in candidates[:k]:
        a = cand[-2]
        b = cand[-1]
        infer = lm.infer(a, b, tau)
        beam.update({cand + (c,): infer[c] for c in infer})
      topk = sorted(beam, key=beam.get, reverse=True)[:k]
      candidates = [cand for cand in topk] + candidates[k:]
      final_idx = [topk.index(c) for c in topk if c[-1] == end_token]
      for idx in sorted(final_idx, reverse=True):
        k -= 1
        candidates[idx], candidates[k] = candidates[k], candidates[idx]
    
    for i, cand in enumerate(candidates):
      if cand[-1] == end_token:
        candidates[i] = cand[:-1]
    
    beam = {cand[1:]: perplexity(list(cand[1:]), lm, tau) for cand in candidates}
    return beam

In [9]:
input1 = 'horse '
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for tok_seq in result:
  print(bpe.decode(list(tok_seq))+": "+str(result[tok_seq]))

horse of th: 0.3205284009362348
horse whis : 0.010779854215407572
horse whim : 7.417695060398774e-05
horse of to: 0.30457968530688223
horse ing t: 0.11830332061596541
horse con t: 0.0007609758095371546
horse and t: 0.07385003584715494
horse whime: 5.163174084631868e-06
horse whoun: 1.268587291925252e-05
horse whout: 1.2671659078252562e-05


In [10]:
input1 = 'her'
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for tok_seq in result:
  print(bpe.decode(list(tok_seq))+": "+str(result[tok_seq]))

herring to : 0.17375038397334014
herring the: 0.14577341628206936
hering the : 0.15736994688547143
herrin the : 0.039457885444548015
hervin the : 0.039428677879375505
hering to t: 0.17374980333103449
herrin to t: 0.055837741890111044
hervin to t: 0.05580853432493854
hering ther: 0.14577281864637437
her: 1.17716386592672e-08


In [11]:
input1 = 'what'
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=1)
for tok_seq in result:
  print(bpe.decode(list(tok_seq))+": "+str(result[tok_seq]))

whatch and : 3.7562284192722237
whation to : 3.4081621493779317
whation the: 3.732546539319406
whatch the : 3.930293395976595
whatáshound: 3.4083774460125533
whatáshere : 3.4707321807135987
whatáshount: 3.2663582694756106
whatáshad t: 3.568807885079943
whatch to t: 3.1418967264841404
whatáshat t: 3.70200399143037


In [12]:
input1 = 'gun '
input1 = bpe.transform([input1])[0]
result = beam_search(input1, lm, max_len=10, k=10, tau=0.1)
for tok_seq in result:
  print(bpe.decode(list(tok_seq))+": "+str(result[tok_seq]))

gun whound : 0.06811305022739671
gun ing to : 0.17408737336268865
gun ing the: 0.14611040567141784
gun con the: 0.028568060864170935
gun and the: 0.10165712542011845
gun of the : 0.34398347760340886
gun whimen : 1.2828408237037665e-05
gun whiment: 7.583012182698783e-06
gun of to t: 0.36036333404897186
gun whout t: 0.01930485432917487
