In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import dataset_creators.config as config
import pandas as pd
import os
import uuid
import re

In [3]:
def check_prompt_token_limits(article, instructions, tokenizer, token_limit):
    counter = 0
    for key, instruction in instructions.items():
        prompt = f'Article: {article}\n{instruction}'
        prompt_len = len(tokenizer.encode(prompt))
        if prompt_len < token_limit:
            counter += 1 
    return counter 

def get_shortlisted_data(articles, reference_summaries, ids = [], token_limit = 4096, dataset = 'xsum'):
    
    flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
    flan_instructions = config.instructions[f'{dataset}_flant5']
    
    gpt_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
    gpt_instructions = config.instructions[f'{dataset}_gpt3']
    
    shortlisted_articles = []
    shortlisted_reference_summaries = []
    
    shortlisted_ids = []
    
    for idx, article in enumerate(articles):
        add_article = 0 
        
        flan_counter = check_prompt_token_limits(article, flan_instructions, flan_tokenizer, token_limit)
        add_article += flan_counter
        
        gpt_counter = check_prompt_token_limits(article, gpt_instructions, gpt_tokenizer, token_limit)
        add_article += gpt_counter
        

        if add_article == 4:
            shortlisted_articles.append(article)
            shortlisted_reference_summaries.append(reference_summaries[idx])
            
            if not ids:
                article_id = str(uuid.uuid4())
            else:
                article_id = ids[idx]
            
            shortlisted_ids.append(article_id)
            
    return shortlisted_articles, shortlisted_reference_summaries, shortlisted_ids

def preprocess_html_tags(article):
    CLEANR = re.compile('<.*?>') 
    article = re.sub(CLEANR, '', article)
    return article

def make_sample_chemsum(data_path, token_limit = 4096):
    dataset = load_dataset("griffin/ChemSum", split = 'test', cache_dir = '/scratch/ramprasad.sa/huggingface_datasets')
    articles = dataset['sections']
    articles = [preprocess_html_tags(each) for each in articles]
    reference_summaries = dataset['abstract']
    ids = dataset['uuid']
    
    
    shortlisted_data = {'article': [], 'reference_summary': [], 'id': [], 'origin': []}
    
    shortlisted_articles, shortlisted_reference_summaries, shortlisted_ids = get_shortlisted_data(articles, reference_summaries, ids, dataset = 'chemsum' )
    shortlisted_data['article'] += shortlisted_articles
    shortlisted_data['reference_summary'] += shortlisted_reference_summaries
    shortlisted_data['id'] += shortlisted_ids
    shortlisted_data['origin'] += ['chemsum'] * len(shortlisted_ids)
            
    isExist = os.path.exists(data_path)
    if not isExist:
        os.makedirs(data_path)
    
    df = pd.DataFrame(shortlisted_data)
    df.to_csv(f'{data_path}/test_sample.csv')
    return df

In [4]:
make_sample_chemsum('/home/ramprasad.sa/factual_annotation_llm_summaries/datasets/chemsum')

Found cached dataset parquet (/scratch/ramprasad.sa/huggingface_datasets/griffin___parquet/griffin--ChemSum-f0741275208cf814/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


KeyError: 'chemsum_flan_t5'

'bi-mediated_allylation_of_aldehydes_in_[bmim][br]:_a_mechanistic_investigation'

In [2]:
config.instructions

{'news_gpt3': {'Generic_summary': 'Summarize the above article',
  'Faithful_summary': 'Summarize the above article such that all the information in the summary is supported by the article'},
 'news_flan_t5': {'Generic_summary': 'Summarize the above article',
  'Faithful_summary': 'Summarize the above article such that all the information in the summary is supported by the article'},
 'pubmed_gpt3': {'Generic_summary': 'Summarize the above article',
  'Faithful_summary': 'Summarize the above article such that all the information in the summary is supported by the article'},
 'pubmed_flant5': {'Generic_summary': 'Summarize the above article',
  'Faithful_summary': 'Summarize the above article such that all the information in the summary is supported by the article'},
 'chemsum_gpt3': {'Generic_summary': 'Summarize the above article',
  'Faithful_summary': 'Summarize the above article such that all the information in the summary is supported by the article'},
 'chemsum_flant5': {'Generic