In [34]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from peft import (
    get_peft_config,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict,
    LoraConfig,
    PeftType,
    PrefixTuningConfig,
    PromptEncoderConfig,
)
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm


In [35]:
dataset = load_dataset("text", data_files={"train": ["dependency_score.txt"],"test": ["dependency_score_test.txt"]})

In [36]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 51
    })
    test: Dataset({
        features: ['text'],
        num_rows: 2
    })
})

In [37]:
batch_size = 2
model_name_or_path = "openai-community/gpt2"
peft_type = PeftType.LORA
num_epochs = 20

In [38]:
if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
    padding_side = "left"
else:
    padding_side = "right"

In [39]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [40]:
def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    outputs = tokenizer(examples["text"], padding=True,truncation=True)
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

In [41]:
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

In [42]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 51
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2
    })
})

In [43]:
# it has 51 rows
(tokenized_datasets['train']).shape

(51, 3)

In [44]:
decoded_text = tokenizer.decode(tokenized_datasets['train'][0]['input_ids'], skip_special_tokens=True)

In [45]:
print(decoded_text)

Background


In [46]:
type(tokenized_datasets['train'])

datasets.arrow_dataset.Dataset

In [47]:
# try decode the encoded 51 rows . See if you see original text
for row in tokenized_datasets['train']:
  tokens = row['input_ids']
  decoded_text = tokenizer.decode(tokens, skip_special_tokens=True)
  print(f" decoded : {decoded_text}")

 decoded : Background
 decoded : Software product developed by a large organisation have inter team dependencies.
 decoded : These dependencies stall the flow of execution and paralyse the teams. Following teams are the actors in a cross team collaborated product
 decoded : 
 decoded : Team which develops application, focuses on functional aspects
 decoded : Infra team which provides infrastructure such as security, authentication, resilliency
 decoded : Devops team which provides access to resources such as kubernetes cluster, IDP's, active directory etc.
 decoded : Application development team gets stuck because of dependencies on infra and devops team. During execution, stakeholders (such as program managers, engineering directors ) are looking for input on the epics and requirement which has higher dependency on other teams.
 decoded : 
 decoded : This is an attempt to look at calculating dependency score of a epic. This will help managers to focus on high dependency epics also com

In [48]:
def collate_fn(examples):
    return tokenizer.pad(examples, padding="longest", return_tensors="pt")

In [49]:
# Instantiate dataloaders.
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)
eval_dataloader = DataLoader(
    tokenized_datasets["test"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size
)

In [50]:
# original LORA config copied from hugging face repo .
# task type changed to Seq generation 
peft_config = LoraConfig(task_type="SEQ_GEN", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
lr = 1e-5

In [51]:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model

trainable params: 294,912 || all params: 124,734,720 || trainable%: 0.2364


PeftModel(
  (base_model): LoraModel(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-11): 12 x GPT2Block(
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (attn): GPT2Attention(
              (c_attn): lora.Linear(
                (base_layer): Conv1D()
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=768, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2304, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (c_proj

In [52]:
optimizer = AdamW(params=model.parameters(), lr=lr)

In [53]:
# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [54]:
from evaluate import load, Metric
# Define the evaluation metric
metric = load("perplexity")

In [55]:
model.to(torch.device("cpu"))
for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch.to(torch.device("cpu"))
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch.to(torch.device("cpu"))
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1)
        predicted_labels = [tokenizer.decode(label, skip_special_tokens=True) for label in predictions]
        print( f"batch[labels] {batch['labels']}")
        labels_text = [tokenizer.decode(label, skip_special_tokens=True) for label in batch["labels"]]
        metric.add_batch(
            predictions=predicted_labels,
            references=labels_text,
        )
    eval_metric = metric.compute(model_id=model_name_or_path)
    print(f"epoch {epoch}:", eval_metric)

  0%|          | 0/26 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 26/26 [00:08<00:00,  3.18it/s]
100%|██████████| 1/1 [00:00<00:00, 14.91it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 0: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:08<00:00,  2.99it/s]
100%|██████████| 1/1 [00:00<00:00, 14.47it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 1: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:08<00:00,  3.05it/s]
100%|██████████| 1/1 [00:00<00:00, 14.48it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 2: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:08<00:00,  2.96it/s]
100%|██████████| 1/1 [00:00<00:00, 12.71it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 3: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.85it/s]
100%|██████████| 1/1 [00:00<00:00, 14.99it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 4: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.80it/s]
100%|██████████| 1/1 [00:00<00:00, 15.00it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 5: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.77it/s]
100%|██████████| 1/1 [00:00<00:00, 11.98it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 6: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.67it/s]
100%|██████████| 1/1 [00:00<00:00, 12.00it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 7: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.66it/s]
100%|██████████| 1/1 [00:00<00:00, 13.60it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 8: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.63it/s]
100%|██████████| 1/1 [00:00<00:00, 12.13it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 9: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.76it/s]
100%|██████████| 1/1 [00:00<00:00, 12.39it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 10: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.70it/s]
100%|██████████| 1/1 [00:00<00:00, 12.35it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 11: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:10<00:00,  2.59it/s]
100%|██████████| 1/1 [00:00<00:00, 12.40it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 12: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.70it/s]
100%|██████████| 1/1 [00:00<00:00, 12.08it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 13: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.67it/s]
100%|██████████| 1/1 [00:00<00:00, 14.94it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 14: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:09<00:00,  2.67it/s]
100%|██████████| 1/1 [00:00<00:00, 11.72it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 15: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:14<00:00,  1.80it/s]
100%|██████████| 1/1 [00:00<00:00,  7.42it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 16: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:15<00:00,  1.64it/s]
100%|██████████| 1/1 [00:00<00:00,  8.33it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 17: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:15<00:00,  1.66it/s]
100%|██████████| 1/1 [00:00<00:00,  8.07it/s]


batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])


  0%|          | 0/1 [00:00<?, ?it/s]

epoch 18: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}


100%|██████████| 26/26 [00:15<00:00,  1.65it/s]
100%|██████████| 1/1 [00:00<00:00,  5.73it/s]

batch[labels] tensor([[ 1890, 20203,  4776, 17952,   837,  4511,  3463,   496,  2925,   284,
          1720,  8475,   764],
        [ 2504,   561,  1612,  2440,   262,  1720,  8475,    11,  2440,   262,
         20203,  4776,   764]])





  0%|          | 0/1 [00:00<?, ?it/s]

epoch 19: {'perplexities': [1469.541259765625, 716.8955688476562], 'mean_perplexity': 1093.2184143066406}
