In [1]:
from datasets import load_metric
metric = load_metric("rouge")

def calc_rouge_scores(candidates, references):
    result = metric.compute(predictions=candidates, references=references, use_stemmer=True)
    result = {key: round(value.mid.fmeasure * 100, 1) for key, value in result.items()}
    return result

In [2]:
import argparse
from datetime import datetime
import os
import time

import numpy as np
from transformers import GPT2LMHeadModel,AdamW
from transformers import get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tnrange, tqdm

from dataset import GPT21024Dataset 
from utils import add_special_tokens, generate_sample, generate_beam_sample, set_seed


In [3]:
root_dir = './gpt2_1024_data'
ids_file = './CNN/ids.json'
valid_data = GPT21024Dataset(root_dir, ids_file, mode='valid', length=500)  #validation on only 500 datasets

In [4]:
train_data = GPT21024Dataset(root_dir,ids_file,mode='train',length=3000) #training on only 3000 datasets
valid_data = GPT21024Dataset(root_dir,ids_file,mode='valid',length=500)  #validation on only 500 datasets
tokenizer = add_special_tokens()
ignore_idx = tokenizer.pad_token_id
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model.resize_token_embeddings(len(tokenizer))

Embedding(50259, 1024)

In [5]:
from transformers import GPT2Tokenizer
def add_special_tokens():
	""" Returns GPT2 tokenizer after adding separator and padding tokens """
	tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
	special_tokens = {'pad_token':'<|pad|>','sep_token':'<|sep|>'}
	num_add_toks = tokenizer.add_special_tokens(special_tokens)
	return tokenizer

tokenizer = add_special_tokens()

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
model.resize_token_embeddings(len(tokenizer))
device = torch.device('cuda:4')
model.to(device)


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50259, 1024)
    (wpe): Embedding(1024, 1024)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout)

In [13]:
calc_rouge_scores(cands, refs)

{'rouge1': 0.6, 'rouge2': 0.0, 'rougeL': 0.6, 'rougeLsum': 0.6}

In [8]:
len(valid_data)

500

In [12]:
cands[:10]

['\ned............,,..ism.',
 ' Theically.ically.icallyically.ically.icallyicallyically.icallyicallyicallyised.......ised.ism',
 '!',
 '!',
 ' the',
 ' Clark',
 ' I',
 'I',
 '\n',
 ' The']

In [11]:
refs[:10]

