In [93]:
import random
import json
import glob
import pickle
from collections import Counter
import re
import torch
import pandas as pd
import nltk 
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import argparse
import logging

In [94]:
MAX_SEQ_LEN = 512

In [95]:
hf_dataset = load_dataset("roneneldan/TinyStories")

In [96]:
train_data = hf_dataset['train']
val_data = hf_dataset['validation']

In [97]:
train_stories = []
val_stories = []

max_len = 0
word_count = {}

for i in tqdm(range(int(0.0005 * len(train_data))), desc='Processing train data'):
    story = train_data[i]['text']
    story = story.lower()
    story = re.sub(r'[^a-zA-Z\s]', ' ', story)
    story = re.sub(r'\s+', ' ', story)

    story = story.strip()
    story = nltk.word_tokenize(story)
    story = ['<bos>'] + story + ['<eos>']

    for word in story:
        if word in word_count:
            word_count[word] += 1
        else:
            word_count[word] = 1

    train_stories.append(story)

    if len(story) > max_len:
        max_len = len(story)

val_stories = []
for i in tqdm(range(int(0.005 * len(val_data))), desc='Processing val data'):
    story = val_data[i]['text']
    story = story.lower()
    story = re.sub(r'[^a-zA-Z\s]', ' ', story)
    story = re.sub(r'\s+', ' ', story)

    story = story.strip()
    story = nltk.word_tokenize(story)
    story = ['<bos>'] + story + ['<eos>']

    for word in story:
        if word in word_count:
            word_count[word] += 1
        else:
            word_count[word] = 1

    val_stories.append(story)

    if len(story) > max_len:
        max_len = len(story)

Processing train data: 100%|██████████| 1059/1059 [00:00<00:00, 1416.93it/s]
Processing val data: 100%|██████████| 109/109 [00:00<00:00, 1862.31it/s]


In [98]:
max_len

839

In [99]:
vocab_file = None

if vocab_file is not None:
    with open(vocab_file, 'rb') as f:
        vocab = pickle.load(f)

else:
    vocab = {}
    vocab['<bos>'] = len(vocab)
    vocab['<eos>'] = len(vocab)
    vocab['<pad>'] = len(vocab)
    vocab['<unk>'] = len(vocab)

    for data in [train_stories, val_stories]:
        for story in data:
            for word in story:
                if word_count[word] >= 5 and word not in vocab.keys():
                    vocab[word] = len(vocab)

In [100]:
print(len(vocab))

print(vocab)

2407
{'<bos>': 0, '<eos>': 1, '<pad>': 2, '<unk>': 3, 'one': 4, 'day': 5, 'a': 6, 'little': 7, 'girl': 8, 'named': 9, 'lily': 10, 'found': 11, 'needle': 12, 'in': 13, 'her': 14, 'room': 15, 'she': 16, 'knew': 17, 'it': 18, 'was': 19, 'difficult': 20, 'to': 21, 'play': 22, 'with': 23, 'because': 24, 'sharp': 25, 'wanted': 26, 'share': 27, 'the': 28, 'mom': 29, 'so': 30, 'could': 31, 'button': 32, 'on': 33, 'shirt': 34, 'went': 35, 'and': 36, 'said': 37, 'i': 38, 'this': 39, 'can': 40, 'you': 41, 'me': 42, 'my': 43, 'smiled': 44, 'yes': 45, 'we': 46, 'fix': 47, 'your': 48, 'together': 49, 'they': 50, 'shared': 51, 's': 52, 'not': 53, 'for': 54, 'them': 55, 'were': 56, 'sharing': 57, 'helping': 58, 'each': 59, 'other': 60, 'after': 61, 'finished': 62, 'thanked': 63, 'both': 64, 'felt': 65, 'happy': 66, 'had': 67, 'worked': 68, 'once': 69, 'upon': 70, 'time': 71, 'there': 72, 'car': 73, 'beep': 74, 'loved': 75, 'go': 76, 'fast': 77, 'sun': 78, 'healthy': 79, 'he': 80, 'always': 81, 'good':

In [101]:
index_train_stories = []
padded_train_stories = []
index_val_stories = []
padded_val_stories = []

for story in tqdm(train_stories, desc='Encoding train stories'):
    index_story = [vocab.get(word, vocab['<unk>']) for word in story]
    new_story = story[:MAX_SEQ_LEN]
    new_story = new_story + ['<pad>'] * (MAX_SEQ_LEN - len(new_story))
    
    if len(index_story) > MAX_SEQ_LEN:
        index_story = index_story[:MAX_SEQ_LEN]
    else:
        index_story = index_story + [vocab['<pad>']] * (MAX_SEQ_LEN - len(index_story))

    index_train_stories.append(index_story)
    padded_train_stories.append(new_story)

for story in tqdm(val_stories, desc='Encoding val stories'):
    index_story = [vocab.get(word, vocab['<unk>']) for word in story]
    new_story = story[:MAX_SEQ_LEN]
    new_story = new_story + ['<pad>'] * (MAX_SEQ_LEN - len(new_story))

    if len(index_story) > MAX_SEQ_LEN:
        index_story = index_story[:MAX_SEQ_LEN]
    else:
        index_story = index_story + [vocab['<pad>']] * (MAX_SEQ_LEN - len(index_story))

    index_val_stories.append(index_story)
    padded_val_stories.append(new_story)

Encoding train stories: 100%|██████████| 1059/1059 [00:00<00:00, 24318.47it/s]


Encoding val stories: 100%|██████████| 109/109 [00:00<00:00, 9884.53it/s]


In [102]:
print(index_train_stories[1])
print(padded_train_stories[1])

[0, 69, 70, 6, 71, 72, 19, 6, 7, 73, 9, 74, 74, 75, 21, 76, 77, 36, 22, 13, 28, 78, 74, 19, 6, 79, 73, 24, 80, 81, 67, 82, 83, 82, 83, 84, 74, 66, 36, 85, 4, 5, 74, 19, 86, 13, 28, 87, 88, 80, 89, 6, 90, 91, 28, 91, 67, 92, 93, 94, 56, 95, 74, 96, 97, 28, 93, 98, 36, 26, 21, 22, 23, 55, 74, 99, 100, 28, 91, 36, 101, 28, 93, 98, 33, 102, 80, 103, 36, 3, 104, 3, 74, 105, 23, 28, 95, 93, 106, 5, 88, 18, 19, 71, 21, 76, 107, 74, 17, 80, 108, 109, 83, 80, 35, 21, 28, 83, 110, 36, 111, 109, 79, 83, 112, 74, 19, 113, 21, 76, 77, 36, 22, 114, 28, 115, 5, 36, 74, 116, 117, 118, 61, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 

In [106]:
class Small_Transformers_Dataset(Dataset):
    def __init__(self, index_stories, vocab):
        self.index_stories = index_stories
        self.vocab = vocab

    def __len__(self):
        return len(self.index_stories)
    
    def __getitem__(self, idx):
        story = self.index_stories[idx]
        story_shifted = story[1:] + [self.vocab['<pad>']]

        return torch.tensor(story), torch.tensor(story_shifted)

In [107]:
dataset = Small_Transformers_Dataset(index_train_stories, vocab)

In [108]:
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)