In [1]:
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import nltk
import string
from nltk.corpus import stopwords
import torch.nn as nn
import yaml
from tqdm import tqdm
nltk.download('stopwords')
import matplotlib.pyplot as plt
import numpy as np
import re
from gensim.models import Word2Vec

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/jazzblazzer/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
dataset = load_dataset("gigaword")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['document', 'summary'],
        num_rows: 3803957
    })
    validation: Dataset({
        features: ['document', 'summary'],
        num_rows: 189651
    })
    test: Dataset({
        features: ['document', 'summary'],
        num_rows: 1951
    })
})


In [45]:
print(dataset['train'][0])

{'document': "australia 's current account deficit shrunk by a record #.## billion dollars -lrb- #.## billion us -rrb- in the june quarter due to soaring commodity prices , figures released monday showed .", 'summary': 'australian current account deficit narrows sharply'}


In [3]:
START_TOKEN = "<s>"
END_TOKEN = "</s>"
UNKNOWN_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"

In [16]:
class LSTMDataset:
    def __init__(self, dataset, tokenizer, max_length):
        self.dataset = dataset['document']
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        tokenized_sentence = self.tokenizer(re.sub(r'[^\w\s\']', '',self.dataset[idx].lower()))
        return [START_TOKEN]+tokenized_sentence,tokenized_sentence+[END_TOKEN]

        
    
    
lstm_data = LSTMDataset(dataset['test'],nltk.word_tokenize, 100)
print(lstm_data[0])

(['<s>', 'japan', "'s", 'nec', 'corp', 'and', 'unk', 'computer', 'corp', 'of', 'the', 'united', 'states', 'said', 'wednesday', 'they', 'had', 'agreed', 'to', 'join', 'forces', 'in', 'supercomputer', 'sales'], ['japan', "'s", 'nec', 'corp', 'and', 'unk', 'computer', 'corp', 'of', 'the', 'united', 'states', 'said', 'wednesday', 'they', 'had', 'agreed', 'to', 'join', 'forces', 'in', 'supercomputer', 'sales', '</s>'])


In [17]:
embed_train_data = []
for i in range(len(lstm_data)):
    embed_train_data.append(lstm_data[i][0]+[END_TOKEN])

    

## Embeddings

In [18]:
model = Word2Vec(embed_train_data, vector_size=300, window=5, min_count=3, workers=4)

model.save("word2vec.model")

In [53]:
embedding_model = Word2Vec.load("word2vec.model")
vocab = embedding_model.wv.key_to_index
#change this to random later
vocab[PAD_TOKEN] = len(vocab)
vocab[UNKNOWN_TOKEN] = np.random.randint(0,len(vocab))
weights = torch.FloatTensor(embedding_model.wv.vectors)
embeddings = nn.Embedding.from_pretrained(weights)

In [54]:
from torch.nn.utils.rnn import pad_sequence
def collate(batch):
    inputs = []
    targets = []
    for i in range(len(batch)):
        inputs.append(torch.tensor([vocab[word] if word in vocab else vocab[UNKNOWN_TOKEN] for word in batch[i][0]]))
        targets.append(torch.tensor([vocab[word] if word in vocab else vocab[UNKNOWN_TOKEN] for word in batch[i][1]]))
    return pad_sequence(inputs, batch_first=True, padding_value=vocab[PAD_TOKEN]),pad_sequence(targets, batch_first=True, padding_value=vocab[PAD_TOKEN])

In [55]:
lstm_data_loader = DataLoader(lstm_data, batch_size=32, shuffle=True, collate_fn=collate)

In [56]:
for i in range(1):
    for inputs, targets in lstm_data_loader:
        print(inputs.shape)
        print(targets.shape)
        break

torch.Size([32, 50])
torch.Size([32, 50])


['SoS', 'australia', 'current', 'account', 'deficit', 'shrunk', 'record', 'billion', 'dollars', 'lrb', 'billion', 'us', 'rrb', 'june', 'quarter', 'due', 'soaring', 'commodity', 'prices', 'figures', 'released', 'monday', 'showed', 'EoS']
['SoS', 'australia', 'current', 'account', 'deficit', 'shrunk', 'record', 'billion', 'dollars', 'lrb', 'billion', 'us', 'rrb', 'june', 'quarter', 'due', 'soaring', 'commodity', 'prices', 'figures', 'released', 'monday', 'showed']
['australia', 'current', 'account', 'deficit', 'shrunk', 'record', 'billion', 'dollars', 'lrb', 'billion', 'us', 'rrb', 'june', 'quarter', 'due', 'soaring', 'commodity', 'prices', 'figures', 'released', 'monday', 'showed', 'EoS']
['SoS', 'least', 'two', 'people', 'killed', 'suspected', 'bomb', 'attack', 'passenger', 'bus', 'strifetorn', 'southern', 'philippines', 'monday', 'military', 'said', 'EoS']
['SoS', 'least', 'two', 'people', 'killed', 'suspected', 'bomb', 'attack', 'passenger', 'bus', 'strifetorn', 'southern', 'philippi

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0