In [1]:
import os
import argparse
import json
import math
import os
import random
from pprint import pformat

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import nltk
import datasets
import evaluate

import transformers
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer,
    SchedulerType,
    get_scheduler,
    set_seed,
    DataCollatorForLanguageModeling,
)

from accelerate import Accelerator
from accelerate.utils import set_seed
from datasets import load_dataset

import wandb
from tqdm.auto import tqdm, trange
from loguru import logger

import scripts
from adapters.models.llama.adapter_model import LlamaAdapterModel
import peft_comparison
import peft_comparison.text2text_utils
import peft_comparison.mappings
from peft_comparison.collation import DataCollatorForSeq2SeqWithMetadata, DataCollatorForCausalLMWithMetadata
from peft_comparison.tokenization_llama_fast import LlamaTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
padding = "max_length"
truncation = True

source_prefix = ""
max_source_length = 512
decoder_only = True
max_target_length = 512

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

0

In [3]:

# First we tokenize all the texts.
def preprocess_function(examples, is_eval=False, decoder_only=False):
    inputs = examples["source_text"]
    targets = examples["target_text"]
    inputs = [source_prefix + inp for inp in inputs]

    if not decoder_only:
        model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
        if padding == "max_length":
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]
        if is_eval:
            model_inputs["metadata"] = [{"targets": t} for t in targets]

    else:
        model_inputs = tokenizer(inputs, targets, max_length=max_source_length, padding=padding, truncation=True)
        model_inputs["labels"] = model_inputs["input_ids"]
        if is_eval:
            input_wo_label = tokenizer(inputs, max_length=max_source_length, padding=False, truncation=False)
            input_wo_label = input_wo_label["input_ids"]
            model_inputs["metadata"] = []
            for idx in range(len(targets)):
                model_inputs["metadata"].append(
                    {
                        "targets": targets[idx],
                        "input_len": len(input_wo_label[idx]),
                    }
                )

    return model_inputs

In [4]:
raw_datasets = load_dataset("super_glue", "rte")
raw_datasets, postprocess_fn = peft_comparison.text2text_utils.dataset_to_text2text(
    raw_datasets,
    task_type="classification",
    dataset_name="rte",
)
column_names = raw_datasets["train"].column_names
padding = "max_length" #if pad_to_max_length else False

Map:   0%|          | 0/2490 [00:00<?, ? examples/s]


KeyError: 'passage'

In [None]:
raw_datasets["train"][0]

In [None]:
eval_dataset = raw_datasets["validation"].map(
    preprocess_function,
    batched=True,
    num_proc=8,
    remove_columns=column_names,
    desc="Running tokenizer on val dataset  ",
    fn_kwargs={"is_eval": True, "decoder_only": decoder_only},
)
train_dataset = raw_datasets["train"].map(
    preprocess_function,
    batched=True,
    batch_size=min(5000, len(raw_datasets["train"]) // 8),
    num_proc=8,
    remove_columns=column_names,
    desc="Running tokenizer on train dataset",
    fn_kwargs={"decoder_only": decoder_only},
)

In [None]:
tokenizer.decode(eval_dataset[0]["input_ids"])

In [None]:
for i in range(0, len(train_dataset)):
    if train_dataset[i]["input_ids"].__len__() != 512:
        print(i)
        print(train_dataset[i]["input_ids"].__len__())

In [None]:
train_dataset[1178]

In [None]:
data_collator = DataCollatorForCausalLMWithMetadata(
    tokenizer=tokenizer,
    padding=padding,
    max_length=max_source_length,
    pad_to_multiple_of=8,
)

train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=2)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=2)

In [None]:
for batch_idx, batch in enumerate(train_dataloader):
    if batch_idx == 0:
        
        logger.info("============= CHECKING FIRST BATCH =============")
        logger.info("\nTensor shapes: ")
        logger.info(batch["input_ids"].shape)

        logger.info("\nFirst example in tensor: ")
        logger.info(batch["input_ids"][0, :])

        print(sum(sum(batch["input_ids"] > 32000)))

        logger.info("\nDecoded text of first example in the batch:")
        s_text = tokenizer.batch_decode(batch["input_ids"][0, :].unsqueeze(0), skip_special_tokens=False)
        logger.info(f"Source text: {s_text}")

In [None]:
for batch_idx, batch in enumerate(eval_dataloader):
    if batch_idx == 0:
        
        logger.info("============= CHECKING FIRST BATCH =============")
        logger.info("\nTensor shapes: ")
        logger.info(batch["input_ids"].shape)

        logger.info("\nFirst example in tensor: ")
        logger.info(batch["input_ids"][0, :])

        print(sum(sum(batch["input_ids"] > 32000)))

        logger.info("\nDecoded text of first example in the batch:")
        s_text = tokenizer.batch_decode(batch["input_ids"][0, :].unsqueeze(0), skip_special_tokens=False)
        logger.info(f"Source text: {s_text}")