In [1]:
import pandas as pd
import time

In [2]:
df = pd.read_csv('data.csv')

In [3]:
df.head()

Unnamed: 0.1,Unnamed: 0,id,episode_id,number,raw_text,timestamp_in_ms,speaking_line,character_id,location_id,raw_character_text,raw_location_text,spoken_words,normalized_text,word_count
0,0,10368,35,29,"Lisa Simpson: Maggie, look. What's that?",235000,True,9,5.0,Lisa Simpson,Simpson Home,"Maggie, look. What's that?",maggie look whats that,4.0
1,1,10369,35,30,Lisa Simpson: Lee-mur. Lee-mur.,237000,True,9,5.0,Lisa Simpson,Simpson Home,Lee-mur. Lee-mur.,lee-mur lee-mur,2.0
2,2,10370,35,31,Lisa Simpson: Zee-boo. Zee-boo.,239000,True,9,5.0,Lisa Simpson,Simpson Home,Zee-boo. Zee-boo.,zee-boo zee-boo,2.0
3,3,10372,35,33,Lisa Simpson: I'm trying to teach Maggie that ...,245000,True,9,5.0,Lisa Simpson,Simpson Home,I'm trying to teach Maggie that nature doesn't...,im trying to teach maggie that nature doesnt e...,24.0
4,4,10374,35,35,"Lisa Simpson: It's like an ox, only it has a h...",254000,True,9,5.0,Lisa Simpson,Simpson Home,"It's like an ox, only it has a hump and a dewl...",its like an ox only it has a hump and a dewlap...,18.0


In [4]:
phrases = df['normalized_text'].tolist()

In [5]:
phrases[:10]

['maggie look whats that',
 'lee-mur lee-mur',
 'zee-boo zee-boo',
 'im trying to teach maggie that nature doesnt end with the barnyard i want her to have all the advantages that i didnt have',
 'its like an ox only it has a hump and a dewlap hump and dew-lap hump and dew-lap',
 'you know his blood type how romantic',
 'oh yeah whats my shoe size',
 'ring',
 'yes dad',
 'ooh look maggie what is that do-dec-ah-edron dodecahedron']

In [6]:
text = [[c for c in ph] for ph in phrases if type(ph) is str]

In [7]:
CHARS = set('abcdefghijklmnopqrstuvwxyz ')

In [8]:
INDEX_TO_CHAR = ['none'] + [w for w in CHARS]

In [9]:
INDEX_TO_CHAR

['none',
 'o',
 'l',
 'w',
 'j',
 'd',
 'b',
 'q',
 'c',
 'h',
 'z',
 'n',
 'p',
 'y',
 'g',
 'u',
 'e',
 'r',
 ' ',
 'k',
 'f',
 's',
 'm',
 'x',
 'a',
 't',
 'v',
 'i']

In [10]:
CHAR_TO_INDEX = {w: i for i, w in enumerate(INDEX_TO_CHAR)}

In [11]:
CHAR_TO_INDEX

{'none': 0,
 'o': 1,
 'l': 2,
 'w': 3,
 'j': 4,
 'd': 5,
 'b': 6,
 'q': 7,
 'c': 8,
 'h': 9,
 'z': 10,
 'n': 11,
 'p': 12,
 'y': 13,
 'g': 14,
 'u': 15,
 'e': 16,
 'r': 17,
 ' ': 18,
 'k': 19,
 'f': 20,
 's': 21,
 'm': 22,
 'x': 23,
 'a': 24,
 't': 25,
 'v': 26,
 'i': 27}

In [12]:
import torch

In [13]:
MAX_LEN = 50

In [14]:
X = torch.zeros((len(text), MAX_LEN), dtype=int)

In [15]:
for i in range(len(text)):
    for j, w in enumerate(text[i]):
        if j >= MAX_LEN:
            break
        X[i, j] = CHAR_TO_INDEX.get(w, CHAR_TO_INDEX['none'])

In [16]:
X[0:1]

