# Implementation of char based LSTM with pyTorch trainied on entries of an english dictionary

TODO:
- Documentation

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable

torch.manual_seed(1)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('seaborn-whitegrid')

In [4]:
def to_vectensor(line,dict_):
    vec = np.zeros((len(line),len(dict_)),dtype="uint8")
    for i, char in enumerate(line):
        vec[i][dict_[char]] = 1
    
    return vec

In [5]:
class EncodedCharDataset(torch.utils.data.Dataset):
    """OxfordDict Dataset."""

    def __init__(self, file, seq_len, char_to_idx={},idx_to_char={}):
        """
        Args:
            file: filepath
            seq_len: lenght of sample the dataset will return
        """
        self.offset = 0
        
        # length of sample sequences
        self.seq_len = seq_len
        
        # char to id; Encoding dict
        self.char_to_idx = char_to_idx
        # id to char; Decoding dict
        self.idx_to_char = idx_to_char
        
        with open (file, "r") as f:
            text=f.read()
            self.data = np.empty(len(text))
            # fill dicts
            for i, char in enumerate(text): 
                    if char not in self.char_to_idx:
                        id_ = len(self.char_to_idx)
                        # add char to dicts
                        self.char_to_idx[char] = id_
                        self.idx_to_char[id_] = char
                    
                    # Encode and store text from file in self.data
                    self.data[i] = self.char_to_idx[char]
            
            
            self.data = np.array(self.data,dtype="uint8")
        
        self.unique_chars = len(self.char_to_idx)
        print("#different chars:", self.unique_chars)
        
        
        
    def __len__(self):
        return (len(self.data) - self.seq_len + 1)# // self.seq_len

    def __getitem__(self, i):
        """
        returns (self.seq_len,self.unique_chars) one hot vectors 
        """
        indices = self.data[i:i+self.seq_len]#[i*self.seq_len:(i+1)*self.seq_len]
        x_onehot = torch.zeros((self.seq_len,self.unique_chars))
        x_onehot[np.arange(self.seq_len), indices] = 1
        return x_onehot
    


In [6]:
## generating a smaller testing set
#with open(fp) as f:
#    text = f.read()
#    file = open("/Users/valentinwolf/data/oxford_dict/tiny_tiny_shakespeare.txt",'w') 
#    file.write(text[:10000])
#    print(len(text),)
#    file.close()
#    
#fp = "/Users/valentinwolf/data/oxford_dict/tiny_tiny_shakespeare.txt"

In [7]:
fp = "/Users/valentinwolf/data/oxford_dict/tiny_shakespeare.txt"
val_fp = "/Users/valentinwolf/data/oxford_dict/tiny_shakespeare_val.txt"
test_fp = "/Users/valentinwolf/data/oxford_dict/tiny_shakespeare_test.txt"
with open(fp) as f:
    text = f.read()
    val_file = open(val_fp,'w') 
    val_file.write(text[-15000:])
    val_file.close()
    test_file = open(test_fp,'w') 
    test_file.write(text[:-15000])
    test_file.close()

In [8]:
train_dataset = EncodedCharDataset(file=test_fp,seq_len=51)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16,
                                           shuffle=True, num_workers=2)

val_dataset = EncodedCharDataset(file=val_fp,seq_len=15000,
                                 char_to_idx=train_dataset.char_to_idx,
                                 idx_to_char=train_dataset.idx_to_char)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                           shuffle=False, num_workers=2)

#different chars: 65
#different chars: 65


In [9]:
class CharLSTM(torch.nn.Module):
    def __init__(self,input_size,hidden_size,output_size, num_layers=2,dropout=0.25,batch_first=True):
        super(CharLSTM, self).__init__()

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, dropout=dropout)
        self.lin = nn.Linear(hidden_size, output_size)
    
    def forward(self, inputs, hidden=None, force=False, steps=0):
        output, hidden = self.lstm(inputs, hidden)
        output = self.lin(output)

        return output, hidden

In [10]:
def sample_seq(model, start_char="\n", temperature=1.0,
               idx_to_char=train_dataset.idx_to_char, char_to_idx=train_dataset.char_to_idx,
              max_len=250):
    """
    samples seqence from model beginning with start_char until "<END>" is samples
    """
    
    start = torch.zeros((1,1,train_dataset.unique_chars),dtype=torch.float)
    start[0][0][char_to_idx[start_char]] = 1 
    
    seq = start_char
    
    output = start
    char_id = -1
    hidden = None
    while len(seq) < max_len:
        output, hidden = model(output,hidden)
        probs = nn.Softmax(dim=-1)(output).detach().numpy()[0,0]
        char_id = sample(probs,temperature)#output.argmax(dim=-1).item()
        output = torch.zeros((1,1,train_dataset.unique_chars),dtype=torch.float)
        output[0][0][char_id] = 1 
        
        seq = seq + idx_to_char[char_id]
    return seq

