In [2]:
## Quietly installing transformers package to import
## the GPT2Tokenizer and TFGPT2LMHeadModel
!pip install transformers -q

[K     |████████████████████████████████| 665kB 2.8MB/s 
[K     |████████████████████████████████| 890kB 47.2MB/s 
[K     |████████████████████████████████| 3.8MB 47.0MB/s 
[K     |████████████████████████████████| 1.1MB 54.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


## Imports

In [0]:
import tensorflow as tf

from transformers import GPT2Tokenizer, TFGPT2LMHeadModel

import os
import numpy as np
import pandas as pd

from tqdm.notebook import tqdm

## Downloading the Data

The dataset is available at https://www.kaggle.com/abhinavmoudgil95/short-jokes/data

Sign In to Kaggle and begin the download process for **shortjokes.csv** file. Then copy the link address for the download file and update the _URL varaiable.

Once done run all the cells in the notebook. Also, you can cancel the donwload process for the file. :)

In [0]:
_URL = 'https://storage.googleapis.com/kaggle-data-sets/781/1457/compressed/shortjokes.csv.zip?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1590552129&Signature=b5VddktUlUoZ9TzTuto9MttfiH0mVKPaTwABgiaQ9PttC54CUDZEGOPnVMRaB8xYZ3VqzkmEAAUq5hDuS1mZOeEcNS5Fi5GFFKhfRMmoWX9FJT8aY%2B3mccSF3a7XlPSYqlU4KcHlulZqwow3LYKod7B1E6Eilh4GdAWLa59Re0wXBdcvd0mBiNx12hRFj9PzSyyGQMEXf%2Bgs7O3ZvVU9QO2zAvH0Kg2wLj4XRNXZsZFEx1TapdPtcyO%2BBcxpbCu7OCg%2F3tuqV81sKtmkJD4LPo532PX9q9M8p08LorXTRrVCid5fIN5QPFbkd%2FRRqOe6%2BsVyhUk65jl6nvodSrHICQ%3D%3D&response-content-disposition=attachment%3B+filename%3Dshortjokes.csv.zip'

In [5]:
path_to_zip = tf.keras.utils.get_file('shortjokes.csv.zip', origin=_URL, extract=True)

FILE_PATH = os.path.join(os.path.dirname(path_to_zip), 'shortjokes.csv')

Downloading data from https://storage.googleapis.com/kaggle-data-sets/781/1457/compressed/shortjokes.csv.zip?GoogleAccessId=web-data@kaggle-161607.iam.gserviceaccount.com&Expires=1590552129&Signature=b5VddktUlUoZ9TzTuto9MttfiH0mVKPaTwABgiaQ9PttC54CUDZEGOPnVMRaB8xYZ3VqzkmEAAUq5hDuS1mZOeEcNS5Fi5GFFKhfRMmoWX9FJT8aY%2B3mccSF3a7XlPSYqlU4KcHlulZqwow3LYKod7B1E6Eilh4GdAWLa59Re0wXBdcvd0mBiNx12hRFj9PzSyyGQMEXf%2Bgs7O3ZvVU9QO2zAvH0Kg2wLj4XRNXZsZFEx1TapdPtcyO%2BBcxpbCu7OCg%2F3tuqV81sKtmkJD4LPo532PX9q9M8p08LorXTRrVCid5fIN5QPFbkd%2FRRqOe6%2BsVyhUk65jl6nvodSrHICQ%3D%3D&response-content-disposition=attachment%3B+filename%3Dshortjokes.csv.zip


## Preparing the Dataset

### Extracting jokes list from CSV

In [0]:
 pd.options.display.max_colwidth = None

In [7]:
jokes = pd.read_csv(FILE_PATH)
jokes.head()

Unnamed: 0,ID,Joke
0,1,"[me narrating a documentary about narrators] ""I can't hear what they're saying cuz I'm talking"""
1,2,"Telling my daughter garlic is good for you. Good immune system and keeps pests away.Ticks, mosquitos, vampires... men."
2,3,I've been going through a really rough period at work this week It's my own fault for swapping my tampax for sand paper.
3,4,"If I could have dinner with anyone, dead or alive... ...I would choose alive. -B.J. Novak-"
4,5,Two guys walk into a bar. The third guy ducks.


In [8]:
jokeslist = jokes['Joke'].to_list()
jokeslist[:5]

['[me narrating a documentary about narrators] "I can\'t hear what they\'re saying cuz I\'m talking"',
 'Telling my daughter garlic is good for you. Good immune system and keeps pests away.Ticks, mosquitos, vampires... men.',
 "I've been going through a really rough period at work this week It's my own fault for swapping my tampax for sand paper.",
 'If I could have dinner with anyone, dead or alive... ...I would choose alive. -B.J. Novak-',
 'Two guys walk into a bar. The third guy ducks.']

### Creating the Tokenizer for word tokenization

In [9]:
Tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




In [10]:
special_tokens_dict = {'pad_token': 'pad'}
num_added_toks = Tokenizer.add_special_tokens(special_tokens_dict)
print('We have added', num_added_toks, 'tokens')

We have added 0 tokens


### Create Dataset from List

In [0]:
# A utility method to create a tf.data dataset from a List of jokes
def jokeslist_to_dataset(jokeslist, tokenizer, 
                  shuffle=True, batch_size=16, MAX_LEN = 64):
  

  jokeslist = ['<|start|> ' + joke + ' <|end|>' for joke in jokeslist]

  encodings = [tokenizer.encode_plus(joke,
                                  None,
                                  add_special_tokens = True,
                                  max_length = MAX_LEN,
                                  pad_to_max_length = True) 
              for joke in jokeslist]

  ids = [x['input_ids'] for x in encodings]
  masks = [x['attention_mask'] for x in encodings]
  types = [x['token_type_ids'] for x in encodings]

  inputs = {}
  inputs['input_ids'] = ids
  inputs['attention_mask'] = masks
  inputs['token_type_ids'] = types

  ds = tf.data.Dataset.from_tensor_slices(inputs)

  if shuffle:
    ds = ds.shuffle(buffer_size=len(jokeslist))

  ds = ds.batch(batch_size)

  return ds

Note: This is a costly process since all the tokenization is done immediately. Thus it is expected to be slow. The only advantage is since everything is processed and kept in memory we are saving repated operationg while training.

In [0]:
jokes_dataset = jokeslist_to_dataset(jokeslist, Tokenizer)

## The Model

In [13]:
model = TFGPT2LMHeadModel.from_pretrained('gpt2-medium')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=718.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1419628976.0, style=ProgressStyle(descr…




In [14]:
model.summary()

Model: "tfgp_t2lm_head_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
transformer (TFGPT2MainLayer multiple                  354823168 
Total params: 354,823,168
Trainable params: 354,823,168
Non-trainable params: 0
_________________________________________________________________


## Loss Function and Optimizer

In [0]:
loss_function = tf.losses.SparseCategoricalCrossentropy(from_logits=True)

In [0]:
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, epsilon=1e-08, clipnorm=1.0)

## Training and Checkpointing

In [17]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
CHECKPOINT_PATH = "/content/gdrive/My Drive/Weights/JokeGenGPT2"

In [0]:
checkpoint_path = CHECKPOINT_PATH

ckpt = tf.train.Checkpoint(model = model)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [0]:
@tf.function
def train_step(data_dict):  
  with tf.GradientTape() as tape:
    outputs = model(data_dict)
    lm_logits = outputs[0]
    labels = data_dict['input_ids']

    shift_logits = lm_logits[..., :-1, :]
    shift_labels = labels[..., 1:]
    
    loss = loss_function(tf.reshape(shift_labels, (-1,)),
         tf.reshape(shift_logits, 
                   (-1, shift_logits.shape[-1])))

  gradients = tape.gradient(loss, model.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  return loss

In [21]:
EPOCHS = 3

for epoch in range(EPOCHS):
  
  for batch, data in tqdm(enumerate(jokes_dataset)):
    loss = train_step(data)
    if batch % 100 == 0:
      print('Epoch : {0} Batch : {1} ---- Loss : {2}'.format(epoch+1, batch+1, loss))

  

  ckpt_save_path = ckpt_manager.save()
  print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                        ckpt_save_path))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Epoch : 1 Batch : 1 ---- Loss : 10.06446647644043
Epoch : 1 Batch : 101 ---- Loss : 1.3731834888458252
Epoch : 1 Batch : 201 ---- Loss : 1.5216294527053833
Epoch : 1 Batch : 301 ---- Loss : 1.1768791675567627
Epoch : 1 Batch : 401 ---- Loss : 1.2565675973892212
Epoch : 1 Batch : 501 ---- Loss : 1.232465147972107
Epoch : 1 Batch : 601 ---- Loss : 1.3808581829071045
Epoch : 1 Batch : 701 ---- Loss : 1.240048885345459
Epoch : 1 Batch : 801 ---- Loss : 1.1473602056503296
Epoch : 1 Batch : 901 ---- Loss : 1.2834093570709229
Epoch : 1 Batch : 1001 ---- Loss : 1.3706344366073608
Epoch : 1 Batch : 1101 ---- Loss : 1.3044953346252441
Epoch : 1 Batch : 1201 ---- Loss : 1.0721495151519775
Epoch : 1 Batch : 1301 ---- Loss : 1.2143783569335938
Epoch : 1 Batch : 1401 ---- Loss : 1.1996228694915771
Epoch : 1 Batch : 1501 ---- Loss : 1.3116751909255981
Epoch : 1 Batch : 1601 ---- Loss : 1.285050868988037
Epoch : 1 Batch : 1701 ---- Loss : 1.3583149909973145
Epoch : 1 Batch : 1801 ---- Loss : 1.4715481

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Epoch : 2 Batch : 1 ---- Loss : 1.1087332963943481
Epoch : 2 Batch : 101 ---- Loss : 1.10383939743042
Epoch : 2 Batch : 201 ---- Loss : 1.20725417137146
Epoch : 2 Batch : 301 ---- Loss : 1.2143092155456543
Epoch : 2 Batch : 401 ---- Loss : 0.9976906180381775
Epoch : 2 Batch : 501 ---- Loss : 0.9533005356788635
Epoch : 2 Batch : 601 ---- Loss : 1.15667724609375
Epoch : 2 Batch : 701 ---- Loss : 0.9574194550514221
Epoch : 2 Batch : 801 ---- Loss : 1.2616252899169922
Epoch : 2 Batch : 901 ---- Loss : 1.1970351934432983
Epoch : 2 Batch : 1001 ---- Loss : 1.042126178741455
Epoch : 2 Batch : 1101 ---- Loss : 1.0042039155960083
Epoch : 2 Batch : 1201 ---- Loss : 1.0429633855819702
Epoch : 2 Batch : 1301 ---- Loss : 1.1570923328399658
Epoch : 2 Batch : 1401 ---- Loss : 0.9533225297927856
Epoch : 2 Batch : 1501 ---- Loss : 1.2842003107070923
Epoch : 2 Batch : 1601 ---- Loss : 1.2865604162216187
Epoch : 2 Batch : 1701 ---- Loss : 1.321840763092041
Epoch : 2 Batch : 1801 ---- Loss : 1.15001356601

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Epoch : 3 Batch : 1 ---- Loss : 1.0839698314666748
Epoch : 3 Batch : 101 ---- Loss : 0.8334741592407227
Epoch : 3 Batch : 201 ---- Loss : 1.077292799949646
Epoch : 3 Batch : 301 ---- Loss : 1.3091721534729004
Epoch : 3 Batch : 401 ---- Loss : 1.1783109903335571
Epoch : 3 Batch : 501 ---- Loss : 1.0431628227233887
Epoch : 3 Batch : 601 ---- Loss : 1.0868970155715942
Epoch : 3 Batch : 701 ---- Loss : 1.0794204473495483
Epoch : 3 Batch : 801 ---- Loss : 0.9824198484420776
Epoch : 3 Batch : 901 ---- Loss : 1.0760180950164795
Epoch : 3 Batch : 1001 ---- Loss : 1.0809065103530884
Epoch : 3 Batch : 1101 ---- Loss : 0.9881730675697327
Epoch : 3 Batch : 1201 ---- Loss : 0.9192593693733215
Epoch : 3 Batch : 1301 ---- Loss : 0.9902061223983765
Epoch : 3 Batch : 1401 ---- Loss : 0.9855520129203796
Epoch : 3 Batch : 1501 ---- Loss : 1.1338046789169312
Epoch : 3 Batch : 1601 ---- Loss : 0.8668370246887207
Epoch : 3 Batch : 1701 ---- Loss : 1.0577040910720825
Epoch : 3 Batch : 1801 ---- Loss : 1.0791