In [1]:
import os
import argparse
import json
import math
import os
import random
from pprint import pformat

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import nltk
import datasets
import evaluate

import transformers
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    SchedulerType,
    get_scheduler,
    set_seed,
    DataCollatorForLanguageModeling,
)

from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset

import wandb
from tqdm.auto import tqdm, trange
from loguru import logger

import scripts
from adapters.models.llama.adapter_model import LlamaAdapterModel
import peft_comparison
import peft_comparison.text2text_utils
import peft_comparison.mappings
from peft_comparison.collation import DataCollatorForSeq2SeqWithMetadata, DataCollatorForCausalLMWithMetadata
from peft_comparison.tokenization_llama_fast import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
padding = "max_length"
truncation = True

source_prefix = ""
max_source_length = 512
decoder_only = True
max_target_length = 512

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

0

In [6]:
def preprocess_function(examples, is_eval=False, decoder_only=False):
    inputs = examples["source_text"]
    targets = examples["target_text"]
    inputs = [source_prefix + inp for inp in inputs]

    if not decoder_only:
        model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=truncation)
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=truncation)
        if padding == "max_length":
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        if is_eval:
            model_inputs["metadata"] = [{"targets": t} for t in targets]
    else:
        model_inputs = tokenizer(inputs, targets, max_length=max_source_length, padding=padding, truncation=True)
        if is_eval:
            tokenized_source_text = tokenizer(inputs, max_length=max_source_length, padding=False, truncation=False)
            model_inputs["metadata"] = [{"input_len": len(i)} for i in tokenized_source_text["input_ids"]]
        #model_inputs["labels"] = []
        #model_inputs["input_len"] = []
        """
        for idx, example in enumerate(model_inputs["input_ids"]):
            label = [tokenizer.pad_token_id] * len(example)
            label[0:len(example)-1] = example[1:]
            label[label == tokenizer.pad_token_id] = -100
            model_inputs["labels"].append(label)
            #model_inputs["input_len"].append(len(example))

            for i in label:
                assert type(i) == int
            for i in example:
                assert type(i) == int
            assert type(label) == list
            assert type(example) == list
            #print(type(model_inputs["input_ids"]))
            #print(example)
            #print(label)
        """
        
    #

    return model_inputs

In [7]:
raw_datasets = load_dataset("super_glue", "boolq")
raw_datasets, postprocess_fn = peft_comparison.text2text_utils.dataset_to_text2text(
    raw_datasets,
    task_type="classification",
    dataset_name="boolq",
)
column_names = list(raw_datasets["train"].column_names)

In [8]:
eval_dataset = raw_datasets["validation"].map(
    preprocess_function,
    batched=True,
    num_proc=8,
    remove_columns=column_names,
    desc="Running tokenizer on val dataset  ",
    fn_kwargs={"is_eval": True, "decoder_only": decoder_only},
)
train_dataset = raw_datasets["train"].map(
    preprocess_function,
    batched=True,
    batch_size=min(5000, len(raw_datasets["train"]) // 8),
    num_proc=8,
    remove_columns=column_names,
    desc="Running tokenizer on train dataset",
    fn_kwargs={"decoder_only": decoder_only}
)

Running tokenizer on val dataset   (num_proc=8): 100%|██████████| 3270/3270 [00:00<00:00, 5571.70 examples/s]
Running tokenizer on train dataset (num_proc=8): 100%|██████████| 9427/9427 [00:00<00:00, 9973.09 examples/s] 


In [9]:
for i in range(0, len(train_dataset)):
    if train_dataset[i]["input_ids"].__len__() != 512:
        print(i)
        print(train_dataset[i]["input_ids"].__len__())

In [None]:
train_dataset[1178]

In [11]:
data_collator = DataCollatorForCausalLMWithMetadata(
    tokenizer=tokenizer,
    padding=padding,
    max_length=max_source_length,
    pad_to_multiple_of=8,
)

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=2)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=2)

In [13]:
for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx == 0:
        
        logger.info("============= CHECKING FIRST BATCH =============")
        logger.info("\nTensor shapes: ")
        logger.info(batch["input_ids"].shape)

        logger.info("\nFirst example in tensor: ")
        logger.info(batch["input_ids"][0, :])

        print(sum(sum(batch["input_ids"] > 32000)))

        logger.info("\nDecoded text of first example in the batch:")
        s_text = tokenizer.batch_decode(batch["input_ids"][0, :].unsqueeze(0), skip_special_tokens=False)
        logger.info(f"Source text: {s_text}")

[32m2023-10-27 00:20:00.773[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
Tensor shapes: [0m
[32m2023-10-27 00:20:00.775[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mtorch.Size([2, 512])[0m
[32m2023-10-27 00:20:00.776[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1m
First example in tensor: [0m
[32m2023-10-27 00:20:00.778[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mtensor([    1,  6120, 29939, 13382, 29901,   341, 24495, 10969, 29879,  1192,
         6811,   278,  2440, 29892,   341, 24495, 16692,   278,   323,  4727,
        29899, 29933,  2782, 29892,  8922,   575, 29892, 28618, 21542,   300,
        29892,   322,   612,   538, 29899,  2517,  1506,  4167,   322, 29914,
          272, 14582, 29889,  1139, 29901,   526,  3147, 29891,   289,  2782,
          322, 13630, 13840,   300,   278,  1021,     1,  1565,     2,     2,
            2,     2,     2

[32m2023-10-27 00:20:00.796[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1m
Decoded text of first example in the batch:[0m
[32m2023-10-27 00:20:00.800[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m15[0m - [1mSource text: ['<s> boolq passage: MTD Products -- Over the years, MTD acquired the Troy-Bilt, Bolens, Cub Cadet, and Yard-Man brands and/or companies. question: are troy bilt and cub cadet the same<s> true</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

tensor(0)


In [14]:
for batch_idx, batch in enumerate(eval_dataloader):
    if batch_idx == 0:
        
        logger.info("============= CHECKING FIRST BATCH =============")
        logger.info("\nTensor shapes: ")
        logger.info(batch["input_ids"].shape)

        logger.info("\nFirst example in tensor: ")
        logger.info(batch["input_ids"][0, :])

        print(sum(sum(batch["input_ids"] > 32000)))

        logger.info("\nDecoded text of first example in the batch:")
        s_text = tokenizer.batch_decode(batch["input_ids"][0, :].unsqueeze(0), skip_special_tokens=False)
        logger.info(f"Source text: {s_text}")

[32m2023-10-27 00:20:12.098[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m5[0m - [1m
Tensor shapes: [0m
[32m2023-10-27 00:20:12.099[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mtorch.Size([2, 512])[0m
[32m2023-10-27 00:20:12.101[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m8[0m - [1m
First example in tensor: [0m
[32m2023-10-27 00:20:12.102[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mtensor([    1,  6120, 29939, 13382, 29901, 13772,   273,   324, 26413,  1192,
         2178,  4768,   290,   465,  5771,  1549,   472,  3203,   777,   310,
         1438,  6576, 29901,   372,  4225,   304,   367, 21633, 29892, 16531,
        29892,   270,  1255, 29892,  6013,   358,   287, 29892,  1320, 24455,
        29892,   322, 12138,   287, 29889,  2178,   310,  1438,  6576,  1996,
         7788,   322,   385, 22035, 12425, 29889,   450,  3001,  5253,   310,
         5864,  1881,   964

tensor(0)