tensor([[22, 24, 14, 14, 27, 16, 18,  2,  1,  1, 19, 18,  3,  9, 24, 25, 21, 18,
         25,  9, 24, 25,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])

In [17]:
class Network(torch.nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.word_embeddings = torch.nn.Embedding(len(INDEX_TO_CHAR), 28)
        self.gru = torch.nn.RNN(28, 128, batch_first=True)
        self.hidden2tag = torch.nn.Linear(128, len(INDEX_TO_CHAR))

    def forward(self, sentences):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

    def forward_state(self, sentences, state):
        embeds = self.word_embeddings(sentences)
        gru_out, state = self.gru(embeds, state)
        tag_space = self.hidden2tag(gru_out.reshape(-1, 128))
        return tag_space.reshape(sentences.shape[0], sentences.shape[1], -1), state

In [18]:
model = Network()

In [19]:
model.forward(X[0:1])[0].shape

torch.Size([1, 50, 28])

In [20]:
def generate_sentence():
    sentence = ['h', 'e', 'l', 'l', 'o']
    state = None
    for i in range(MAX_LEN):
        X = torch.Tensor([[CHAR_TO_INDEX[sentence[i]]]]).type(torch.long)
        if i == 0:
            result, state = model.forward(X)
        else:
            result, state = model.forward_state(X, state)
        prediction = result[0, -1, :]
        index_of_prediction = prediction.argmax()
        if i >= len(sentence) - 1:
            if index_of_prediction == 0:
                break
        sentence.append(INDEX_TO_CHAR[index_of_prediction])
    print(''.join(sentence))

In [21]:
generate_sentence()

helloxpssxxaxxxxsxbxxxxxxxbxxxxxxxxxxxxxxxxxxxxxxxxxxxx


In [22]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=.05)

In [23]:
for ep in range(300):
    start = time.time()
    train_loss = 0.
    train_passed = 0

    for i in range(int(len(X) / 100)):
        batch = X[i * 100:(i + 1) * 100]
        X_batch = batch[:, :-1]
        Y_batch = batch[:, 1:].flatten()

        optimizer.zero_grad()
        answers, _ = model.forward(X_batch)
        answers = answers.view(-1, len(INDEX_TO_CHAR))
        loss = criterion(answers, Y_batch)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()
        train_passed += 1

    print("Epoch {}. Time: {:.3f}, Train loss: {:.3f}".format(ep, time.time() - start, train_loss / train_passed))
    generate_sentence()


Epoch 0. Time: 9.855, Train loss: 2.112
helloe     toaaa      oaaaa      oaaaa      oaaaa      
Epoch 1. Time: 9.099, Train loss: 1.812
helloe     tooaaou     tooaaou     tooaaou     tooaaou 
Epoch 2. Time: 9.748, Train loss: 1.731
helloe     toaaahu     toaaahu     toaaahu     toaaahu 
Epoch 3. Time: 9.362, Train loss: 1.684
helloe     toaaahu     toaaahu     toaaahu     toaaahu 
Epoch 4. Time: 9.409, Train loss: 1.652
helloe     tooaahu     tooaahu     tooaahu     tooaahu 
Epoch 5. Time: 8.897, Train loss: 1.627
helloe     tooaahu     tooaahu     tooaahu     tooaahu 
Epoch 6. Time: 6.931, Train loss: 1.607
helloe     tooaaou     tooaahu     taoaahr     tooaahu 
Epoch 7. Time: 7.292, Train loss: 1.590
helloe l   toaaoou d   taaooort     iooaaau    nonetooaohu
Epoch 8. Time: 7.161, Train loss: 1.575
helloe l   toaaoou d   taaooort     iooaoan     tooaoou
Epoch 9. Time: 7.330, Train loss: 1.562
helloe l   toaaoou d   taaooort     iooootn     taoooor
Epoch 10. Time: 7.322, Train loss: 1.

KeyboardInterrupt: 

## Полезное

In [None]:
torchnlp.samplers.BucketBatchSampler - для сэмплирования данных