In [1]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    LoraConfig,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)

import evaluate
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from tqdm import tqdm

In [2]:
batch_size = 2
model_name_or_path = "distilbert/distilgpt2"
peft_type = PeftType.LORA
num_epochs = 20

In [3]:
from evaluate import load, Metric
# Define the evaluation metric
metric = load("perplexity")

In [4]:
# original LORA config copied from hugging face repo .
# task type changed to Seq generation 
peft_config = LoraConfig(task_type="SEQ_GEN", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
lr = 3e-4

In [5]:
if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
    padding_side = "left"
else:
    padding_side = "right"

padding_side

'left'

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id



In [7]:
dataset = load_dataset("json", data_files={"train" : "email.jsonl","test" : "email2.jsonl" })

In [8]:
def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    outputs = tokenizer(examples["text"], padding=True,truncation=True)
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

In [9]:
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["type","text"]
)

Map:   0%|          | 0/6 [00:00<?, ? examples/s]

In [10]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 6
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 6
    })
})

In [11]:
def collate_fn(examples):
    return tokenizer.pad(examples, padding="longest", return_tensors="pt")

In [12]:
# Instantiate dataloaders.
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)
eval_dataloader = DataLoader(
    tokenized_datasets["test"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
)

In [13]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model

trainable params: 147,456 || all params: 82,060,032 || trainable%: 0.1797




PeftModel(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-5): 6 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D()
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (c_proj):

In [14]:
optimizer = AdamW(params=model.parameters(), lr=lr)

In [15]:
# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [16]:
model.to(torch.device("cpu"))

PeftModel(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-5): 6 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D()
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (c_proj):

In [18]:
for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch.to(torch.device("cpu"))
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch.to(torch.device("cpu"))
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        # Decode the predictions to actual text
        predicted_texts = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions]

        # Print the predicted texts neatly
        for idx, text in enumerate(predicted_texts, 1):
          print(f"{idx}. {text}")
          
        predictions, references = predictions, batch["labels"]
        '''
        metric.add_batch(
            predictions=predictions,
            references=references,
        )
        '''

    #eval_metric = metric.compute()
    #print(f"epoch {epoch}:", eval_metric)

100%|██████████| 3/3 [00:00<00:00,  5.56it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  8.91it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  8.48it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.26it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.67it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.22it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.01it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  5.61it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.73it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.38it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.77it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.48it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.06it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  4.96it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  4.95it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.10it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.62it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.42it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.72it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.94it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.99it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.01it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.32it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.33it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.30it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.54it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.45it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.04it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  5.77it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.81it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.12it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.60it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.95it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.46it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.71it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.83it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.82it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  5.45it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.44it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.26it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.27it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.06it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.38it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  5.69it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.96it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.57it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.10it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  5.39it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.14it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  5.85it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.05it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.50it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.75it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.38it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.39it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  7.34it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.64it/s]


1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.


100%|██████████| 3/3 [00:00<00:00,  3.12it/s]
 67%|██████▋   | 2/3 [00:00<00:00,  6.82it/s]

1. 
p
2. 
p of the payment to payment the coprorate
1. 
pgo on az cloud as of i i is can be hosted cloud platforms too
2. 
p in pipeline


100%|██████████| 3/3 [00:00<00:00,  6.68it/s]

1. , Most of customers of p are having F payment rail
2.  The,.. beulyregate theize, the aED.- it the means standards are be to. the will will be replaced.



