In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip install transformers datasets evaluate rouge_score accelerate

In [None]:
!pip install transformers[torch]

In [1]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /users/PAS0350/geng161/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
import json
from datasets import Dataset

meetings = {}
# read the dataset
# please enter the path of your data
splits = ('train', 'val', 'test')
for split in splits:
    data_path = '/users/PAS0350/geng161/MeetingSummary/QMSum/data/ALL/jsonl/' + split + '.jsonl'
    data = []
    with open(data_path) as f:
        for line in f:
            data.append(json.loads(line))
    n_meetings = len(data)
    print('Total {} meetings in the {} set.'.format(n_meetings, split))
    meetings[split] = data

Total 162 meetings in the train set.
Total 35 meetings in the val set.
Total 35 meetings in the test set.


In [3]:
from nltk import word_tokenize
# tokneize a sent
def tokenize(sent):
    tokens = ' '.join(word_tokenize(sent.lower()))
    return tokens
# filter some noises caused by speech recognition
def clean_data(text):
    text = text.replace('{ vocalsound } ', '')
    text = text.replace('{ disfmarker } ', '')
    text = text.replace('a_m_i_', 'ami')
    text = text.replace('l_c_d_', 'lcd')
    text = text.replace('p_m_s', 'pms')
    text = text.replace('t_v_', 'tv')
    text = text.replace('{ pause } ', '')
    text = text.replace('{ nonvocalsound } ', '')
    text = text.replace('{ gap } ', '')
    return text

In [4]:
# process data for BART
# the input of the model here is the entire content of the meeting

bart_data = {}
for split in splits:
    data = meetings[split]
    src_tgt = []
    for i in range(len(data)): # For each meeting
        # get meeting content
        src = []
        for k in range(len(data[i]['meeting_transcripts'])): # For each conversation
            cur_turn = data[i]['meeting_transcripts'][k]['speaker'].lower() + ': '
            cur_turn = cur_turn + tokenize(data[i]['meeting_transcripts'][k]['content'])
            src.append(cur_turn)
        src = ' '.join(src)
        for j in range(len(data[i]['general_query_list'])):
            cur = {}
            query = tokenize(data[i]['general_query_list'][j]['query'])
            cur['src'] = clean_data('<s> ' + query + ' </s> ' + src + ' </s>') # query + all transcripts
            target = tokenize(data[i]['general_query_list'][j]['answer'])
            cur['tgt'] = target
            src_tgt.append(cur)
        for j in range(len(data[i]['specific_query_list'])):
            cur = {}
            query = tokenize(data[i]['specific_query_list'][j]['query'])
            cur['src'] = clean_data('<s> ' + query + ' </s> ' + src + ' </s>')
            target = tokenize(data[i]['specific_query_list'][j]['answer'])
            cur['tgt'] = target
            src_tgt.append(cur)
    bart_data[split] = src_tgt
        
# print('Total {} query-summary pairs in the {} set'.format(len(bart_data), split))
# print(bart_data[2])
# with open('/users/PAS0350/geng161/MeetingSummary/QMSum/data/bart_' + split + '.jsonl', 'w') as f:
#     for i in range(len(bart_data)):
#         print(json.dumps(bart_data[i]), file=f)

In [5]:
# process data for BART
# the input of the model here is the gold span corresponding to each query
bart_data_gold = {}
for split in splits:
    data = meetings[split]
    src_tgt = []
    for i in range(len(data)):
        # get meeting content
        entire_src = []
        for k in range(len(data[i]['meeting_transcripts'])):
            cur_turn = data[i]['meeting_transcripts'][k]['speaker'].lower() + ': '
            cur_turn = cur_turn + tokenize(data[i]['meeting_transcripts'][k]['content'])
            entire_src.append(cur_turn)
        entire_src = ' '.join(entire_src)
        for j in range(len(data[i]['general_query_list'])):
            cur = {}
            query = tokenize(data[i]['general_query_list'][j]['query'])
            cur['src'] = clean_data('<s> ' + query + ' </s> ' + entire_src + ' </s>')
            target = tokenize(data[i]['general_query_list'][j]['answer'])
            cur['tgt'] = target
            src_tgt.append(cur)
        for j in range(len(data[i]['specific_query_list'])):
            cur = {}
            query = tokenize(data[i]['specific_query_list'][j]['query'])
            src = []
            # get the content in the gold span for each query
            for span in data[i]['specific_query_list'][j]['relevant_text_span']:
                assert len(span) == 2
                st, ed = int(span[0]), int(span[1])
                for k in range(st, ed + 1):
                    cur_turn = data[i]['meeting_transcripts'][k]['speaker'].lower() + ': '
                    cur_turn = cur_turn + tokenize(data[i]['meeting_transcripts'][k]['content'])
                    src.append(cur_turn)
            src = ' '.join(src)
            cur['src'] = clean_data('<s> ' + query + ' </s> ' + src + ' </s>')
            target = tokenize(data[i]['specific_query_list'][j]['answer'])
            cur['tgt'] = target
            src_tgt.append(cur)
    bart_data_gold[split] = src_tgt
