# Практическая работа №6
Необходимо разработать (автономную) диалоговую программную систему или подсистему для программного продукта, создаваемого в рамках магистерского исследования. Например, в формате чат-бота. Рекомендуется использовать результаты, полученные при выполнении предыдущих практических работ. Внимание: допускается использовать любые известные подходы к генерации осмысленных предложений, кроме шаблонного.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import ast
from collections import Counter

## Загрузка данных

In [3]:
def load_lines(path):
    lines = {}
    with open(path, encoding="utf-8", errors="ignore") as f:
        for line in f:
            parts = line.strip().split(" +++$+++ ")
            if len(parts) == 5:
                lines[parts[0]] = parts[4].lower()
    return lines


def load_pairs(lines_path, conv_path, limit=1000):
    lines = load_lines(lines_path)
    pairs = []

    with open(conv_path, encoding="utf-8", errors="ignore") as f:
        for row in f:
            parts = row.strip().split(" +++$+++ ")
            if len(parts) == 4:
                ids = ast.literal_eval(parts[3])
                for i in range(len(ids) - 1):
                    if ids[i] in lines and ids[i+1] in lines:
                        pairs.append((lines[ids[i]], lines[ids[i+1]]))
            if len(pairs) >= limit:
                break
    return pairs


## Cловарь и токинизация

In [4]:
def tokenize(s):
    return s.replace("?", "").replace("!", "").replace(".", "").split()


pairs = load_pairs("data/movie_lines.txt", "data/movie_conversations.txt")

counter = Counter()
for q, a in pairs:
    counter.update(tokenize(q))
    counter.update(tokenize(a))

vocab = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
for w in counter:
    vocab[w] = len(vocab)

ivocab = {i:w for w,i in vocab.items()}

# Подготовка данных

In [5]:
def encode(sentence):
    return [vocab.get(w, vocab["<UNK>"]) for w in tokenize(sentence)]


data = []
for q, a in pairs:
    src = encode(q)
    tgt = [vocab["<SOS>"]] + encode(a) + [vocab["<EOS>"]]
    data.append((src, tgt))


## Обучение

In [6]:
EMB = 128
HID = 256

encoder_emb = nn.Embedding(len(vocab), EMB)
encoder = nn.LSTM(EMB, HID, batch_first=True)

decoder_emb = nn.Embedding(len(vocab), EMB)
decoder = nn.LSTM(EMB, HID, batch_first=True)
fc = nn.Linear(HID, len(vocab))


In [7]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
opt = optim.Adam(
    list(encoder.parameters()) +
    list(decoder.parameters()) +
    list(encoder_emb.parameters()) +
    list(decoder_emb.parameters()) +
    list(fc.parameters()),
    lr=0.001
)

EPOCHS = 10

for epoch in range(EPOCHS):
    total_loss = 0

    random.shuffle(data)
    for src, tgt in data:
        src = torch.tensor(src).unsqueeze(0)
        tgt = torch.tensor(tgt).unsqueeze(0)

        opt.zero_grad()

        enc_out, hidden = encoder(encoder_emb(src))

        dec_in = tgt[:, :-1]
        dec_out, _ = decoder(decoder_emb(dec_in), hidden)

        logits = fc(dec_out)
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt[:, 1:].reshape(-1)
        )

        loss.backward()
        opt.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, loss={total_loss/len(data):.4f}")


Epoch 1, loss=5.8586
Epoch 2, loss=4.8555
Epoch 3, loss=3.8161
Epoch 4, loss=2.5839
Epoch 5, loss=1.5493
Epoch 6, loss=0.9007
Epoch 7, loss=0.5091
Epoch 8, loss=0.3037
Epoch 9, loss=0.1868
Epoch 10, loss=0.1253


## Генерация ответа

In [8]:
def reply(text, max_len=20, temperature=1.0):
    src = torch.tensor(encode(text)).unsqueeze(0)
    _, hidden = encoder(encoder_emb(src))

    cur = torch.tensor([[vocab["<SOS>"]]])
    result = []

    for _ in range(max_len):
        out, hidden = decoder(decoder_emb(cur), hidden)
        logits = fc(out[:, -1]) / temperature
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1).item()

        if next_token == vocab["<EOS>"]:
            break

        result.append(ivocab.get(next_token, "<UNK>"))
        cur = torch.tensor([[next_token]])

    return " ".join(result)


## Чат

In [11]:
print("'exit' - выйти из чата")
while True:
    msg = input("Вы: ").lower()
    if msg == "exit":
        break
    print("Вы: ", msg)
    print("Бот:", reply(msg))


'exit' - выйти из чата
Вы:  hi
Бот: looks like things worked out tonight, huh
Вы:  i don't know
Бот: what i think of it
Вы:  tell me
Бот: you need therapy has anyone ever told you that
Вы:  no
Бот: what about back home
