In [11]:
import json
import torch
from transformers import AutoTokenizer, BertModel, LEDForConditionalGeneration
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
from datasets import load_metric
import pandas as pd
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Dataset

In [3]:
# load our augmented dataset json file
augmented_dataset = None
with open('augmented_test.json', 'r') as f:
    augmented_dataset = json.load(f)

# load our indices to use
sample_indices = pd.read_csv('cohere_sample_indices.csv').values.flatten()
dataset_small = [augmented_dataset[i] for i in sample_indices]

## Embeddings

In [4]:
# load our model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def embed_documents(documents):
    doc_embeddings = []
    for doc in documents:
        inputs = tokenizer.encode_plus(doc, max_length=512, pad_to_max_length=True, return_tensors="pt")
        outputs = model(**inputs)

        last_hidden_states = outputs.last_hidden_state
        # we're going to average across the tokens to get the sentence embedding
        doc_embedding = torch.mean(last_hidden_states, dim=1)
        doc_embeddings.append(doc_embedding)
    
    return doc_embeddings

## Clustering

In [7]:
def cluster_documents(documents):
    doc_embeddings = embed_documents(documents)
    # turn our list of embeddings into a numpy matrix
    doc_embeddings = torch.cat(doc_embeddings).detach().numpy()
    
    # run PCA on the embeddings first
    pca = PCA(n_components=0.99)
    pca.fit(doc_embeddings)
    doc_embeddings = pca.fit_transform(doc_embeddings)

    # we are going to fine tune the number of clusters to use, based on silhoutte score
    best_score = -2
    best_kmeans = None

    for k in range(2, min(10, len(doc_embeddings))):
        kmeans = KMeans(n_clusters=k, random_state=0, n_init=10).fit(doc_embeddings)
        ss = silhouette_score(doc_embeddings, kmeans.labels_)
        if ss > best_score:
            best_score = ss
            best_kmeans = kmeans.labels_
        
    return best_kmeans

## Summarization

In [8]:
# load our model
TOKENIZER = AutoTokenizer.from_pretrained('allenai/PRIMERA')
MODEL = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA').to(device)

PAD_TOKEN_ID = TOKENIZER.pad_token_id
DOCSEP_TOKEN_ID = TOKENIZER.convert_tokens_to_ids("<doc-sep>")

In [14]:
def process_document(documents):
    input_ids_all=[]
    #### concat with global attention on doc-sep
    input_ids = []
    for doc in documents:
        input_ids.extend(
            TOKENIZER.encode(
                doc,
                truncation=True,
                max_length=4096 // len(documents),
            )[1:-1]
        )
        input_ids.append(DOCSEP_TOKEN_ID)
    input_ids = (
        [TOKENIZER.bos_token_id]
        + input_ids
        + [TOKENIZER.eos_token_id]
    )
    input_ids_all.append(torch.tensor(input_ids).to(device))
    input_ids_final = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    ).to(device)

    return input_ids_final

"""
For each cluster, use each document in input. Output 512/k tokens for each cluster. Append them. 
"""
def process_clusters(data, clusters):
    documents = data['documents']
    # create a list of lists, where each list is a cluster
    cluster_docs = [[] for _ in range(len(clusters))]
    for i, doc in enumerate(documents):
        cluster_docs[clusters[i]].append(doc)

    # now run the summarization on each cluster
    cluster_summaries = []

    for cluster in cluster_docs:
        input_ids = process_document(cluster)
        
        # get the input ids and attention masks together
        global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
        
        # put global attention on <s> token
        global_attention_mask[:, 0] = 1
        global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1

        generated_ids = MODEL.generate(
            input_ids=input_ids,
            global_attention_mask=global_attention_mask,
            use_cache=True,
            max_length=512//len(cluster),
            num_beams=5,
        )

        generated_str = TOKENIZER.batch_decode(
                generated_ids.tolist(), skip_special_tokens=True
            )
        cluster_summaries.append(generated_str)
    
    result={}
    # concatenate the summaries
    result['generated_summaries'] = " ".join(cluster_summaries)
    result['gt_summaries']= data['summary']
    return result

In [15]:
# run our pipeline
results = []

for i in tqdm(range(0, len(dataset_small))):
    data = dataset_small[i]
    k_clusters = cluster_documents(data['documents'])
    results.append(process_clusters(data, k_clusters))
    print(results[-1])

    # let's save the result every 50 points
    if i % 50 == 0:
        with open(f"results/final_pipeline_{i}.json", "w") as f:
            json.dump(results, f)
        

with open(f"results/final_pipeline_complete.json", "w") as f:
    json.dump(results, f)

  0%|          | 0/200 [00:27<?, ?it/s]


KeyboardInterrupt: 