In [38]:
import json
import torch
from transformers import AutoTokenizer, BertModel, LEDForConditionalGeneration
from sklearn.cluster import KMeans, AgglomerativeClustering
from datasets import load_metric
import pandas as pd
from tqdm import tqdm

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [15]:
# load our augmented dataset json file
augmented_dataset = None
with open('augmented_test.json', 'r') as f:
    augmented_dataset = json.load(f)

In [16]:
# load our model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
def embed_documents(documents):
    doc_embeddings = []
    for doc in documents:
        inputs = tokenizer.encode_plus(doc, max_length=512, pad_to_max_length=True, return_tensors="pt")
        outputs = model(**inputs)

        last_hidden_states = outputs.last_hidden_state
        # we're going to average across the tokens to get the sentence embedding
        doc_embedding = torch.mean(last_hidden_states, dim=1)
        doc_embeddings.append(doc_embedding)
    
    return doc_embeddings

In [18]:
def cluster_documents_kmeans(documents):
    doc_embeddings = embed_documents(documents)
    # turn our list of embeddings into a numpy matrix
    doc_embeddings = torch.cat(doc_embeddings).detach().numpy()
    # cluster them into 3 clusters
    kmeans = KMeans(n_clusters=3, random_state=0).fit(doc_embeddings)
    return kmeans.labels_

def cluster_documents_hierarchical(documents):
    doc_embeddings = embed_documents(documents)
    # turn our list of embeddings into a numpy matrix
    doc_embeddings = torch.cat(doc_embeddings).detach().numpy()
    # cluster them into 3 clusters
    hierarchical = AgglomerativeClustering(n_clusters=3).fit(doc_embeddings)
    return hierarchical.labels_

In [19]:
# load our model
TOKENIZER = AutoTokenizer.from_pretrained('allenai/PRIMERA')
MODEL = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA').to(device)

PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

In [20]:
# load our indices to use
sample_indices = pd.read_csv('cohere_sample_indices.csv').values.flatten()
dataset_small = [augmented_dataset[i] for i in sample_indices]

In [29]:
def process_document(documents):
    input_ids_all=[]
    #### concat with global attention on doc-sep
    input_ids = []
    for doc in documents:
        input_ids.extend(
            TOKENIZER.encode(
                doc,
                truncation=True,
                max_length=4096 // len(documents),
            )[1:-1]
        )
        input_ids.append(DOCSEP_TOKEN_ID)
    input_ids = (
        [TOKENIZER.bos_token_id]
        + input_ids
        + [TOKENIZER.eos_token_id]
    )
    input_ids_all.append(torch.tensor(input_ids).to(device))
    input_ids_final = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    ).to(device)

    return input_ids_final

"""
For each cluster, use each document in input. Output 512/3 tokens for each cluster. Append them. 
"""
def process_clusters_concat(data, clusters):
    documents = data['documents']
    # create a list of lists, where each list is a cluster
    cluster_docs = [[] for _ in range(3)]
    for i, doc in enumerate(documents):
        cluster_docs[clusters[i]].append(doc)
    cluster_summaries = []
    
    # now run the summarization on each cluster, use the first document``
    for cluster in cluster_docs:
        input_ids = process_document(cluster)
        # get the input ids and attention masks together
        global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
        # put global attention on <s> token

        global_attention_mask[:, 0] = 1
        global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1

        generated_ids = MODEL.generate(
            input_ids=input_ids,
            global_attention_mask=global_attention_mask,
            use_cache=True,
            max_length=512//3,
            num_beams=5,
        )

        generated_str = TOKENIZER.batch_decode(
                generated_ids.tolist(), skip_special_tokens=True
            )
        cluster_summaries.append(generated_str)
    
    result={}
    # concatenate the summaries
    result['generated_summaries'] = [' '.join(cluster) for cluster in cluster_summaries]
    result['gt_summaries']= data['summary']
    return result

"""
Take 1 random document from each cluster and just use that to summarize (idea — get more context on 
each document). Output max length still 512 tokens. 
"""
def process_clusters_one_doc(data, clusters):
    documents = data['documents']
    # create a list of lists, where each list is a cluster
    cluster_docs = [[] for _ in range(3)]
    for i, doc in enumerate(documents):
        cluster_docs[clusters[i]].append(doc)

    new_docs = [cluster[0] for cluster in cluster_docs]
    input_ids = process_document(new_docs)
    
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1

    generated_ids = MODEL.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=512,
        num_beams=5,
    )

    generated_str = TOKENIZER.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )
    
    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']= data['summary']
    return result


In [None]:
# run our pipeline using kmeans clustering

result_small_one_doc = []
result_small_concat = []

for i in tqdm(range(0, len(dataset_small))):
    data = dataset_small[i]
    k_clusters = cluster_documents_kmeans(data['documents'])

    result_small_one_doc.append(process_clusters_one_doc(data, k_clusters))
    result_small_concat.append(process_clusters_concat(data, k_clusters))

    # let's save the result every 50 points
    if i % 50 == 0:
        with open(f"results/kmeans_bert_one_doc{i}.json", "w") as f:
            json.dump(result_small_one_doc, f)
        
        with open(f"results/kmeans_bert_concat{i}.json", "w") as f:
            json.dump(result_small_concat, f)
        

with open(f"results/kmeans_bert_one_doc_complete.json", "w") as f:
    json.dump(result_small_one_doc, f)

with open(f"results/kmeans_bert_concat_complete.json", "w") as f:
    json.dump(result_small_concat, f)

In [32]:
generated_summaries_one_doc_kmeans = [result['generated_summaries'] for result in result_small_one_doc]
generated_summaries_concat_kmeans = [result['generated_summaries'] for result in result_small_concat]
gt_summaries = [result['gt_summaries'] for result in result_small_one_doc]

In [None]:
# run our pipeline using hierarchical clustering
result_small_one_doc = []
result_small_concat = []

for i in tqdm(range(0, len(dataset_small))):
    data = dataset_small[i]
    k_clusters = cluster_documents_hierarchical(data['documents'])

    result_small_one_doc.append(process_clusters_one_doc(data, k_clusters))
    result_small_concat.append(process_clusters_concat(data, k_clusters))

    # let's save the result every 50 points
    if i % 50 == 0:
        with open(f"results/hierarchical_bert_one_doc{i}.json", "w") as f:
            json.dump(result_small_one_doc, f)
        
        with open(f"results/hierarchical_bert_concat{i}.json", "w") as f:
            json.dump(result_small_concat, f)
        

with open(f"results/hierarchical_bert_one_doc_complete.json", "w") as f:
    json.dump(result_small_one_doc, f)

with open(f"results/hierarchical_bert_concat_complete.json", "w") as f:
    json.dump(result_small_concat, f)

In [33]:
generated_summaries_one_doc_hier = [result['generated_summaries'] for result in result_small_one_doc]
generated_summaries_concat_hier = [result['generated_summaries'] for result in result_small_concat]

In [39]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


In [None]:
score=rouge.compute(predictions=generated_summaries_one_doc_kmeans, references=gt_summaries)
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)

score=rouge.compute(predictions=generated_summaries_concat_kmeans, references=gt_summaries)
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)

In [None]:
score=rouge.compute(predictions=generated_summaries_one_doc_hier, references=gt_summaries)
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)

score=rouge.compute(predictions=generated_summaries_concat_hier, references=gt_summaries)
print(score['rouge1'].mid)
print(score['rouge2'].mid)
print(score['rougeL'].mid)