In [1]:
import numpy as np
import json
import tqdm

%load_ext autoreload
%autoreload 2
from run_abstractive_summarizer import main
from models.abstractive_summarizer import AbstractiveSummarizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse
import json

from torchtext.datasets import WikiText2

from models.abstractive_summarizer import AbstractiveSummarizer

In [3]:
hparams = {
    "train_data": "data/train.json",
    "validation_data": "data/validation.json",
    "test_data": "data/test.json",
    'd_model': 512,
    'nhead': 4,
    'd_hid': 512,
    'nlayers': 2,
    'dropout': 0.1,
    'tokenizer': 'wordpiece',
    'learning_rate': 0.001,
    'num_epochs': 10,
    'grad_acc': 1,
    'batch_size': 8
}

In [4]:
with open(hparams['train_data'], 'r') as f:
    train_data = json.load(f)

with open(hparams['validation_data'], 'r') as f:
    validation_data = json.load(f)
    
with open(hparams['test_data'], 'r') as f:
    test_data = json.load(f)


train_articles = [article['article'] for article in train_data]
train_summaries = [article['summary'] for article in train_data]

val_articles = [article['article'] for article in validation_data]
val_summaries = [article['summary'] for article in validation_data]

test_articles = [article['article'] for article in test_data]
test_summaries = [article['summary'] for article in test_data]


In [5]:
np.random.choice(train_articles)

'Washington (CNN) -- It is a case at the intersection of science and finance, an evolving 21st century dispute that comes down to a simple question: Should the government allow patents for human genes? The Supreme Court offered little other than confusion during oral arguments on Monday on nine patents held by a Utah biotech firm. Myriad Genetics isolated two related types of biological material, BRCA-1 and BRCA-2, linked to increased hereditary risk for breast and ovarian cancer. At issue is whether "products of nature" can be treated the same as "human-made" inventions, and held as the exclusive intellectual property of individuals and companies. A ruling is expected by late June. On one side, scientists and companies argue patents encourage medical innovation and investment that saves lives. On the other, patient rights groups and civil libertarians counter the patent holders are "holding hostage" the diagnostic care and access of information available to high-risk patients. How hum

In [6]:
np.random.choice(val_articles)

"Cheltenham have appointed Gary Johnson as manager until the end of the season. Johnson, who left his post in charge of Yeovil last month, has been tasked with saving the club, who are bottom of Sky Bet League Two, from relegation from the Football League. Gary Johnson has been appointed Cheltenham manager until the end of the season . The 59-year-old told the club's official website: 'I understand the position the club are in and I will be doing my best along with (assistant) Russell (Milton) to keep the club in the Football League.' Johnson was at Cheltenham's 2-2 draw with Portsmouth at Fratton Park on March 17 and the 2-1 home defeat by Exeter City four days later. 'I feel I know a lot about the squad, but it's also important to have that continuity with Russell and the lads so we'll be having a good chat regarding all of the players' abilities,' Johnson said. Cheltenham Town are bottom of Sky Bet League Two with seven games remaining . Milton, who has taken charge of nine matches 

