In [1]:
import logging

In [2]:
logging.basicConfig(level=logging.DEBUG, 
                    format='%(asctime)s - %(levelname)s - %(message)s', 
                    datefmt="%H:%M:%S")

In [24]:
def read_input(fpath="") -> str:
    import os
    if fpath=="" or fpath=="tinyshakespeare":
        fpath = os.path.join("data", "tinyshakespeare", "input.txt") # data/tinysharespeare/input.txt
    if os.path.exists(fpath):
        logging.info(f"Reading: {fpath}")
        with open(fpath) as f:
            text = f.read()

        return text
    
    else:
        logging.error(f"File not found: {fpath}")
        return None

In [25]:
text = read_input("tinyshakespeare")
total_len = len(text)
total_len

00:05:07 - INFO - Reading: data\tinyshakespeare\input.txt


1115394

### Train test split

In [5]:
def get_train_val_set(train_split=0.9):
    train_text = text[:int(total_len*train_split)]
    val_text = text[int(total_len*train_split):]
    return train_text, val_text

In [6]:
train_text, val_text = get_train_val_set()
len(train_text), len(val_text)

(1003854, 111540)

In [7]:
char_set = set(train_text)
vocab_size = len(char_set)
''.join(sorted(char_set)), vocab_size

("\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", 65)

### Tokenizer

In [13]:
# Mapping from string to list of int 
stoi = {ch: i for i, ch in enumerate(char_set)}

# Reverse mapping from list of int to string
itos = {i: ch for i, ch in enumerate(char_set)}

# Tokenize 
encode = lambda s: [stoi[ch] for ch in s] # Handle unknown chars
decode = lambda li: ''.join([itos[i] for i in li])

In [17]:
# test encode
test_tk = encode("hi!\nhello world!")
print(test_tk)
# test decode
decode(test_tk)

[48, 62, 55, 17, 48, 18, 12, 12, 36, 63, 39, 36, 50, 12, 54, 55]


'hi!\nhello world!'

In [18]:
tokenized_train_text = encode(train_text)