In [16]:
import regex as re
import pickle

In [17]:
text = open("Hindi_Premchand_Story.txt", "r", encoding="utf-8").read()
tokens = text.encode("utf-8")
tokens = list(map(int, tokens))

In [18]:
from train_util import get_stats, merge, encode, decode

In [19]:
hindi_split_pattern = r""" ?\p{Devanagari}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

compiled_pattern = re.compile(hindi_split_pattern)
text_chunks = re.findall(compiled_pattern, text)
print(text_chunks[0: 100])

# input text preprocessing
ids = [list(ch.encode("utf-8")) for ch in text_chunks]

vocab_size = 5000

num_merges = vocab_size - 256
merges = {} # (int, int) -> int
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes

for i in range(num_merges):
    # count the number of times every consecutive pair appears
    stats = {}
    for chunk_ids in ids:
        # passing in stats will update it in place, adding up counts
        get_stats(chunk_ids, stats)
    # find the pair with the highest count
    pair = max(stats, key=stats.get)
    # mint a new token: assign it the next available id
    idx = 256 + i
    # replace all occurrences of pair in ids with idx
    ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
    # save the merge
    merges[pair] = idx
    vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
    # prints
    print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")

['मुंशी', ' \n\n', '\n', 'प्रेमचंद', ' साहित्य', ' \n', '\n', 'प्रेमचंद', ' की', ' ', '\n', 'सर्वश्रेष्ठ', ' कहानियां', ' \n\n', '\n', 'जुलूस', ' तथा', ' अन्य', ' कहानियां', ' \n\n', '\n', 'प्रेमचंद', ' की', ' प्रत्यीक', ' कहानी', ' मानव', ' -', ' मन', ' के', ' अनेक', ' दशयों', ',', ' चेतना', ' के', ' अनेक', ' छोरों', ',', ' सामाजिक', ' ', '\n', 'कुरीतियों', ' तथा', ' आर्थिक', ' उत्पीड़न', ' के', ' विविध', ' आयामों', ' को', ' अपनी', ' संपूर्ण', ' कलात्मकता', ' के', ' साथ', ' ', '\n', 'अनावृत्त', ' करती', ' है', ' ।', ' कफन', ' ,', ' नमक', ' का', ' दारोगा', ',', ' शतरंज', ' के', ' खिलाड़ी', ' ,', ' वासना', ' की', ' कड़ियाँ', ' ,', ' ', '\n', 'दुनिया', ' का', ' सबसे', ' अनमोल', ' रतन', ' आदि', ',', ' सैकड़ों', ' रचनाएँ', ' ऐसी', ' हैं', ' ,', ' जो', ' विचार', ' और', ' अनुभूति', ' दोनों', ' ', '\n', 'स्तरों', ' पर', ' पाठकों', ' को', ' आज', ' भी']
merge 1/4744: (224, 164) -> 256 (b'\xe0\xa4') had 159961 occurrences
merge 2/4744: (32, 256) -> 257 (b' \xe0\xa4') had 52793 occurrences
merge 

In [20]:
enc_text = encode(text, merges, compiled_pattern)

In [21]:
dec_text = decode(enc_text, vocab)
dec_text == text

False

In [27]:
print(text_chunks[1:2])
enc_text = encode(''.join(text_chunks[1:2]), merges, compiled_pattern)
print(enc_text)
print([decode([i], vocab) for i in enc_text])

[' \n\n']
[467]
[' \n\n']


In [28]:
print(', '.join([str(i) for i in enc_text]))
print(decode(enc_text, vocab))
print([(str(i),decode([i], vocab)) for i in enc_text])

467
 


[('467', ' \n\n')]


In [29]:

model = {"vocab": vocab,
         "merges": merges,
         "compiled_pattern": compiled_pattern
         }

In [30]:

with open('model.pkl', 'wb') as handle:
    pickle.dump(model, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [31]:
with open('model.pkl', 'rb') as handle:
    model1 = pickle.load(handle)

In [32]:
model1

{'vocab': {0: b'\x00',
  1: b'\x01',
  2: b'\x02',
  3: b'\x03',
  4: b'\x04',
  5: b'\x05',
  6: b'\x06',
  7: b'\x07',
  8: b'\x08',
  9: b'\t',
  10: b'\n',
  11: b'\x0b',
  12: b'\x0c',
  13: b'\r',
  14: b'\x0e',
  15: b'\x0f',
  16: b'\x10',
  17: b'\x11',
  18: b'\x12',
  19: b'\x13',
  20: b'\x14',
  21: b'\x15',
  22: b'\x16',
  23: b'\x17',
  24: b'\x18',
  25: b'\x19',
  26: b'\x1a',
  27: b'\x1b',
  28: b'\x1c',
  29: b'\x1d',
  30: b'\x1e',
  31: b'\x1f',
  32: b' ',
  33: b'!',
  34: b'"',
  35: b'#',
  36: b'$',
  37: b'%',
  38: b'&',
  39: b"'",
  40: b'(',
  41: b')',
  42: b'*',
  43: b'+',
  44: b',',
  45: b'-',
  46: b'.',
  47: b'/',
  48: b'0',
  49: b'1',
  50: b'2',
  51: b'3',
  52: b'4',
  53: b'5',
  54: b'6',
  55: b'7',
  56: b'8',
  57: b'9',
  58: b':',
  59: b';',
  60: b'<',
  61: b'=',
  62: b'>',
  63: b'?',
  64: b'@',
  65: b'A',
  66: b'B',
  67: b'C',
  68: b'D',
  69: b'E',
  70: b'F',
  71: b'G',
  72: b'H',
  73: b'I',
  74: b'J',
  75: b'K',

In [33]:

print("compression_ratio  = ", len(tokens)/len(ids))

compression_ratio  =  9.32457055707176
