In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import math
import os
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
alpaca_dataset = load_dataset('tatsu-lab/alpaca', split='train')

In [5]:
print(alpaca_dataset)

Dataset({
    features: ['instruction', 'input', 'output', 'text'],
    num_rows: 52002
})


In [6]:
data = alpaca_dataset.to_pandas()

In [7]:
PAD_token = 0
UNK_token = 1
SOS_token = 2
EOS_token = 3

class Vocab:
    def __init__(self, name) -> None:
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            PAD_token: "<pad>",
            UNK_token: "<unk>",
            SOS_token: "<sos>",
            EOS_token: "<eos>"
        }
        self.num_words = 4
    
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
    
    def trim(self, min_count):
        if self.trimmed:
            return
        
        self.trimmed = True
        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        
        # reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            PAD_token: "<pad>",
            UNK_token: "<unk>",
            SOS_token: "<sos>",
            EOS_token: "<eos>"
        }
        self.num_words = 4

        for word in keep_words:
            self.add_word(word)

In [8]:
import unicodedata
import re

def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalize_str(s):
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [9]:
def read_vocabs(datatable: pd.DataFrame, corpus_name="alpaca"):
    output_col = datatable['output'].to_list()
    instruction_col = datatable['instruction'].to_list()
    pairs = [[normalize_str(x), normalize_str(y)] for x, y in zip(instruction_col, output_col)]
    voc = Vocab(corpus_name)
    return voc, pairs

In [10]:
#output_col = data['output'].to_list()
#instruction_col = data['instruction'].to_list()
#a = [[normalize_str(x), normalize_str(y)] for x, y in zip(instruction_col, output_col)]

In [11]:
__MAX_LEN_KV_PAIR = 50

def is_under_maxlen(p):
    return len(p[0].split(' ')) < __MAX_LEN_KV_PAIR and len(p[1].split(' ')) < __MAX_LEN_KV_PAIR

def filter_pairs(pairs):
    return [pair for pair in pairs if is_under_maxlen(pair)]

In [12]:
def load_prepare_data(datatable: pd.DataFrame, save_dir):
    print("Start prepraing data...")
    voc, pairs = read_vocabs(datatable, "alpaca")
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filter_pairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.add_sentence(pair[0])
        voc.add_sentence(pair[1])
    print("Counted words:", voc.num_words)

    with open('word_map.json', 'w') as p:
        json.dump(voc.word2index, p)

    return voc, pairs

In [13]:
voc, pairs = load_prepare_data(data, "")
for pair in pairs[:10]:
    print(pair)

Start prepraing data...


Read 52002 sentence pairs
Trimmed to 45912 sentence pairs
Counting words...
Counted words: 36341
['give three tips for staying healthy .', '.eat a balanced diet and make sure to include plenty of fruits and vegetables . . exercise regularly to keep your body active and strong . . get enough sleep and maintain a consistent sleep schedule .']
['what are the three primary colors ?', 'the three primary colors are red blue and yellow .']
['describe the structure of an atom .', 'an atom is made up of a nucleus which contains protons and neutrons surrounded by electrons that travel in orbits around the nucleus . the protons and neutrons have a positive charge while the electrons have a negative charge resulting in an overall neutral atom . the number of each particle determines the atomic number and the type of atom .']
['how can we reduce air pollution ?', 'there are a number of ways to reduce air pollution such as shifting to renewable energy sources encouraging the use of public transporta

In [14]:
__MIN_COUNT_WORD = 3

def trim_rare_words(voc, pairs, min_count = __MIN_COUNT_WORD):
    # trim words used under the minimum count
    voc.trim(min_count)
    # filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True

        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
        
        if keep_input and keep_output:
            keep_pairs.append(pair)
    
    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

In [15]:
pairs = trim_rare_words(voc, pairs, 3)

Trimmed from 45912 pairs to 33897, 0.7383 of total


In [16]:
import itertools

def indexes_from_sentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

def zero_padding(l, fill_value=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fill_value))

