In [257]:
##Class: ANLY-580_03
##Professor: Trevor Adriaanse
##Topic: Therapy Chatbot
##Team Member: Xinran Zhang, Qian Chen, Ting Huang, Yifan(Jake) Zhu

In [258]:
#Load necessary packages
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import json


USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

In [259]:
# Splits each line of csv file to create conversation_object and conversations
def loadConversations(filepath):

    conversations = {}
    with open(filepath) as csvfile:
        reader = csv.DictReader(csvfile)
        id_index = 1
        for row in reader:
            conversation_object = {}
            conversation_object['id'] = id_index
            conversation_object['tag'] = row['tag']
            conversation_object['pattern'] = row['pattern']
            conversation_object['response'] = row['response']

            conversations[conversation_object['id']] = conversation_object

            id_index +=1

    return conversations


# Extracts pairs of pattern and response from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations.values():

        # Iterate over each conversation and retrieve pattern and response
        pattern = conversation["pattern"].strip()
        response = conversation["response"].strip()
        qa_pairs.append([pattern, response])

    return qa_pairs

In [260]:
# Define path to new file with formatted data
datafile = os.path.join('data/formatted_lines.txt')

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize conversations dict
conversations = {}
# Load conversations
conversations = loadConversations('data/final.csv')

# Write new file
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

In [261]:
# Default word tokens
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self):
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [262]:
MAX_LENGTH = 20  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read pattern/response pairs and return a voc object
def readVocs(datafile):
    print("Reading lines...")
    # Read the file and split into pairs
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc()
    return voc, pairs

# Returns True if both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(datafile):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


# Load/Assemble voc and pairs
# save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(datafile)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 700 sentence pairs
Trimmed to 666 sentence pairs
Counting words...
Counted words: 770

pairs:
['hi', 'hello there . tell me how are you feeling today ?']
['hey', 'hi there . what brings you here today ?']
['is anyone there', 'hi there . how are you feeling today ?']
['hi there', 'great to see you . how do you feel currently ?']
['hey there', 'how do you do ? do you have a good day ?']
['howdy', 'nice to meet you sir']
['hola', 'good to see you sir']
['hiya', 'hey ! long time no see . how are you ?']
['yo', 'what s the story ?']
['hey chatbot', 'it s a pleasure to meet you']
