<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/misc/TrueCase_GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **This is an in-progress draft!**

In [None]:
! pip install -Uq transformers sentence-transformers datasets

In [1]:
from datasets import load_dataset
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
import random
from tqdm.auto import tqdm
from typing import List
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DataCollatorForLanguageModeling
import math
import torch
from torch.utils.data import Dataset, DataLoader
random.seed(1234)
nltk.download("punkt")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
# dataset = load_dataset("wikitext", "wikitext-103-v1", split="train", streaming=True)
dataset = load_dataset("cnn_dailymail", '3.0.0', streaming=True)

In [3]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [4]:
def num_word_tokens(text: str) -> List[str]:
  word_tokens = word_tokenize(text)
  return len(word_tokens)

def num_sent_tokens(text: str) -> List[str]:
  sentence_tokens = sent_tokenize(text)
  return len(sentence_tokens)

assert num_word_tokens("This is great") == 3
assert num_sent_tokens("This is great. This is sentence 2.") == 2

In [5]:
def create_dataset(limit: int) -> List[str]:
  sentences = []
  progress_bar_total = tqdm(range(dataset["train"].dataset_size), desc="Articles")
  progress_bar_sentences = tqdm(range(limit), desc="Sentences")
  for idx, dict_ in enumerate(dataset["train"]):
    progress_bar_total.update(1)
    if len(sentences) > limit:
      break
    text = dict_["article"]
    if word_tokenize(text) and len(sent_tokenize(text)) > 4:
      sentences_extracted = sent_tokenize(text)[3:]
      # sentences_extracted = list(map(preprocess_text, sentences_extracted))
      sentences_extracted = list(filter(lambda x: num_word_tokens(x) > 5, sentences_extracted))
      if sentences_extracted:
        sentences.extend(sentences_extracted)
        progress_bar_sentences.update(len(sentences_extracted))
  
  sentences = random.sample(sentences, limit)
  split = math.ceil(0.8 * len(sentences))
  train_sentences, validation_sentences = sentences[:split], sentences[split:]
  return train_sentences, validation_sentences

In [6]:
train_sentences, validation_sentences = create_dataset(limit=1000)
print(f"Total train sentences: {len(train_sentences)}")
print(f"Total validation sentences: {len(validation_sentences)}")

Articles:   0%|          | 0/1369362325 [00:00<?, ?it/s]

Sentences:   0%|          | 0/1000 [00:00<?, ?it/s]

Total train sentences: 800
Total validation sentences: 200


In [7]:
# Check some train sentences
random.sample(train_sentences, 10)

['Carlos Alberto started his career with Fluminense, and helped them to lift the Campeonato Carioca in 2002.',
 'But in 1997, a jury found him liable for the deaths in a civil case brought by the Goldman family.',
 "Hakan, a 27-year-old security guard in Istanbul with two young children who also requested only his first name be published, told Reuters he received five or six offers from Turkey and Germany, offering 10,000-15,000 lira ($11,600), but he's holding out for 40,000 lira.",
 '"That\'s a lie.',
 'Oakley, described as a low-level employee, apparently did not make contact with any foreign government and is not a foreign agent of any kind, an official familiar with the case said.',
 "He elaborated in Arabic: His homeland won't be enjoying such freedoms anytime soon.",
 'The judge in the case will have the final say over the plea agreement.',
 '"It sometimes happens -- not often -- that the state will follow a federal prosecution by charging its own crimes for exactly the same beh

In [8]:
# Check some validation sentences
random.sample(validation_sentences, 10)

["His and President Sarkozy's concern for human rights lies behind their eagerness to join Gordon Brown's Britain in a new push for action in Darfur.",
 '"The president wants to encourage everybody to use surveillance," Snow said.',
 'She will always be remembered for her amazing public work.',
 'They were snatched off the street by gunmen.',
 'As we have heard from Prince Harry, his mother Diana did all that she could to prepare her sons for the work which lies ahead.',
 "Surgeries will be performed by Dr. Peter Grossman, a plastic surgeon with the affiliated Grossman Burn Center who is donating his services for Youssif's cause.",
 'We also get a rare glimpse inside the David Beckham academy in L.A, find out what drives the kids and who are their heroes.',
 '"It was fairly mild."',
 'Romney is leading polls in the early voting states of Iowa and New Hampshire.',
 "He said doctors usually help in finding people willing to sell their organs from their patients' lists."]

In [46]:
# Hyperparameters
MAX_LENGTH = 128
BATCH_SIZE = 16

In [40]:
class CustomDataset(Dataset):
  def __init__(self, 
               tokenizer, 
               sentences: List[str],
               max_len: int):
    self.tokenizer = tokenizer
    self.input_sentences = sentences
    self.target_sentences = list(map(str.lower, sentences)) # Make target lowercase
    self.target_sentences = list(map(lambda x: x + f" {tokenizer.eos_token}", 
                                     self.target_sentences)) # Add eos_token to target
    assert len(self.input_sentences) == len(self.target_sentences)
    self.max_len = max_len
  
  def __len__(self,):
    return len(self.input_sentences)
  
  def __getitem__(self, idx):
    sents_input = self.input_sentences[idx]
    sents_target = self.target_sentences[idx]
    return {
        "input_sentences": sents_input,
        "target_sentences": sents_target,
    }
  
  def collate_fn(self, batch):
    input_sentences = [x["input_sentences"] for x in batch]
    target_sentences = [x["target_sentences"] for x in batch]

    tokens_inp = self.tokenizer(input_sentences,
                            max_length=self.max_len,
                            padding=True,
                            truncation=True,
                            return_tensors="pt"
    )
    tokens_tgt = self.tokenizer(target_sentences,
                                max_length=self.max_len,
                                padding=True,
                                truncation=True,
                                return_tensors="pt")
    labels = tokens_tgt["input_ids"]
    labels[labels == self.tokenizer.pad_token_id] = -100
    tokens_inp["labels"] = labels

    return tokens_inp

In [47]:
train_dataset = CustomDataset(tokenizer,
                              train_sentences,
                              max_len=MAX_LENGTH)
validation_dataset = CustomDataset(tokenizer,
                              validation_sentences,
                              max_len=MAX_LENGTH)

train_loader = DataLoader(train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True, 
                          collate_fn=train_dataset.collate_fn)
validation_loader = DataLoader(validation_dataset,
                               batch_size=BATCH_SIZE,
                               shuffle=False,
                               collate_fn=validation_dataset.collate_fn)

print(f"Length of train loader: {len(train_loader)}")
print(f"Length of validation loader: {len(validation_loader)}")

Length of train loader: 50
Length of validation loader: 13
