The task is to explore masked language modeling in Transformers

In [1]:
import numpy as np
import pandas as pd

from pprint import pprint
from transformers import pipeline

In [2]:
# download BBC text classification dataset
# original dataset on Kaggle: https://www.kaggle.com/datasets/shivamkushwaha/bbc-full-text-document-classification)
!wget -nc https://lazyprogrammer.me/course_files/nlp/bbc_text_cls.csv

File ‘bbc_text_cls.csv’ already there; not retrieving.



In [3]:
# save the dataset in Pandas dataframe
df = pd.read_csv('bbc_text_cls.csv')

In [4]:
# check the dataset
df.head()

Unnamed: 0,text,labels
0,Ad sales boost Time Warner profit\n\nQuarterly...,business
1,Dollar gains on Greenspan speech\n\nThe dollar...,business
2,Yukos unit buyer faces loan claim\n\nThe owner...,business
3,High fuel prices hit BA's profits\n\nBritish A...,business
4,Pernod takeover talk lifts Domecq\n\nShares in...,business


In [5]:
# check labels
df['labels'].unique()

array(['business', 'entertainment', 'politics', 'sport', 'tech'],
      dtype=object)

In [6]:
# select business texts
business_texts = df[df['labels'] == 'business']['text']

In [7]:
# select a random text
i = np.random.choice(business_texts.shape[0])
doc = business_texts.iloc[i]

In [8]:
print(doc)

Virgin Blue shares plummet 20%

Shares in Australian budget airline Virgin Blue plunged 20% after it warned of a steep fall in full year profits.

Virgin Blue said profits after tax for the year to March would be between 10% to 15% lower than the previous year. "Sluggish demand reported previously for November and now December 2004 continues," said Virgin Blue chief executive Brett Godfrey. Virgin Blue, which is 25% owned by Richard Branson, has been struggling to fend off pressure from rival Jetstar. It cut its full year passenger number forecast by "approximately 2.5%". Virgin Blue reported a 22% fall in first quarter profits in August 2004 due to tough competition. In November, first half profits were down due to slack demand and rising fuel costs. Virgin Blue was launched four years ago and now has roughly one third of Australia's domestic airline market. But the national carrier, Qantas, has fought back with its own budget airline, Jetstar, which took to the skies in May 2004. Syd

In [9]:
# use the pipeline
mlm = pipeline('fill-mask')

No model was supplied, defaulted to distilroberta-base and revision ec58a5b (https://huggingface.co/distilroberta-base).
Using a pipeline without specifying a model name and revision in production is not recommended.


Downloading config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/316M [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [10]:
# try random sentences with a mask
text = 'Children in the UK between the <mask> of 10 and 17 had an annual income of £775, \
said market analyst Datamonitor.'
pprint(mlm(text))

[{'score': 0.9955180287361145,
  'sequence': 'Children in the UK between the ages of 10 and 17 had an annual '
              'income of £775, said market analyst Datamonitor.',
  'token': 4864,
  'token_str': ' ages'},
 {'score': 0.004034961573779583,
  'sequence': 'Children in the UK between the age of 10 and 17 had an annual '
              'income of £775, said market analyst Datamonitor.',
  'token': 1046,
  'token_str': ' age'},
 {'score': 0.00022817989520262927,
  'sequence': 'Children in the UK between the years of 10 and 17 had an annual '
              'income of £775, said market analyst Datamonitor.',
  'token': 107,
  'token_str': ' years'},
 {'score': 0.00020057418441865593,
  'sequence': 'Children in the UK between the aged of 10 and 17 had an annual '
              'income of £775, said market analyst Datamonitor.',
  'token': 5180,
  'token_str': ' aged'},
 {'score': 3.930165348720038e-06,
  'sequence': 'Children in the UK between the months of 10 and 17 had an '
      

In [12]:
text = 'British teenage girls, <mask> to their counterparts in seven European countries, \
are the most keen to use make-up products.'
pprint(mlm(text))

[{'score': 0.8782907724380493,
  'sequence': 'British teenage girls, compared to their counterparts in seven '
              'European countries, are the most keen to use make-up products.',
  'token': 1118,
  'token_str': ' compared'},
 {'score': 0.03993658721446991,
  'sequence': 'British teenage girls, relative to their counterparts in seven '
              'European countries, are the most keen to use make-up products.',
  'token': 5407,
  'token_str': ' relative'},
 {'score': 0.03592807054519653,
  'sequence': 'British teenage girls, similar to their counterparts in seven '
              'European countries, are the most keen to use make-up products.',
  'token': 1122,
  'token_str': ' similar'},
 {'score': 0.010741285048425198,
  'sequence': 'British teenage girls, according to their counterparts in seven '
              'European countries, are the most keen to use make-up products.',
  'token': 309,
  'token_str': ' according'},
 {'score': 0.006287650670856237,
  'sequence': 'B