In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import math
import os
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [4]:
alpaca_dataset = load_dataset('tatsu-lab/alpaca', split='train')

In [5]:
print(alpaca_dataset)

Dataset({
    features: ['instruction', 'input', 'output', 'text'],
    num_rows: 52002
})


In [6]:
data = alpaca_dataset.to_pandas()

In [7]:
PAD_token = 0
UNK_token = 1
SOS_token = 2
EOS_token = 3

class Vocab:
    def __init__(self, name) -> None:
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            PAD_token: "<pad>",
            UNK_token: "<unk>",
            SOS_token: "<sos>",
            EOS_token: "<eos>"
        }
        self.num_words = 4
    
    def add_sentence(self, sentence):
        for word in sentence.split(' '):
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1
    
    def trim(self, min_count):
        if self.trimmed:
            return
        
        self.trimmed = True
        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)
        
        # reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {
            PAD_token: "<pad>",
            UNK_token: "<unk>",
            SOS_token: "<sos>",
            EOS_token: "<eos>"
        }
        self.num_words = 4

        for word in keep_words:
            self.add_word(word)

In [8]:
import unicodedata
import re

def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalize_str(s):
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

In [9]:
def read_vocabs(datatable: pd.DataFrame, corpus_name="alpaca"):
    output_col = datatable['output'].to_list()
    instruction_col = datatable['instruction'].to_list()
    pairs = [[normalize_str(x), normalize_str(y)] for x, y in zip(instruction_col, output_col)]
    voc = Vocab(corpus_name)
    return voc, pairs

In [10]:
#output_col = data['output'].to_list()
#instruction_col = data['instruction'].to_list()
#a = [[normalize_str(x), normalize_str(y)] for x, y in zip(instruction_col, output_col)]

In [11]:
__MAX_LEN_KV_PAIR = 50

def is_under_maxlen(p):
    return len(p[0].split(' ')) < __MAX_LEN_KV_PAIR and len(p[1].split(' ')) < __MAX_LEN_KV_PAIR

def filter_pairs(pairs):
    return [pair for pair in pairs if is_under_maxlen(pair)]

In [12]:
def load_prepare_data(datatable: pd.DataFrame, save_dir):
    print("Start prepraing data...")
    voc, pairs = read_vocabs(datatable, "alpaca")
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filter_pairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.add_sentence(pair[0])
        voc.add_sentence(pair[1])
    print("Counted words:", voc.num_words)

    with open('word_map.json', 'w') as p:
        json.dump(voc.word2index, p)

    return voc, pairs

In [13]:
voc, pairs = load_prepare_data(data, "")
for pair in pairs[:10]:
    print(pair)

Start prepraing data...
Read 52002 sentence pairs
Trimmed to 31568 sentence pairs
Counting words...
Counted words: 27304
['give three tips for staying healthy .', '.eat a balanced diet and make sure to include plenty of fruits and vegetables . . exercise regularly to keep your body active and strong . . get enough sleep and maintain a consistent sleep schedule .']
['what are the three primary colors ?', 'the three primary colors are red blue and yellow .']
['identify the odd one out .', 'telegram']
['explain why the following fraction is equivalent to', 'the fraction is equivalent to because both numerators and denominators are divisible by . dividing both the top and bottom numbers by yields the fraction .']
['render a d model of a house', 'nooutput this type of instruction cannot be fulfilled by a gpt model .']
['evaluate this sentence for spelling and grammar mistakes', 'he finished his meal and left the restaurant .']
['how did julius caesar die ?', 'julius caesar was assassinated 

In [14]:
__MIN_COUNT_WORD = 3

def trim_rare_words(voc, pairs, min_count = __MIN_COUNT_WORD):
    # trim words used under the minimum count
    voc.trim(min_count)
    # filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True

        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
        
        if keep_input and keep_output:
            keep_pairs.append(pair)
    
    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs

In [15]:
pairs = trim_rare_words(voc, pairs, 3)

Trimmed from 31568 pairs to 21718, 0.6880 of total


In [16]:
def indexes_from_sentence(voc, words):
    word_map = voc.word2index
    enc_c = [word_map.get(word, UNK_token) for word in words]
    return enc_c

max_len = 50

def encode_question(voc, sentence):
    words = sentence.split(' ')[:max_len]
    enc_c = indexes_from_sentence(voc, words) + [PAD_token] * (max_len - len(words))
    return enc_c

def encode_answer(voc, sentence):
    words = sentence.split(' ')[:max_len]
    enc_c = [SOS_token] + indexes_from_sentence(voc, words) + [EOS_token] + [PAD_token] * (max_len - len(words) - 1)
    return enc_c

In [17]:
pairs_encoded = []
for pair in pairs:
    qus = encode_question(voc, pair[0])
    ans = encode_answer(voc, pair[1])
    pairs_encoded.append([qus, ans])

with open('pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)