['Kaleb Kula punched across face as he stood alone in a driveway. Middle school pupils cheered and filmed attack to put on Facebook. Parents say he has been target of bullies at Elkton Middle School. Police say they have charged juvenile with second degree assault. <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>',
 "Ahmed Fareed strangled Major Janet Gilson with a rope. She was found hidden under a sofa three days after being reported missing. Fareed was the ex-husband of Maj Gilson's niece. <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>

In [10]:
refs, cands = generate_beam_sample(valid_data, tokenizer, model, num=len(valid_data), device=device)

  next_token_probs = F.softmax(next_token_logits)


  0%|          | 0/99 [00:00<?, ?it/s]

  next_token_probs = F.softmax(next_token_logits)


1 1


  0%|          | 0/99 [00:00<?, ?it/s]

2 2


  0%|          | 0/99 [00:00<?, ?it/s]

3 3


  0%|          | 0/99 [00:00<?, ?it/s]

4 4


  0%|          | 0/99 [00:00<?, ?it/s]

5 5


  0%|          | 0/99 [00:00<?, ?it/s]

6 6


  0%|          | 0/99 [00:00<?, ?it/s]

7 7


  0%|          | 0/99 [00:00<?, ?it/s]

8 8


  0%|          | 0/99 [00:00<?, ?it/s]

9 9


  0%|          | 0/99 [00:00<?, ?it/s]

10 10


  0%|          | 0/99 [00:00<?, ?it/s]

11 11


  0%|          | 0/99 [00:00<?, ?it/s]

12 12


  0%|          | 0/99 [00:00<?, ?it/s]

13 13


  0%|          | 0/99 [00:00<?, ?it/s]

14 14


  0%|          | 0/99 [00:00<?, ?it/s]

15 15


  0%|          | 0/99 [00:00<?, ?it/s]

16 16


  0%|          | 0/99 [00:00<?, ?it/s]

17 17


  0%|          | 0/99 [00:00<?, ?it/s]

18 18


  0%|          | 0/99 [00:00<?, ?it/s]

19 19


  0%|          | 0/99 [00:00<?, ?it/s]

20 20


  0%|          | 0/99 [00:00<?, ?it/s]

21 21


  0%|          | 0/99 [00:00<?, ?it/s]

22 22


  0%|          | 0/99 [00:00<?, ?it/s]

23 23


  0%|          | 0/99 [00:00<?, ?it/s]

24 24


  0%|          | 0/99 [00:00<?, ?it/s]

25 25


  0%|          | 0/99 [00:00<?, ?it/s]

26 26


  0%|          | 0/99 [00:00<?, ?it/s]

27 27


  0%|          | 0/99 [00:00<?, ?it/s]

28 28


  0%|          | 0/99 [00:00<?, ?it/s]

29 29


  0%|          | 0/99 [00:00<?, ?it/s]

30 30


  0%|          | 0/99 [00:00<?, ?it/s]

31 31


  0%|          | 0/99 [00:00<?, ?it/s]

32 32


  0%|          | 0/99 [00:00<?, ?it/s]

33 33


  0%|          | 0/99 [00:00<?, ?it/s]

34 34


  0%|          | 0/99 [00:00<?, ?it/s]

35 35


  0%|          | 0/99 [00:00<?, ?it/s]

36 36


  0%|          | 0/99 [00:00<?, ?it/s]

37 37


  0%|          | 0/99 [00:00<?, ?it/s]

38 38


  0%|          | 0/99 [00:00<?, ?it/s]

39 39


  0%|          | 0/99 [00:00<?, ?it/s]

40 40


  0%|          | 0/99 [00:00<?, ?it/s]

41 41


  0%|          | 0/99 [00:00<?, ?it/s]

42 42


  0%|          | 0/99 [00:00<?, ?it/s]

43 43


  0%|          | 0/99 [00:00<?, ?it/s]

44 44


  0%|          | 0/99 [00:00<?, ?it/s]

45 45


  0%|          | 0/99 [00:00<?, ?it/s]

46 46


  0%|          | 0/99 [00:00<?, ?it/s]

47 47


  0%|          | 0/99 [00:00<?, ?it/s]

48 48


  0%|          | 0/99 [00:00<?, ?it/s]

49 49


  0%|          | 0/99 [00:00<?, ?it/s]

50 50


  0%|          | 0/99 [00:00<?, ?it/s]

51 51


  0%|          | 0/99 [00:00<?, ?it/s]

52 52


  0%|          | 0/99 [00:00<?, ?it/s]

53 53


  0%|          | 0/99 [00:00<?, ?it/s]

54 54


  0%|          | 0/99 [00:00<?, ?it/s]

55 55


  0%|          | 0/99 [00:00<?, ?it/s]

56 56


  0%|          | 0/99 [00:00<?, ?it/s]

57 57


  0%|          | 0/99 [00:00<?, ?it/s]

58 58


  0%|          | 0/99 [00:00<?, ?it/s]

59 59


  0%|          | 0/99 [00:00<?, ?it/s]

60 60


  0%|          | 0/99 [00:00<?, ?it/s]

61 61


  0%|          | 0/99 [00:00<?, ?it/s]

62 62


  0%|          | 0/99 [00:00<?, ?it/s]

63 63


  0%|          | 0/99 [00:00<?, ?it/s]

64 64


  0%|          | 0/99 [00:00<?, ?it/s]

65 65


  0%|          | 0/99 [00:00<?, ?it/s]

66 66


  0%|          | 0/99 [00:00<?, ?it/s]

67 67


  0%|          | 0/99 [00:00<?, ?it/s]

68 68


  0%|          | 0/99 [00:00<?, ?it/s]

69 69


  0%|          | 0/99 [00:00<?, ?it/s]

70 70


  0%|          | 0/99 [00:00<?, ?it/s]

71 71


  0%|          | 0/99 [00:00<?, ?it/s]

72 72


  0%|          | 0/99 [00:00<?, ?it/s]

73 73


  0%|          | 0/99 [00:00<?, ?it/s]

74 74


  0%|          | 0/99 [00:00<?, ?it/s]

75 75


  0%|          | 0/99 [00:00<?, ?it/s]

76 76


  0%|          | 0/99 [00:00<?, ?it/s]

77 77


  0%|          | 0/99 [00:00<?, ?it/s]

78 78


  0%|          | 0/99 [00:00<?, ?it/s]

79 79


  0%|          | 0/99 [00:00<?, ?it/s]

80 80


  0%|          | 0/99 [00:00<?, ?it/s]

81 81


  0%|          | 0/99 [00:00<?, ?it/s]

82 82


  0%|          | 0/99 [00:00<?, ?it/s]

83 83


  0%|          | 0/99 [00:00<?, ?it/s]

84 84


  0%|          | 0/99 [00:00<?, ?it/s]

85 85


  0%|          | 0/99 [00:00<?, ?it/s]

86 86


  0%|          | 0/99 [00:00<?, ?it/s]

87 87


  0%|          | 0/99 [00:00<?, ?it/s]

88 88


  0%|          | 0/99 [00:00<?, ?it/s]

89 89


  0%|          | 0/99 [00:00<?, ?it/s]

90 90


  0%|          | 0/99 [00:00<?, ?it/s]

91 91


  0%|          | 0/99 [00:00<?, ?it/s]

92 92


  0%|          | 0/99 [00:00<?, ?it/s]

93 93


  0%|          | 0/99 [00:00<?, ?it/s]

94 94


  0%|          | 0/99 [00:00<?, ?it/s]

95 95


  0%|          | 0/99 [00:00<?, ?it/s]

96 96


  0%|          | 0/99 [00:00<?, ?it/s]

97 97


  0%|          | 0/99 [00:00<?, ?it/s]

98 98


  0%|          | 0/99 [00:00<?, ?it/s]

99 99


  0%|          | 0/99 [00:00<?, ?it/s]

100 100


  0%|          | 0/99 [00:00<?, ?it/s]

101 101


  0%|          | 0/99 [00:00<?, ?it/s]

102 102


  0%|          | 0/99 [00:00<?, ?it/s]

103 103


  0%|          | 0/99 [00:00<?, ?it/s]

104 104


  0%|          | 0/99 [00:00<?, ?it/s]

105 105


  0%|          | 0/99 [00:00<?, ?it/s]

106 106


  0%|          | 0/99 [00:00<?, ?it/s]

107 107


  0%|          | 0/99 [00:00<?, ?it/s]

108 108


  0%|          | 0/99 [00:00<?, ?it/s]

109 109


  0%|          | 0/99 [00:00<?, ?it/s]

110 110


  0%|          | 0/99 [00:00<?, ?it/s]

111 111


  0%|          | 0/99 [00:00<?, ?it/s]

112 112


  0%|          | 0/99 [00:00<?, ?it/s]

113 113


  0%|          | 0/99 [00:00<?, ?it/s]

114 114


  0%|          | 0/99 [00:00<?, ?it/s]

115 115


  0%|          | 0/99 [00:00<?, ?it/s]

116 116


  0%|          | 0/99 [00:00<?, ?it/s]

117 117


  0%|          | 0/99 [00:00<?, ?it/s]

118 118


  0%|          | 0/99 [00:00<?, ?it/s]

119 119


  0%|          | 0/99 [00:00<?, ?it/s]

120 120


  0%|          | 0/99 [00:00<?, ?it/s]

121 121


  0%|          | 0/99 [00:00<?, ?it/s]

122 122


  0%|          | 0/99 [00:00<?, ?it/s]

123 123


  0%|          | 0/99 [00:00<?, ?it/s]

124 124


  0%|          | 0/99 [00:00<?, ?it/s]

125 125


  0%|          | 0/99 [00:00<?, ?it/s]

126 126


  0%|          | 0/99 [00:00<?, ?it/s]

127 127


  0%|          | 0/99 [00:00<?, ?it/s]

128 128


  0%|          | 0/99 [00:00<?, ?it/s]

129 129


  0%|          | 0/99 [00:00<?, ?it/s]

130 130


  0%|          | 0/99 [00:00<?, ?it/s]

131 131


  0%|          | 0/99 [00:00<?, ?it/s]

132 132


  0%|          | 0/99 [00:00<?, ?it/s]

133 133


  0%|          | 0/99 [00:00<?, ?it/s]

134 134


  0%|          | 0/99 [00:00<?, ?it/s]

135 135


  0%|          | 0/99 [00:00<?, ?it/s]

136 136


  0%|          | 0/99 [00:00<?, ?it/s]

137 137


  0%|          | 0/99 [00:00<?, ?it/s]

138 138


  0%|          | 0/99 [00:00<?, ?it/s]

139 139


  0%|          | 0/99 [00:00<?, ?it/s]

140 140


  0%|          | 0/99 [00:00<?, ?it/s]

141 141


  0%|          | 0/99 [00:00<?, ?it/s]

142 142


  0%|          | 0/99 [00:00<?, ?it/s]

143 143


  0%|          | 0/99 [00:00<?, ?it/s]

144 144


  0%|          | 0/99 [00:00<?, ?it/s]

145 145


  0%|          | 0/99 [00:00<?, ?it/s]

In [None]:
cands = []
refs = []
for i in range(len(valid_data)):
    sample = valid_data[i]
    idx = sample['sum_idx']
    context = sample['article'][:idx].tolist()
    summary = sample['article'][idx+1:][:100].tolist()
    print(tokenizer.decode(context), end='\n\n')
    print('actual_summary', end='\n\n')
    print(tokenizer.decode(summary), end='\n\n')
    refs.append(tokenizer.decode(summary))
    print ("generated")
    ref, cand = generate_beam_sample([sample], tokenizer, model, num=1, device=device)
    cands.append(cand)
    refs.append(ref)
    if (i % 100 == 0):
        print("cand", cand, end='\n\n')
        print("ref", ref, end='\n\n')

In [14]:
print(tokenizer.decode(context), end='\n\n')
print('actual_summary', end='\n\n')
print(tokenizer.decode(summary), end='\n\n')

At 95, an elderly man from Denmark is still able to lift weights that most people a quarter of his age wouldn't be able to budge from the ground. Even more astonishingly, it was only two-and-a-half years ago that powerlifter Svend Stensgaard was rushed to hospital after having a heart attack. Yet a video filmed of him working out at the gym - in a room full of boys young enough to be his great-grandchildren - proves just how fit and healthy he is in later life. In the footage, the jacked Dane defies stereotypes of pensioners as frail and feeble beings clad in their bedroom slippers. Both standing and lying on his back, he is seen shifting a whopping 290 pounds of weight as he controls his breathing to establish a rhythm. Mr Stensgaard says in the interview that exercising, which releases a lot of stress-fighting endorphins, is comparable to a ` dosage of morphine'for him. Power pa : Svend Stensgaard, 95, showcases his impressive strength while powerlifting at his local gym in Denmark. 

In [4]:
!pip3 install rouge_score

Collecting rouge_score
  Downloading https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz
Collecting six>=1.14.0
  Downloading https://files.pythonhosted.org/packages/d9/5a/e7c31adbe875f2abbb91bd84cf2dc52d792b5a01506781dbcf25c91daf11/six-1.16.0-py2.py3-none-any.whl
Building wheels for collected packages: rouge-score
  Building wheel for rouge-score (setup.py) ... [?25ldone
[?25h  Created wheel for rouge-score: filename=rouge_score-0.1.2-cp37-none-any.whl size=24957 sha256=a6aa3daec5aefd4eac50166f9629a5441663a0fd4095b4f26f809f6e75e2b7cb
  Stored in directory: /root/.cache/pip/wheels/df/aa/59/74f33db3bbedf322bcaadca4a43750ea5eb523bbe742d78b25
Successfully built rouge-score
[31mERROR: deeppavlov 0.6.1 requires fastapi==0.38.1, which is not installed.[0m
[31mERROR: deeppavlov 0.6.1 requires flasgger==0.9.2, which is not installed.[0m
[31mERROR: deeppavlov 0.6.1 requires fuzzywuzzy==0.17.0, which is not in