print('Total {} query-summary pairs in the {} set'.format(len(bart_data_gold), split))

Total 3 query-summary pairs in the test set


In [6]:
from datasets import Dataset

train_bart_data_gold = Dataset.from_list(bart_data_gold['train'])
test_bart_data_gold = Dataset.from_list(bart_data_gold['test'])

In [7]:
# Preprocess
def preprocess_function(examples):
    model_inputs = tokenizer(examples['src'], max_length=2048, truncation=True)
    labels = tokenizer(text_target=examples['tgt'], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
from transformers import AutoTokenizer

checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [9]:
train_tokenized_bart_data_gold = train_bart_data_gold.map(preprocess_function, batched=True)
test_tokenized_bart_data_gold = test_bart_data_gold.map(preprocess_function, batched=True)

Map:   0%|          | 0/1257 [00:00<?, ? examples/s]

Map:   0%|          | 0/281 [00:00<?, ? examples/s]

In [10]:
# Evaluate
import evaluate
import numpy as np

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [11]:
# Train
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized_bart_data_gold,
    eval_dataset=test_tokenized_bart_data_gold,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,No log,2.67189,0.1771,0.0557,0.1414,0.1413,18.9858
2,2.896100,2.602722,0.1937,0.0636,0.1554,0.1552,18.9822
3,2.896100,2.57935,0.1985,0.0661,0.1598,0.1597,18.9609
4,2.709800,2.57313,0.2005,0.0661,0.1623,0.1621,18.9858




TrainOutput(global_step=1260, training_loss=2.775908648778522, metrics={'train_runtime': 421.2249, 'train_samples_per_second': 11.937, 'train_steps_per_second': 2.991, 'total_flos': 2492577269415936.0, 'train_loss': 2.775908648778522, 'epoch': 4.0})

In [16]:
from transformers import pipeline

summarizer = pipeline("summarization", model="stevhliu/my_awesome_billsum_model")
summarizer("summarize: " + test_bart_data_gold[0]['src'])

Token indices sequence length is longer than the specified maximum sequence length for this model (16612 > 512). Running this sequence through the model will result in indexing errors


[{'summary_text': "barry hughes: 'i think that if a child is convicted, i would be able to do a lot of things, and that i think . that . is a . smacking , and . it 's a good thing . to be convicted of . assault on the child . and i'm a foolish man . who . have been involved in . this ."}]

{'src': "<s> summarize the whole meeting . </s> lynne neagle am: good afternoon , everyone . welcome to the children , young people and education committee . we 've received apologies for absence from hefin david and jack sargeant . vikki howells is substituting for jack sargeant . so , vikki , welcome ; it 's good to see you in the committee . item 2 this afternoon is our eleventh evidence session on the children ( abolition of defence of reasonable punishment ) ( wales ) bill . i 'm very pleased to welcome barry hughes , who is chief crown prosecutor for wales ; kwame biney , who is senior policy advisor , cps ; and iwan jenkins , who is head of the complex casework unit , crown prosecution service cymru wales . so thank you all for attending this afternoon . we 're really looking forward to hearing your views on the bill . if you 're happy , we 'll go straight into questions from members , and the first ones are from siân gwenllian . barry hughes: perfectly happy . sian gwenllian am