In [1]:
import copy
from datasets import load_dataset, load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

In [3]:
ds = load_from_disk('flan2022/flan2021')

In [4]:
raw_dataset = ds

In [6]:
max_seq_length = 512
IGNORE_INDEX = 10
preprocessing_num_workers = 12
overwrite_cache = True
raw_dataset_column_names = raw_dataset.column_names

In [7]:
def preprocess_function(examples):
    prompts_responses=[p+" "+r for p, r in zip(examples["prompt"], examples["response"])]
    prompts_responses_tokenized=tokenizer(prompts_responses, truncation=True, max_length=max_seq_length)
    prompts_tokenized=tokenizer(examples["prompt"], truncation=True, max_length=max_seq_length)
    all_labels=copy.deepcopy(prompts_responses_tokenized["input_ids"])
    prompts_len=[len(prompt) for prompt in prompts_tokenized["input_ids"]]
    for labels, prompt_len in zip(all_labels, prompts_len):
        labels[:prompt_len]=[IGNORE_INDEX]*prompt_len
    result={k: v for k, v in prompts_responses_tokenized.items()}
    result["labels"]=all_labels
    return result

preprocessed_dataset=raw_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=preprocessing_num_workers,
    load_from_cache_file=not overwrite_cache,
    remove_columns=raw_dataset_column_names,
    desc="Preprocessing the raw dataset",
)

train_dataset=preprocessed_dataset["train"]
eval_dataset=preprocessed_dataset["validation"]

Preprocessing the raw dataset (num_proc=12):   0%|          | 0/5362361 [00:00<?, ? examples/s]


KeyError: 'prompt'

In [1]:
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
task_indices=utils.get_task_indices()

Trying to load task indices from task_indices directory...


In [3]:
# Get task embeddings
tasks, embeddings= utils.get_task_embeddings(task_indices)

Trying to load task embeddings from task_embeddings directory...


In [4]:
tasks

