|
| 1 | +import numpy as np |
| 2 | +import re |
| 3 | +import itertools |
| 4 | +from collections import Counter |
| 5 | + |
| 6 | + |
| 7 | +def clean_str(string): |
| 8 | + """ |
| 9 | + Tokenization/string cleaning for all datasets except for SST. |
| 10 | + Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py |
| 11 | + """ |
| 12 | + string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) |
| 13 | + string = re.sub(r"\'s", " \'s", string) |
| 14 | + string = re.sub(r"\'ve", " \'ve", string) |
| 15 | + string = re.sub(r"n\'t", " n\'t", string) |
| 16 | + string = re.sub(r"\'re", " \'re", string) |
| 17 | + string = re.sub(r"\'d", " \'d", string) |
| 18 | + string = re.sub(r"\'ll", " \'ll", string) |
| 19 | + string = re.sub(r",", " , ", string) |
| 20 | + string = re.sub(r"!", " ! ", string) |
| 21 | + string = re.sub(r"\(", " \( ", string) |
| 22 | + string = re.sub(r"\)", " \) ", string) |
| 23 | + string = re.sub(r"\?", " \? ", string) |
| 24 | + string = re.sub(r"\s{2,}", " ", string) |
| 25 | + return string.strip().lower() |
| 26 | + |
| 27 | + |
| 28 | +def load_data_and_labels(positive_data_file, negative_data_file): |
| 29 | + """ |
| 30 | + Loads MR polarity data from files, splits the data into words and generates labels. |
| 31 | + Returns split sentences and labels. |
| 32 | + """ |
| 33 | + # Load data from files |
| 34 | + positive_examples = list(open(positive_data_file, "r", encoding="utf-8").readlines()) |
| 35 | + positive_examples = [s.strip() for s in positive_examples] |
| 36 | + negative_examples = list(open(negative_data_file, "r", encoding="utf-8").readlines()) |
| 37 | + negative_examples = [s.strip() for s in negative_examples] |
| 38 | + # Split by words |
| 39 | + x_text = positive_examples + negative_examples |
| 40 | + x_text = [clean_str(sent) for sent in x_text] |
| 41 | + # Generate labels |
| 42 | + positive_labels = [[0, 1] for _ in positive_examples] |
| 43 | + negative_labels = [[1, 0] for _ in negative_examples] |
| 44 | + y = np.concatenate([positive_labels, negative_labels], 0) |
| 45 | + return [x_text, y] |
| 46 | + |
| 47 | + |
| 48 | +def batch_iter(data, batch_size, num_epochs, shuffle=True): |
| 49 | + """ |
| 50 | + Generates a batch iterator for a dataset. |
| 51 | + """ |
| 52 | + data = np.array(data) |
| 53 | + data_size = len(data) |
| 54 | + num_batches_per_epoch = int((len(data)-1)/batch_size) + 1 |
| 55 | + for epoch in range(num_epochs): |
| 56 | + # Shuffle the data at each epoch |
| 57 | + if shuffle: |
| 58 | + shuffle_indices = np.random.permutation(np.arange(data_size)) |
| 59 | + shuffled_data = data[shuffle_indices] |
| 60 | + else: |
| 61 | + shuffled_data = data |
| 62 | + for batch_num in range(num_batches_per_epoch): |
| 63 | + start_index = batch_num * batch_size |
| 64 | + end_index = min((batch_num + 1) * batch_size, data_size) |
| 65 | + yield shuffled_data[start_index:end_index] |
0 commit comments