def input_encode(l, voc):
    indexes_batch = [indexes_from_sentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    pad_list = zero_padding(indexes_batch)
    pad_var = torch.LongTensor(pad_list)
    return pad_var, lengths

def output_encode(l, voc):
    indexes_batch = [indexes_from_sentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    pad_list = zero_padding(indexes_batch)
    pad_var = torch.LongTensor(pad_list)
    return pad_var, lengths

In [17]:
[indexes_from_sentence(voc, word) for word in pairs[0][0]] + [PAD_token] * (200 - len(pairs[0][0]))

[[5553, 3],
 [4130, 3],
 [11, 3],
 [582, 3],
 [879, 879, 3],
 [11, 3],
 [489, 3],
 [1218, 3],
 [879, 879, 3],
 [582, 3],
 [4130, 3],
 [1218, 3],
 [879, 879, 3],
 [582, 3],
 [4130, 3],
 [489, 3],
 [1218, 3],
 [1218, 3],
 [879, 879, 3],
 [3646, 3],
 [489, 3],
 [350, 3],
 [1008, 3],
 [11, 3],
 [489, 3],
 [652, 3],
 [879, 879, 3],
 [457, 3],
 [4212, 3],
 [4514, 3],
 [4212, 3],
 [489, 3],
 [419, 3],
 [879, 879, 3],
 [41, 3],
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,


In [18]:
def indexes_from_sentence(voc, words):
    word_map = voc.word2index
    enc_c = [word_map.get(word, UNK_token) for word in words]
    return enc_c

max_len = 50

def encode_question(voc, sentence):
    words = sentence.split(' ')
    enc_c = indexes_from_sentence(voc, words) + [PAD_token] * (max_len - len(words))
    return enc_c

def encode_answer(voc, sentence):
    words = sentence.split(' ')
    enc_c = [SOS_token] + indexes_from_sentence(voc, words) + [EOS_token] + [PAD_token] * (max_len - len(words))
    return enc_c

In [19]:
pairs_encoded = []
for pair in pairs:
    qus = encode_question(voc, pair[0])
    ans = encode_answer(voc, pair[1])
    pairs_encoded.append([qus, ans])

with open('pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)

In [20]:
import json

a = json.load(open('pairs_encoded.json'))

In [21]:
question = torch.LongTensor(a[1][0])
question

tensor([45, 38, 46, 20, 47, 48, 10,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])

In [22]:
class AlpacaDataset(Dataset):
    def __init__(self) -> None:
        self.pairs = json.load(open('pairs_encoding.json'))
        self.dataset_size = len(self.pairs)
    
    def __getitem__(self, index: int):
        question = torch.LongTensor(self.pairs[index][0])
        answer = torch.LongTensor(self.pairs[index][1])

        return question, answer
    
    def __len__(self):
        return self.dataset_size

In [23]:
train_loader = torch.utils.data.DataLoader(AlpacaDataset(),
                                           batch_size = 100, 
                                           shuffle=True, 
                                           pin_memory=True)


FileNotFoundError: [Errno 2] No such file or directory: 'pairs_encoding.json'

In [None]:
import random

pair_batch = [random.choice(pairs) for _ in range(5)]
pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
pair_batch

[['create open ended questions about a given topic .',
  '. what are the advantages and disadvantages of d printing ? . how is d printing being used in various industries ? . what are the main components of a d printing machine ? . how does the technology of d printing work ? . what are the current trends in d printing ?'],
 ['describe a real world application of reinforcement learning .',
  'reinforcement learning has been used in a variety of real world applications such as robotics healthcare and finance . for instance robotics applications use reinforcement learning for end to end control where an agent i .e . a robot learns to maximize a reward by taking an action in an environment . in healthcare reinforcement learning can be used to develop treatments and optimize health systems . in finance reinforcement learning has been used to develop strategies for stock trading and portfolio management .'],
 ['explain how technology affects communication between people .',
  'technology ha

In [None]:
input_batch, output_batch = [], []
for pair in pair_batch:
    input_batch.append(pair[0])
    output_batch.append(pair[1])

In [None]:
input_batch

['create open ended questions about a given topic .',
 'describe a real world application of reinforcement learning .',
 'explain how technology affects communication between people .',
 'describe the structure of an atom .',
 'do the following equation']

In [None]:
inp, lengths = input_encode(input_batch, voc)

In [None]:
inp

tensor([[ 554,  201, 1777,  201,  201],
        [ 640, 1777,  908, 1777, 5029],
        [1777,  143, 4436,  143,    1],
        [  11,  554, 5451,  554,  846],
        [ 846,  640,   11,  640, 4922],
        [1777,  127,  127,  127, 1777],
        [   1, 3601, 1154, 3601,    1],
        [5029, 1777,    1, 1777,  555],
        [4436,    1, 4922,    1, 5029],
        [1777,   11, 5029,  846, 5451],
        [1154,    1, 6577, 4922, 5451],
        [   1,  640,    1, 1777, 5029],
        [1777, 1777,  846,    1, 6577],
        [1154,   11, 1777,  143,  127],
        [ 201, 5451,  554,  846, 1154],
        [1777,    1, 4922,  640, 2054],
        [ 201, 6577, 1154, 1848,    1],
        [   1, 5029, 5029,  554, 1777],
        [4591,  640, 5451,  846, 4591],
        [1848, 5451, 5029, 1848, 1848],
        [1777,  201, 2054,  640,   11],
        [ 143,    1,  909, 1777,  846],
        [ 846,   11,    1,    1,  127],
        [ 127, 4436,   11, 5029, 5029],
        [5029, 4436,  555,  555, 1154],


In [None]:
out, _ = output_encode(output_batch, voc)

In [None]:
out.shape

torch.Size([558, 5])

In [None]:
inp.shape

torch.Size([61, 5])

In [None]:
for i, (question, reply) in enumerate(train_loader):
    if i < 10:
        print(question)

TypeError: new(): invalid data type 'str'