In [15]:
import pandas as pd
from datasets import load_dataset
import tiktoken
from transformers import T5Tokenizer, T5ForConditionalGeneration
import os
import pipeline_scripts.config as config

In [20]:

def check_prompt_token_limits(article, instructions, tokenizer, token_limit):
    counter = 0
    for key, instruction in instructions.items():
        prompt = f'Article: {article}\n{instruction}'
        prompt_len = len(tokenizer.encode(prompt))
        if prompt_len < token_limit:
            counter += 1 
    return counter 

def get_shortlisted_data(articles, reference_summaries, ids = [], token_limit = 4096):
    
    flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
    flan_instructions = config.instructions['xsum_flan_t5']
    
    gpt_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
    gpt_instructions = config.instructions['xsum_gpt3']
    
    shortlisted_articles = []
    shortlisted_reference_summaries = []
    
    shortlisted_ids = []
    
    for idx, article in enumerate(articles):
        add_article = 0 
        
        flan_counter = check_prompt_token_limits(article, flan_instructions, flan_tokenizer, token_limit)
        add_article += flan_counter
        
        gpt_counter = check_prompt_token_limits(article, gpt_instructions, gpt_tokenizer, token_limit)
        add_article += gpt_counter
        

        if add_article == 4:
            shortlisted_articles.append(article)
            shortlisted_reference_summaries.append(reference_summaries[idx])
            
            if not ids:
                article_id = str(uuid.uuid4())
            else:
                article_id = ids[idx]
            
            shortlisted_ids.append(article_id)
            
    return shortlisted_articles, shortlisted_reference_summaries, shortlisted_ids

def make_sample_news(data_path, token_limit = 4096):
    cnndm_dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0', split = 'test', cache_dir = '/scratch/ramprasad.sa/huggingface_datasets')
    xsum_dataset = load_dataset("xsum", split = 'test', cache_dir = '/scratch/ramprasad.sa/huggingface_datasets')
    
    
    
    shortlisted_data = {'article': [], 'reference_summary': [], 'id': [], 'origin': []}
    
    
    cnndm_articles = cnndm_dataset['article']
    cnndm_reference_summaries = cnndm_dataset['highlights']
    cnndm_ids = cnndm_dataset['id']
    
    shortlisted_articles, shortlisted_reference_summaries, shortlisted_ids = get_shortlisted_data(cnndm_articles, cnndm_reference_summaries, cnndm_ids)
    shortlisted_data['article'] += shortlisted_articles
    shortlisted_data['reference_summary'] += shortlisted_reference_summaries
    shortlisted_data['id'] += shortlisted_ids
    shortlisted_data['origin'] += ['cnndm'] * len(shortlisted_ids)
    
    xsum_articles = xsum_dataset['document']
    xsum_reference_summaries = xsum_dataset['summary']
    xsum_ids = xsum_dataset['id']
    
    shortlisted_articles, shortlisted_reference_summaries, shortlisted_ids = get_shortlisted_data(xsum_articles, xsum_reference_summaries, xsum_ids)
    shortlisted_data['article'] += shortlisted_articles
    shortlisted_data['reference_summary'] += shortlisted_reference_summaries
    shortlisted_data['id'] += shortlisted_ids
    shortlisted_data['origin'] += ['xsum'] * len(shortlisted_ids)
    
    isExist = os.path.exists(data_path)
    if not isExist:
        os.makedirs(data_path)
    
    df = pd.DataFrame(shortlisted_data)
    df.to_csv(f'{data_path}/test_sample.csv')
    return df
 

In [21]:
make_sample_news('/home/ramprasad.sa/factual_annotation_llm_summaries/datasets/news')

Found cached dataset cnn_dailymail (/scratch/ramprasad.sa/huggingface_datasets/ccdv___cnn_dailymail/3.0.0/3.0.0/0107f7388b5c6fae455a5661bcd134fc22da53ea75852027040d8d1e997f101f)
Found cached dataset xsum (/scratch/ramprasad.sa/huggingface_datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)
Token indices sequence length is longer than the specified maximum sequence length for this model (750 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (783 > 512). Running this sequence through the model will result in indexing errors


Unnamed: 0,article,reference_summary,id,origin
0,"(CNN)James Best, best known for his portrayal ...","James Best, who played the sheriff on ""The Duk...",00200e794fa41d3f7ce92cbf43e9fd4cd652bb09,cnndm
1,(CNN)The attorney for a suburban New York card...,A lawyer for Dr. Anthony Moschetto says the ch...,0021fe8d65bd0d6d76d5fefba2ac02f0c48a43f4,cnndm
2,(CNN)President Barack Obama took part in a rou...,"""No challenge poses more of a public threat th...",0041698b4463a633f912681b96f73648cb012e33,cnndm
3,Moscow (CNN)A Russian TV channel aired Hillary...,"Presidential hopeful's video, featuring gay co...",0095ce085581314285f894af73a55ea9ef003412,cnndm
4,(CNN)Marco Rubio is all in. The Republican se...,"Raul Reyes: In seeking Latino vote, Marco Rubi...",00a51d5454f2ef7dbf4c53471223a27fb9c20681,cnndm
...,...,...,...,...
22812,Amnesty International supporters are zipped in...,Two hundred body bags have been placed on Brig...,32411431,xsum
22813,The project was launched in Grenada by the pri...,Budding hospitality workers from the Caribbean...,38139638,xsum
22814,The world number two won 6-2 6-0 in 66 minutes...,Britain's Andy Murray is through to the Monte ...,36054448,xsum
22815,Mr Pistorius says he mistakenly shot his girlf...,A forensics expert has swung a cricket bat at ...,26541765,xsum


In [7]:
import uuid
uuid.uuid4(), str(uuid.uuid4())

(UUID('d8aa8716-850b-4955-834c-f8769c5375b4'),
 'c224d676-8f12-4038-888e-55345263ef1b')