In [1]:
import glob
from collections import Counter
import pickle
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

from model import SimpleRNNModel
%load_ext autoreload
%autoreload 2

In [2]:
# load the dataset
poems = []

for path in glob.iglob("data/*.txt"):
    with open(path, 'r') as f:
        x = f.read()
        poems.append(x.lower().split(' '))

In [3]:
print(poems[0])

['coplas', 'elegíacas', '\n', '¡', 'ay', 'del', 'que', 'llega', 'sediento', '\n', 'a', 'ver', 'el', 'agua', 'correr', ',', '\n', 'y', 'dice', ':', 'la', 'sed', 'que', 'siento', '\n', 'no', 'me', 'la', 'calma', 'el', 'beber', '!', '\n', '¡', 'ay', 'de', 'quien', 'bebe', 'y', ',', 'saciada', '\n', 'la', 'sed', ',', 'desprecia', 'la', 'vida', ':', '\n', 'moneda', 'al', 'tahúr', 'prestada', ',', '\n', 'que', 'sea', 'al', 'azar', 'rendida', '!', '\n', 'del', 'iluso', 'que', 'suspira', '\n', 'bajo', 'el', 'orden', 'soberano', ',', '\n', 'y', 'del', 'que', 'sueña', 'la', 'lira', '\n', 'pitagórica', 'en', 'su', 'mano', '.', '\n', '¡', 'ay', 'del', 'noble', 'peregrino', '\n', 'que', 'se', 'para', 'a', 'meditar', ',', '\n', 'después', 'de', 'largo', 'camino', '\n', 'en', 'el', 'horror', 'de', 'llegar', '!', '\n', '¡', 'ay', 'de', 'la', 'melancolía', '\n', 'que', 'llorando', 'se', 'consuela', ',', '\n', 'y', 'de', 'la', 'melomanía', '\n', 'de', 'un', 'corazón', 'de', 'zarzuela', '!', '\n', '¡', '

In [4]:
word_counter = Counter()

for words in poems:
    word_counter.update(words)

print(len(word_counter)) # 6926 distinct words, OMG!!!
print("Most common:", word_counter.most_common(10))
print("Most uncommon:", word_counter.most_common()[-10:])

6298
Most common: [('\n', 12658), (',', 5864), ('.', 5174), ('de', 3176), ('y', 2805), ('la', 2739), ('el', 2583), ('en', 1611), ('que', 1466), ('un', 888)]
Most uncommon: [('arrancaste', 1), ('clamar', 1), ('voluntad', 1), ('solos', 1), ('perdemos', 1), ('perderá', 1), ('adivínala', 1), ('fetiche', 1), ('ofrenda', 1), ('puñetazos', 1)]


In [5]:
min_word_appearance = 1 # (optional) if > 1 gets rid of words that only appear N times

word_idx = {e[0]: idx+1 for idx, e in enumerate(word_counter.most_common()) if e[1] >= min_word_appearance}
idx_word = {v: k for k, v in word_idx.items()}

with open('word_idx.pkl', 'wb') as file:
    pickle.dump(word_idx, file)

with open('idx_word.pkl', 'wb') as file:
    pickle.dump(idx_word, file)

print(len(word_idx))

6298


In [5]:
with open('word_idx.pkl', 'rb') as file:
    word_idx = pickle.load(file)
    
with open('idx_word.pkl', 'rb') as file:
    idx_word = pickle.load(file)

In [6]:
tokenized = []
for poem in poems:
    tokenized.append([word_idx[word] for word in poem])

print(poems[0])
print(tokenized[0])
print("---")
print(poems[0][7], "->", word_idx[poems[0][7]])

['coplas', 'elegíacas', '\n', '¡', 'ay', 'del', 'que', 'llega', 'sediento', '\n', 'a', 'ver', 'el', 'agua', 'correr', ',', '\n', 'y', 'dice', ':', 'la', 'sed', 'que', 'siento', '\n', 'no', 'me', 'la', 'calma', 'el', 'beber', '!', '\n', '¡', 'ay', 'de', 'quien', 'bebe', 'y', ',', 'saciada', '\n', 'la', 'sed', ',', 'desprecia', 'la', 'vida', ':', '\n', 'moneda', 'al', 'tahúr', 'prestada', ',', '\n', 'que', 'sea', 'al', 'azar', 'rendida', '!', '\n', 'del', 'iluso', 'que', 'suspira', '\n', 'bajo', 'el', 'orden', 'soberano', ',', '\n', 'y', 'del', 'que', 'sueña', 'la', 'lira', '\n', 'pitagórica', 'en', 'su', 'mano', '.', '\n', '¡', 'ay', 'del', 'noble', 'peregrino', '\n', 'que', 'se', 'para', 'a', 'meditar', ',', '\n', 'después', 'de', 'largo', 'camino', '\n', 'en', 'el', 'horror', 'de', 'llegar', '!', '\n', '¡', 'ay', 'de', 'la', 'melancolía', '\n', 'que', 'llorando', 'se', 'consuela', ',', '\n', 'y', 'de', 'la', 'melomanía', '\n', 'de', 'un', 'corazón', 'de', 'zarzuela', '!', '\n', '¡', '

In [7]:
lengths = [len(sequence) for sequence in tokenized]
print(max(lengths))
print(min(lengths))

4758
13


In [8]:
max_seq_length = 400 #4758

padded = []
for sequence in tokenized:
    trimmed = sequence[-max_seq_length:]
    padding = [0] * (max_seq_length - len(trimmed))
    padded.append(padding + trimmed)
    
padded = np.array(padded)
print(padded.shape, padded[:,:-1].shape, padded[:,1:].shape)

(445, 400) (445, 399) (445, 399)


In [9]:
# create data loader
batch_size = 128

dataset = TensorDataset(torch.from_numpy(padded[:,:-1]), torch.from_numpy(padded[:,1:]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [10]:
x, y = next(iter(dataloader))

print(x.shape, y.shape)

print(x[0, -10:])
print(y[0, -10:])

torch.Size([128, 399]) torch.Size([128, 399])
tensor([ 451,   12, 5308,   20, 5309,    4,    6, 5310,    3,    1])
tensor([  12, 5308,   20, 5309,    4,    6, 5310,    3,    1,   24])


In [11]:
model = SimpleRNNModel(len(word_idx) + 1)
model

SimpleRNNModel(
  (embedding_layer): Embedding(6299, 200, padding_idx=0)
  (rnn): GRU(200, 512, num_layers=2, batch_first=True, dropout=0.5)
  (dropout): Dropout(p=0.5)
  (fc): Linear(in_features=512, out_features=6299, bias=True)
)

In [12]:
out = model(x)
print(out.shape)
print(torch.exp(out[0]).sum())

torch.Size([51072, 6299])
tensor(1.0000, grad_fn=<SumBackward0>)


In [17]:
epochs = 200

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss() #ignore_index=0
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=10, verbose=True)

model.train()
model.cuda()
for e in range(1, epochs+1):
    total_loss = 0
    total_accuracy = 0
    batch = 0
    for x, y in dataloader:
        batch += 1
        x, y = x.cuda(), y.cuda()
        
        optimizer.zero_grad()
        model.hidden = None
        
        pred = model(x)
        loss = criterion(pred, y.view(-1))
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 4)
        optimizer.step()
        
        total_loss += loss.item()
        
        equals = torch.argmax(pred, dim=1).view(-1) == y.view(-1)
        total_accuracy += torch.mean(equals.type(torch.FloatTensor))
            
        print(f"EPOCH {e} ({batch}/{len(dataloader)}) - loss {total_loss/batch:.4f} - acc {total_accuracy/batch:.4f}", end='\r') 
        
    scheduler.step(total_loss)
    print(f"EPOCH {e} - loss {total_loss/len(dataloader):.4f} - acc {total_accuracy/len(dataloader):.4f} ---------------------- ")
    

EPOCH 1 - loss 0.4178 - acc 0.8740 ---------------------- 
EPOCH 2 - loss 0.1226 - acc 0.9747 ---------------------- 
EPOCH 3 - loss 0.1167 - acc 0.9764 ---------------------- 
EPOCH 4 - loss 0.1088 - acc 0.9780 ---------------------- 
EPOCH 5 - loss 0.1047 - acc 0.9787 ---------------------- 
EPOCH 6 - loss 0.0992 - acc 0.9801 ---------------------- 
EPOCH 7 - loss 0.0954 - acc 0.9806 ---------------------- 
EPOCH 8 - loss 0.0929 - acc 0.9816 ---------------------- 
EPOCH 9 - loss 0.0910 - acc 0.9817 ---------------------- 
EPOCH 10 - loss 0.0897 - acc 0.9821 ---------------------- 
EPOCH 11 - loss 0.0867 - acc 0.9826 ---------------------- 
EPOCH 12 - loss 0.0848 - acc 0.9833 ---------------------- 
EPOCH 13 - loss 0.0848 - acc 0.9832 ---------------------- 
EPOCH 14 - loss 0.0824 - acc 0.9836 ---------------------- 
EPOCH 15 - loss 0.0820 - acc 0.9836 ---------------------- 
EPOCH 16 - loss 0.0811 - acc 0.9842 ---------------------- 
EPOCH 17 - loss 0.0796 - acc 0.9846 -------------

In [18]:
model.cpu()
torch.save(model.state_dict(), 'model.pt')

In [16]:
checkpoint = torch.load('model.pt')
model.load_state_dict(checkpoint)

In [19]:
# sample generation
seed = 'el amor'
predictions = []

model.cpu()
model.eval()
with torch.no_grad():
    model.hidden = None
    last_token = None
    
    # initialize hidden state using seed
    for word in seed.split(' '):
        token = np.array([word_idx[word]])
        token = torch.from_numpy(token).view(1, 1)
        last_state = model(token)
        
    for i in range(300):
        last_prediction = torch.argmax(last_state, dim=1).view(1, 1)
        last_state = model(last_prediction)
        idx = torch.argmax(last_state, dim=1).numpy()[0]
        word = idx_word[idx]
        if word == '<end>':
            break
        predictions.append(word)
        
print(seed + ' ' +  ' '.join(predictions))

el amor enfático . 
 en el solitario parque , la sonora 
 copia borbollante del agua cantora 
 me guió a la fuente . la fuente vertía 
 sobre el blanco mármol su monotonía . 
 la fuente cantaba : ¿ te recuerda , hermano , 
 un sueño lejano mi canto presente ? 
 fue una tarde lenta del lento verano . 
 respondí a la fuente : 
 no recuerdo , hermana , 
 mas sé que tu copla presente es lejana . 
 fue esta misma tarde : mi cristal vertía 
 como hoy sobre el mármol su monotonía . 
 ¿ recuerdas , hermano ? . . . los mirtos talares , 
 que ves , sombreaban los claros cantares 
 que escuchas . del rubio color de la llama , 
 el fruto maduro pendía en la rama , 
 lo mismo que ahora . ¿ recuerdas , hermano ? . . . 
 fue esta misma lenta tarde de verano . 
 — no sé qué me dice tu copla riente 
 de ensueños lejanos , hermana la fuente . 
 yo sé que tu claro cristal de alegría 
 ya supo del árbol la fruta bermeja ; 
 yo sé que es lejana la amargura mía 
 que sueña en la tarde de verano vieja . 
 yo