In [3]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import random
import os
import sys
import nltk
from collections import Counter
import re
import numpy as np
from argparse import Namespace

In [86]:
#nltk.download('punkt')

In [4]:
flags = Namespace(
    seq_size=32,
    batch_size=1,
    embedding_size=64,
    lstm_size=64,
    gradients_norm=5,
    initial_words=['Et', 'je'],
    predict_top_k=5,
)

In [5]:
## rap data
directory = "rap/Jul"
all = ""
for album in os.listdir(directory):
    dir_album = "{}/{}".format(directory, album)
    for son in os.listdir(dir_album):
        adr = "{}/{}".format(dir_album,son)
        with open(adr, 'r') as f:
            data = f.read()
#             decoded_data = data.decode('utf8')
            all+=data



In [6]:
##american dad data
directory = "scripts_american_dad"
all = ""
for saison in os.listdir(directory):
    dir_saison = "{}/{}".format(directory, saison)
    for ep in os.listdir(dir_saison):
        adr = "{}/{}".format(dir_saison,ep)
        with open(adr, 'r') as f:
            data = f.read()
            all+=data
            

In [10]:
sentences = all.split("\n")
print(sentences[0])
print("nombre de phrases : {}".format(len(sentences)))
sentences = [nltk.word_tokenize(s.lower()) for s in sentences]
sentences = [s for s in sentences if len(s)>0]


Shut up, Steve.
nombre de phrases : 119508


In [7]:
token = nltk.word_tokenize(all.lower())
words = Counter(token)
words = sorted(words, key=words.get, reverse=True)
vocab_size = len(words)
word2idx = {o:i for i,o in enumerate(words)}
idx2word = {i:o for i,o in enumerate(words)}
print(vocab_size)

30262


In [9]:
int_text = [word2idx[w] for w in token]
num_batches = int(len(int_text) / (flags.seq_size * flags.batch_size))
in_text = int_text[:num_batches * flags.batch_size * flags.seq_size]
out_text = np.zeros_like(in_text)
out_text[:-1] = in_text[1:]
out_text[-1] = in_text[0]
in_text = np.reshape(in_text, (flags.batch_size, -1))
out_text = np.reshape(out_text, (flags.batch_size, -1))

In [10]:
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

In [11]:
class Model(nn.Module):
    def __init__(self, nb_cells, hidden_size, vocab_size, embeddings_dim): 
        super(Model, self).__init__()
        self.gru = nn.GRU(embeddings_dim, hidden_size, nb_cells, batch_first = True)
        self.embeddings = nn.Embedding(vocab_size, embeddings_dim)
        self.hidden_size = hidden_size
        self.nb_cells = nb_cells
        self.dense1 = nn.Linear(hidden_size, vocab_size)

        
    def forward(self, x, hidden):
        embeds = self.embeddings(x)
        gru_out, hidden = self.gru(embeds, hidden)
        out = self.dense1(gru_out)
        return out, hidden
    
        
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.nb_cells, batch_size, self.hidden_size).zero_()
        return hidden

In [28]:
###Hyper paramètres
vocab_size = len(word2idx)
embedding_dim = 50
hidden_dim = 256
nb_cells = 2
model = Model(nb_cells,hidden_dim, vocab_size, embedding_dim)
lr=0.005
epochs = range(10)
loss_fn = nn.CrossEntropyLoss()
optimizer = th.optim.Adam(model.parameters(), lr=lr) #Adam adapté aux pb de NLP

In [38]:
###Apprentissage
for epoch in epochs:
    batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
    h = model.init_hidden(flags.batch_size)
    i=0
    print("EPOCH {}".format(epoch))
    for x, y in batches:
        i+=1
        optimizer.zero_grad()
        h = h.data
        x = th.tensor(x)
        y = th.tensor(y)
        pred, h = model.forward(x, h)
        loss = loss_fn(pred.transpose(1, 2), y)
        h = h.detach()
        loss.backward()  
        nn.utils.clip_grad_norm_( model.parameters(), flags.gradients_norm)
        optimizer.step()
        
        if(i%1000==0):
            print("{}/{}".format(i, num_batches))
            print("predict iteration:\n")
            predict(model, ["i", "am"], word2idx, idx2word, 50)
            
    print("predict epoch :\n")
    predict(model, ["i", "am"], word2idx, idx2word)
    print('\n')

EPOCH 0
1000/31773
predict iteration:

i am temptation ] area. ofwood red-hot ofwood area. red-hot ofwood red-hot scream area. red-hot red-hot ofwood scream affected red-hot area. scream affected area. scream red-hot red-hot area. area. scream scream calling red-hot scream area. area. area. ofwood area. calling ofwood red-hot scream area. calling area. area. area. scream red-hot scream scream calling
2000/31773
predict iteration:

i am sense ] ] kreme krispy kreme krispy sense krispy ah-ah sense ah-ah sense kreme kreme kreme krispy krispy sense ah-ah cloak kreme sense kreme sense krispy sense krispy cloak kreme krispy cloak ah-ah kreme ah-ah cloak ah-ah kreme ah-ah ah-ah krispy cloak krispy cloak sense cloak sense cloak kreme krispy kreme


KeyboardInterrupt: 

In [35]:
def predict(model, words, vocab_to_int, int_to_vocab,nbwords, top_k=5):
    model.eval()
    h = model.init_hidden(1)
    
    for w in words:
        idx = th.LongTensor([[vocab_to_int[w]]])
        out, h = model(idx, h)
    
    _, top_idx = th.topk(out[0], k=top_k)
    choices = top_idx.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])
    
    for _ in range(nbwords):
        ix = th.LongTensor([[choice]])
        out, h = model(idx, h)

        _, top_idx = th.topk(out[0], k=top_k)
        choices = top_idx.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])

    print(' '.join(words))

In [37]:
predict(model, ["i", "am"], word2idx, idx2word, 30)

i am are , nazi coveted nazi coveted off-the-charts nazi diddy diddy failing failing off-the-charts off-the-charts roller off-the-charts nazi nazi off-the-charts off-the-charts off-the-charts roller nazi failing off-the-charts nazi meet- off-the-charts off-the-charts meet- failing


In [40]:
th.save(model.state_dict(), "model.pth")