# Imports

In [None]:
# TODO: Change PATH to desired file location where results will be saved.
PATH = '.'

In [None]:
!pip install faiss-cpu

In [None]:
!pip install nltk

In [None]:
from datasets import load_dataset
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, util
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import scipy.stats as stats

In [None]:
ds = load_dataset("ccdv/pubmed-summarization", "section")

In [None]:
model = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForSeq2SeqLM.from_pretrained(model)
eval_model = SentenceTransformer('all-mpnet-base-v2')
embedder = SentenceTransformer("sentence-transformers/multi-qa-mpnet-base-dot-v1")
import nltk
nltk.download("punkt_tab")
from nltk.tokenize import sent_tokenize
import faiss

# Model

In [None]:
#
# Chunk paper into chunks of max tokens. Return both raw text and tokenized
# text. Raw text is necessary for FAISS model.
#
def chunk_paper(text, max_tokens=256):
  sentences = sent_tokenize(text)

  chunks = []
  chunks_tokens = []
  chunk = ''
  token_count = 0

  for sentence in sentences:
    tokens = tokenizer.tokenize(sentence)
    if token_count + len(tokens) > max_tokens:
      chunks.append(chunk)
      chunks_tokens.append(tokenizer.tokenize(chunk))
      chunk = sentence
      token_count = len(tokens)
    else:
      chunk += ' ' + sentence
      token_count += len(tokens)

  if chunk:
    chunks.append(chunk)
    chunks_tokens.append(tokenizer.tokenize(chunk))

  return chunks, chunks_tokens

In [None]:
def rank_chunks(idx, top:int):
  ind_values = {}
  num_chunks = idx.shape[1]
  for q_indices in idx:
    for i in range(num_chunks):
      index = q_indices[i]
      if index not in ind_values:
        ind_values[index] = num_chunks - i
      else:
        ind_values[index] += num_chunks - i
  sorted_values = sorted(ind_values.items(), key=lambda x: x[1], reverse=True)[:top]
  return [a for a,b in sorted_values]

def faiss_chunks(chunks, k):
  queries = [
    "Study objective.",
    "Methods overview.",
    "Primary conclusions."
  ]
  embeddings = embedder.encode(chunks, convert_to_numpy=True)
  index = faiss.IndexFlatL2(embeddings.shape[1]) # L2 Euclidian distance.
  index.add(embeddings)
  query_embedding = embedder.encode(queries)
  _, idx = index.search(query_embedding, k)
  best_idx = rank_chunks(idx, 3)
  return best_idx

In [None]:
llm_summaries = []

for i in range(200):
  chunks, _ = chunk_paper(ds['test'][i]['article'], max_tokens=512)

  best_indices = faiss_chunks(chunks, 1)
  best_chunks = [chunks[i] for i in best_indices]

  chunk_summaries = []
  for chunk in best_chunks:
      chunk = "Summarize:" + chunk
      inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=1024)
      summary_ids = model.generate(**inputs, max_length=128, min_length=64)
      chunk_summaries.append(tokenizer.decode(summary_ids[0], skip_special_tokens=True))

  all_summaries = ' '.join(chunk_summaries)
  final_summary = model.generate( tokenizer(all_summaries, return_tensors="pt", truncation=True, max_length=1024)["input_ids"], min_length=100, max_length=606)
  model_summary_text = tokenizer.decode(final_summary[0], skip_special_tokens=True)
  print(f'Paper {i}:', model_summary_text)
  llm_summaries.append(model_summary_text)

In [None]:
save = True
basename = PATH
if save:
  np.save(basename+'text.npy', np.array(llm_summaries))