In [7]:
import torch
print(torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda


In [8]:
model = AbstractiveSummarizer(
    hparams['d_model'], 
    hparams['nhead'], 
    hparams['d_hid'], 
    hparams['nlayers'], 
    hparams['dropout'],
    hparams['tokenizer'],
).to(device)


2023-11-27 17:25:03,221 - INFO - Initializing tokenization vocabulary with method: wordpiece.


2023-11-27 17:25:03,487 - INFO - WordPiece tokenizer initialized with custom special tokens.


In [9]:
from torch.utils.data import DataLoader

In [10]:
train_articles[0]

"Playing computer games such as Angry Birds teaches children important life skills including concentration, resilience and problem solving, an academic has said. Professor Angela Mcfarlane, an education expert who will become head of training body the College of Teachers next month, said many games were complex and required deep learning and lateral thinking to solve them. Prof Mcfarlane said she herself had become 'hooked' on the Lemmings computer game, as well as Angry Birds, and said such games could have a place in the classroom provided they were used under supervision. Professor Angela Mcfarlane says computer games like Angry Birds can teach children valuable life-skills . Expert: Prof Mcfarlane says the games can help children learn problem solving, resilience and concentration . She said: 'There are many computer games that require quite deep learning to master the games. 'Some of that learning applies beyond games to wider life, such as concentration, problem solving, and resi

In [11]:
train_summaries[0]

"Professor Angela Mcfarlane is an education expert and former teacher .\nShe says complex computer games require concentration and resilience .\nThe former government adviser says they also teach problem-solving .\nProf Mcfarlane says she herself has been 'hooked' on Lemmings game .\nThe academic is to become head of the College of Teachers next month .\nShe is currently writing a book on education for the 'digital generation'"

In [12]:
train_dataloader = model.preprocess(train_articles[:50], train_summaries[:50], hparams['batch_size'])

Tokenizing and converting text to indices using WordPiece in batches.


In [13]:
# lengths = []
# for batch_idx, batch in tqdm.tqdm(enumerate(train_dataloader)):
#     input_ids, attention_mask, labels, _ = batch
#     for input_id in input_ids:
#         lengths.append(len(input_id))

In [14]:
val_dataloader = model.preprocess(val_articles[:50], val_summaries[:50], hparams['batch_size'])

Tokenizing and converting text to indices using WordPiece in batches.


In [165]:
test_dataloader = model.preprocess(test_articles[:50], test_summaries[:50], hparams['batch_size'])

Tokenizing and converting text to indices using WordPiece in batches.


In [17]:
model.train(
    train_dataloader,
    val_dataloader,
    learning_rate=hparams['learning_rate'],
    grad_acc=hparams['grad_acc'], 
    num_epochs=hparams['num_epochs'],
)

2023-11-27 17:28:59,405 - INFO - Beginning training.
Training epoch 0: 0it [00:00, ?it/s]

s=tensor([  101,  2652,  3274,  2399,  2107,  2004,  4854,  5055, 12011,  2336,
         2590,  2166,  4813,  2164,  6693,  1010, 24501, 18622, 10127,  1998,
         3291, 13729,  1010,  2019,  3834,  2038,  2056,  1012,  2934, 10413,
        11338, 23511,  1010,  2019,  2495,  6739,  2040,  2097,  2468,  2132,
         1997,  2731,  2303,  1996,  2267,  1997,  5089,  2279,  3204,  1010,
         2056,  2116,  2399,  2020,  3375,  1998,  3223,  2784,  4083,  1998,
        11457,  3241,  2000,  9611,  2068,  1012, 11268, 11338, 23511,  2056,
         2016,  2841,  2018,  2468,  1005, 13322,  1005,  2006,  1996,  3393,
        25057,  2015,  3274,  2208,  1010,  2004,  2092,  2004,  4854,  5055,
         1010,  1998,  2056,  2107,  2399,  2071,  2031,  1037,  2173,  1999,
         1996,  9823,  3024,  2027,  2020,  2109,  2104, 10429,  1012,  2934,
        10413, 11338, 23511,  2758,  3274,  2399,  2066,  4854,  5055,  2064,
         6570,  2336,  7070,  2166,  1011,  4813,  1012,  6739

Training epoch 0: 0it [00:00, ?it/s]

s=tensor([  101,  2111,  4147,  5499,  1011,  2806,  5929,  1998,  2312, 10154,
         2015,  2024,  7917,  2013,  5559,  2270,  7793,  2076,  1996, 25904,
         2998,  2399,  1999,  7855,  2859,  1010,  1037,  4750,  2283,  1011,
         2448,  3780,  4311,  1012,  1996,  7221,  2253,  2046,  3466,  6928,
         1999,  1996,  2103,  1997, 13173, 27871,  1010,  1999,  1996,  2406,
         1005,  1055,  2717,  3512,  7855, 25904,  2555,  1012,  2009, 12033,
         2000,  5467,  2040,  4929, 15562,  2015,  1010,  2312, 10154,  2015,
         1010,  2004,  2092,  2004,  2093,  4127,  1997,  5499, 14464,  1011,
         1011,  2164,  2216,  2007,  1996,  2732,  1998, 13152,  6454,  1011,
         1011,  2429,  2000,  1996, 13173, 27871,  3679,  1010,  1037,  3780,
         6989,  2000,  1996,  2822,  4750,  2283,  1012,  9877,  1997,  2270,
         3902,  3703,  1999,  1996,  2103,  2097,  2022, 15371,  2011,  3036,
         5073,  2000,  6204, 14148,  1012,  1996,  3902,  7221




RuntimeError: the batch number of src and tgt must be equal