def sample(a, temperature=1.0):
    # helper function to sample an index from a probability array a
    a = np.log(a) / temperature
    a = np.exp(a) / np.sum(np.exp(a))
    return np.random.choice(range(train_dataset.unique_chars),p=a)

In [39]:
# TODO: Implement model loading
model = CharLSTM(input_size=train_dataset.unique_chars,hidden_size=512,output_size=train_dataset.unique_chars,
                 batch_first=True,num_layers=3)

In [40]:
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

losses = []
mean_losses = []
val_losses = []

In [41]:
# training LSTM
for epoch in range(300):
    print("\nEpoch {}:".format(epoch))
    losses.append([])
    for batch, seqences in enumerate(train_loader):
        model.zero_grad()
        
        inputs = seqences[:,:-1]
        targets = torch.argmax(seqences[:,1:],dim=-1)

        out,hidden = model(inputs)

        loss = loss_function(out.permute(dims=(0,2,1)), targets)
        loss.backward()
        optimizer.step()
        losses[-1].append(loss.item())
        if batch % 1 == 0:
            rmean_loss = np.mean(losses[-1][-15:])
            print("\rLoss: {:.5f} | Running loss: {:.5f}, Batch: {}       ".format(
                losses[-1][-1],rmean_loss,batch), end="")
    
    mean_loss = np.mean(loss[-1])
    mean_losses.append(mean_loss)
    print("\nMean Loss:", mean_loss)
    
    # valtidate
    # TODO: disable autograd and put it in seperate function
    for batch, seqences in enumerate(val_loader):
        print("validating", seqences.shape)
        model.zero_grad()
        
        inputs = seqences[:,:-1]
        targets = torch.argmax(seqences[:,1:],dim=-1)

        out,hidden = model(inputs)

        loss = loss_function(out.permute(dims=(0,2,1)), targets)
        val_losses.append(loss.item())
        print("\n Validation Loss:", loss)
        
    # plot batchwise train error
    fig = plt.figure()
    ax = plt.axes()
    # flatten losses array
    flat_losses = [item for sublist in losses for item in sublist]
    ax.plot(range(len(flat_losses)), flat_losses);
    
    if len(epoch_losses)==0 or np.mean(epoch_losses) < min(mean_losses): 
        # save model to file if mean loss of epoch is lower than ever before
        print("taking snapshot")
        torch.save(model,'./model/tiny_tiny_Shakespeare_LSTM_512.pkl')
    
    if epoch % 10 == 0:
        # show some example by the curren model
        print("Samples with increasing Temp -> conservative first:")
        for i in range(5):
            print(sample_seq(model,temperature=0.2*(i+1)))


Epoch 0:
Loss: 2.73779 | Running loss: 2.77474, Batch: 349       

Process Process-9:
Process Process-10:
Traceback (most recent call last):


Loss: 2.78366 | Running loss: 2.77293, Batch: 350       

Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/process.py", line 254, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)


KeyboardInterrupt: 

  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/process.py", line 254, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 52, in _worker_loop
    r = index_queue.get()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/queues.py", line 343, in get
    res = self._reader.recv_bytes()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/Library/Frameworks/Python.fra

In [24]:
out = sample_seq(model, start_char=' ', temperature=0.75, max_len=1000)
print(out)

 ri? w ns uoPt nh t In pid s f  nehan l t ow meo f,  lii
Amoa! 
s
t de
teh  oe  yilo d
w eet nr kooeco tel hesor oo 
Lu la t 
TI d.ce e y 
Ori oib
ceeoth n Od c
me s nn
l r t w  te t s t , to, Tiiee a g rot 
ose hh vn tn.do,e
t sey nteotir s sd wp, lohr t, t sfat re tew
t
s  beh s tes thoen  o ns u f aIian atog,auis t,eoteeeshay s.oh s rheolh oiy aNuh le h t boc i
n d t uoBto
Hl t w teeite looaaaan
s teed ce sem ng f t tate nr tin OCAolta! kteo
f treo seasr r 
cte
rstn w m natem
se aateeua
At nf seeiIl d !y t  see s d ew  Ec oireluoe ne nhh lund s t oe
Pshy w ut ne leh Irseith ou nlhod to  EAm
d 
Tsh n s iry y msheno-s  s d , t tiu h 
nl m toh l Esl n nes se t
T  or t l vlo meiushe,ar;oseaIr te rne l doo hale,e to t
m n lae t t d te Rp .t :e ap ud of a areoot l se ehoin e weth fh nm lh 
n d ctt tl 
f rbero d weeote ernotevnry ter
Ao oeo uiv geeonon s seeodte v v tod tt cat
ws, tase bnt foa, t lolr h oKeanhooTaratl l konnii nge Ebege neoeo Vs fam d ites b h t w OSular n memuiIl te ones 