['snli:1.1.0',
 'multi_news:1.0.0',
 'wmt16_translate/fi-en:1.0.0',
 'glue/mnli:2.0.0',
 'wmt16_translate/cs-en:1.0.0',
 'drop:2.0.0',
 'gigaword:1.2.0',
 'squad/v1.1:3.0.0',
 'anli/r2:0.1.0',
 'word_segment',
 'gem/web_nlg_en:1.1.0',
 'cnn_dailymail:3.4.0',
 'opinion_abstracts_rotten_tomatoes',
 'super_glue/record:1.0.2',
 'wmt16_translate/de-en:1.0.0',
 'math_dataset/algebra__linear_1d:1.0.0',
 'paws_wiki:1.1.0',
 'newsroom:1.0.0',
 'wmt16_translate/ru-en:1.0.0',
 'gem/dart:1.1.0',
 'true_case',
 'wmt14_translate/fr-en:1.0.0',
 'huggingface:xsum',
 'super_glue/multirc:1.0.2',
 'quac:1.0.0',
 'sentiment140:1.0.0',
 'glue/qnli:2.0.0',
 'yelp_polarity_reviews:0.2.0',
 'coqa:1.0.0',
 'anli/r3:0.1.0',
 'gem/common_gen:1.1.0',
 'imdb_reviews/plain_text:1.0.0',
 'cosmos_qa:1.0.0',
 'fix_punct',
 'lambada:1.0.0',
 'bool_q:1.0.0',
 'piqa:1.0.0',
 'glue/cola:2.0.0',
 'para_crawl_enes',
 'winogrande:1.1.0',
 'story_cloze/2016:1.0.0',
 'natural_questions_open:1.0.0',
 'gem/wiki_lingua_english_en

In [7]:
# Get task totals
task_totals=utils.get_task_totals(task_indices)
task_totals

Getting task totals...


{'snli:1.1.0': 86739,
 'multi_news:1.0.0': 47161,
 'wmt16_translate/fi-en:1.0.0': 87065,
 'glue/mnli:2.0.0': 173695,
 'wmt16_translate/cs-en:1.0.0': 86663,
 'drop:2.0.0': 86882,
 'gigaword:1.2.0': 86565,
 'squad/v1.1:3.0.0': 86241,
 'anli/r2:0.1.0': 87127,
 'word_segment': 86949,
 'gem/web_nlg_en:1.1.0': 86525,
 'cnn_dailymail:3.4.0': 86668,
 'opinion_abstracts_rotten_tomatoes': 9093,
 'super_glue/record:1.0.2': 86787,
 'wmt16_translate/de-en:1.0.0': 86728,
 'math_dataset/algebra__linear_1d:1.0.0': 87118,
 'paws_wiki:1.1.0': 87076,
 'newsroom:1.0.0': 86807,
 'wmt16_translate/ru-en:1.0.0': 86864,
 'gem/dart:1.1.0': 86622,
 'true_case': 84874,
 'wmt14_translate/fr-en:1.0.0': 87458,
 'huggingface:xsum': 86924,
 'super_glue/multirc:1.0.2': 78361,
 'quac:1.0.0': 86877,
 'sentiment140:1.0.0': 86656,
 'glue/qnli:2.0.0': 87249,
 'yelp_polarity_reviews:0.2.0': 86288,
 'coqa:1.0.0': 20605,
 'anli/r3:0.1.0': 87133,
 'gem/common_gen:1.1.0': 86753,
 'imdb_reviews/plain_text:1.0.0': 71470,
 'cosmos_

In [10]:
# Get task budgets
submod_fnc_tasks = 'fl'
num_tasks = 8
num_instances = 25000
task_budgets=utils.get_task_budgets(tasks, task_totals, embeddings, submod_fnc_tasks, num_tasks, num_instances)
task_budgets

Getting task budgets...
Creating similarity kernel...
Instantiating facility location function...
Running the lazy greedy algorithm...
The total number of instances in the selected tasks is 133007. The total number of instances in the subset is 25000.
Applying Taylor softmax on gains...
Creating a budget split according to the probabilities computed...
Checking if we need to redistribute budget because some tasks may be assigned a budget more than the number of instances in the task
Need to redistribute a budget of 0. Redistributing...
Setting the budget of the unselected tasks to 0...


[||||||||||||||||||||]100% [Iteration 8 of 8]

{'anli/r3:0.1.0': 24998,
 'task1007_pib_translation_english_punjabi': 2,
 'task091_all_elements_from_index_i_to_j': 0,
 'task1691_qed_amara_translation': 0,
 'task519_aquamuse_question_generation': 0,
 'task493_review_polarity_classification': 0,
 'task782_pawsx_english_japanese_translation': 0,
 'task663_global_voices_en_fa_translation': 0,
 'snli:1.1.0': 0,
 'multi_news:1.0.0': 0,
 'wmt16_translate/fi-en:1.0.0': 0,
 'glue/mnli:2.0.0': 0,
 'wmt16_translate/cs-en:1.0.0': 0,
 'drop:2.0.0': 0,
 'gigaword:1.2.0': 0,
 'squad/v1.1:3.0.0': 0,
 'anli/r2:0.1.0': 0,
 'word_segment': 0,
 'gem/web_nlg_en:1.1.0': 0,
 'cnn_dailymail:3.4.0': 0,
 'opinion_abstracts_rotten_tomatoes': 0,
 'super_glue/record:1.0.2': 0,
 'wmt16_translate/de-en:1.0.0': 0,
 'math_dataset/algebra__linear_1d:1.0.0': 0,
 'paws_wiki:1.1.0': 0,
 'newsroom:1.0.0': 0,
 'wmt16_translate/ru-en:1.0.0': 0,
 'gem/dart:1.1.0': 0,
 'true_case': 0,
 'wmt14_translate/fr-en:1.0.0': 0,
 'huggingface:xsum': 0,
 'super_glue/multirc:1.0.2': 0,

In [11]:
# Get (task, template) budgets
task_template_budgets=utils.get_task_template_budgets(task_indices, task_budgets)
task_template_budgets

Getting budget for each (task, template) pair...
Processing flan2021


  0%|          | 0/70 [00:00<?, ?it/s]

Processing t0


100%|██████████| 70/70 [00:00<00:00, 23409.45it/s]


Processing niv2


100%|██████████| 193/193 [00:00<00:00, 56293.51it/s]


Processing cot


100%|██████████| 1556/1556 [00:00<00:00, 149982.47it/s]


Processing dialog


100%|██████████| 18/18 [00:00<00:00, 7871.70it/s]
100%|██████████| 4/4 [00:00<00:00, 5966.29it/s]


{'flan2021': {'snli:1.1.0': [0, 0, 0, 0],
  'multi_news:1.0.0': [0, 0, 0, 0],
  'wmt16_translate/fi-en:1.0.0': [0, 0, 0, 0],
  'glue/mnli:2.0.0': [0, 0, 0, 0],
  'wmt16_translate/cs-en:1.0.0': [0, 0, 0, 0],
  'drop:2.0.0': [0, 0, 0, 0],
  'gigaword:1.2.0': [0, 0, 0, 0],
  'squad/v1.1:3.0.0': [0, 0, 0, 0],
  'anli/r2:0.1.0': [0, 0, 0, 0],
  'word_segment': [0, 0, 0, 0],
  'gem/web_nlg_en:1.1.0': [0, 0, 0, 0],
  'cnn_dailymail:3.4.0': [0, 0, 0, 0],
  'opinion_abstracts_rotten_tomatoes': [0, 0, 0, 0],
  'super_glue/record:1.0.2': [0, 0, 0, 0],
  'wmt16_translate/de-en:1.0.0': [0, 0, 0, 0],
  'math_dataset/algebra__linear_1d:1.0.0': [0, 0, 0, 0],
  'paws_wiki:1.1.0': [0, 0, 0, 0],
  'newsroom:1.0.0': [0, 0, 0, 0],
  'wmt16_translate/ru-en:1.0.0': [0, 0, 0, 0],
  'gem/dart:1.1.0': [0, 0, 0, 0],
  'true_case': [0, 0, 0, 0],
  'wmt14_translate/fr-en:1.0.0': [0, 0, 0, 0],
  'huggingface:xsum': [0, 0, 0, 0],
  'super_glue/multirc:1.0.2': [0, 0, 0, 0],
  'quac:1.0.0': [0, 0, 0, 0],
  'sentiment1

In [None]:


# Generate and save the submodular ordering
submod_ordering = generate_submodular_ordering()
save_submodular_ordering(submod_ordering, args.submod_fnc_instances)

# Load the submodular ordering of instances in each task
submod_ordering=load_instances_submodular_ordering(args.submod_fnc_instances)

# get a list of indices to select based on task_template_budgets
indices=get_subset_indices(submod_ordering, task_template_budgets, task_indices)

# prepare final dataset based on indices
dataset=get_final_dataset(indices)

assert len(dataset["train"])==args.num_instances