In [1]:
import sys
import copy

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch.utils.data import DataLoader
from peft import LoraConfig

sys.path.append("../")
from src import TaskAdapter, TextLabelDataset, PromtTuningConfig, HybridPeftWrapper, Params

In [2]:
config_path = "../configs/train_config.json"
config = Params(config_path)

In [3]:
model_name = config.MODEL.BASE_MODEL_NAME
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Look at Prepared Data

In [4]:
def visualize_samples(batch):
    input_ids, attention_mask, labels = batch
    
    example_indices = np.random.randint(0, input_ids.size(0), 2)

    dash_line = "-".join(" " for _ in range(100))

    for i, index in enumerate(example_indices):
        decoded_input_ids = tokenizer.decode(input_ids[index], skip_special_tokens=True)
        decoded_labels = copy.deepcopy(labels[index])
        decoded_labels[decoded_labels == -100] = tokenizer.pad_token_id
        decoded_labels = tokenizer.decode(decoded_labels, skip_special_tokens=True)
        
        print(dash_line)
        print("Example", i+1)
        print(dash_line)
        print(f"INPUT IDS:\n{input_ids[index]}")
        print(dash_line)
        print(f"DECODED INPUT IDS:\n{decoded_input_ids}")
        print(dash_line)
        print(f"ATTENTION_MASK:\n{attention_mask[index]}")
        print(dash_line)
        print(f"LABELS:\n{labels[index]}")
        print(dash_line)
        print(f"DECODED LABELS:\n{decoded_labels}")
        print(dash_line)

**WikiSQL Dataset**

In [5]:
wikisql_datasetname = "wikisql"
wikisql_dataset_adapter = TaskAdapter(wikisql_datasetname, tokenizer)

In [6]:
wikisql_dataset = TextLabelDataset(wikisql_dataset_adapter.dataset_dict["train"], tokenizer, wikisql_dataset_adapter.start_prompt, wikisql_dataset_adapter.end_prompt)
wikisql_dataloader = DataLoader(wikisql_dataset, batch_size=config.TRAINING.BATCH_SIZE.TRAIN, shuffle=True)

In [7]:
wikisql_batch = next(iter(wikisql_dataloader))

In [8]:
visualize_samples(wikisql_batch)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT IDS:
tensor([30355,    15,    48, 11417,   139, 12558,    10,  2645,  2832,     8,
         5640,   213,     8,   926,   799,   833,    19,     3,  2047,   120,
         1755,     6, 20615,    58, 12558,    10,     1,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     

**Samsum Dataset**

In [9]:
samsum_datasetname = "samsum"
samsum_dataset_adapter = TaskAdapter(samsum_datasetname, tokenizer)

In [10]:
samsum_dataset = TextLabelDataset(samsum_dataset_adapter.dataset_dict["train"], tokenizer, samsum_dataset_adapter.start_prompt, samsum_dataset_adapter.end_prompt)
samsum_dataloader = DataLoader(samsum_dataset, batch_size=config.TRAINING.BATCH_SIZE.TRAIN, shuffle=True)

In [11]:
samsum_batch = next(iter(samsum_dataloader))

In [12]:
visualize_samples(samsum_batch)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT IDS:
tensor([12198,  1635,  1737,     8,   826,  3634,    10,  9137,    10,     3,
           23,    56,    36,  1480,  5721,  9137,    10,  2361,     3,    23,
           56,    43,    12,  1049,  1200,    44,   161, 11712,    10,   572,
           58,  9137,    10,    62,    33,  8619,  1450,   516, 11712,    10,
           78,  2087,    62,    56,   942,   430,   239,    58,  9137,    10,
          150,     6,   131,   428,   140,   128,    97,     3,    99,     3,
           23,    56,    36,   865,   145,   489, 11712,    10,     3,  1825,
            6,    78,   752,    3

**SST2 Dataset**

In [13]:
sst2_datasetname = "sst2"
sst2_dataset_adapter = TaskAdapter(sst2_datasetname, tokenizer)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [14]:
sst2_dataset = TextLabelDataset(sst2_dataset_adapter.dataset_dict["train"], tokenizer, sst2_dataset_adapter.start_prompt, sst2_dataset_adapter.end_prompt)
sst2_dataloader = DataLoader(sst2_dataset, batch_size=config.TRAINING.BATCH_SIZE.TRAIN, shuffle=True)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

In [15]:
sst2_batch = next(iter(sst2_dataloader))

In [16]:
visualize_samples(sst2_batch)

 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
Example 1
 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
INPUT IDS:
tensor([ 5331,   120,   776,     8,  6493,    13,     8,   826,  7142,    10,
           80,    13,     8,  2592,    49, 11571,     8, 12082,  5349,    65,
         2546,    16,  1100,  2594,     3,     6,   237,     3,    99,    34,
            3,    31,     7,   623,     3,    17,     9,   935,   145, 22152,
            3,     5,  4892,  2998,   295,    10,     1,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     

# Look at Model Outputs

In [17]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [18]:
lora_config = LoraConfig()
pt_config = PromtTuningConfig()

In [19]:
input_ids, attention_mask, labels = next(iter(wikisql_dataloader))

In [20]:
input_ids = input_ids.to(device)
labels = labels.to(device)
attention_mask = attention_mask.to(device)

**Model Output - Original Model Only**

In [21]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [22]:
peft_model = HybridPeftWrapper.from_config(original_model)
peft_model = peft_model.to(device)

In [23]:
loss = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss

In [24]:
loss

tensor(3.0055, device='cuda:0', grad_fn=<NllLossBackward0>)

In [25]:
loss.backward()

In [26]:
del original_model
del peft_model

In [27]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()

**Model Output - PEFT Model LoRA**

In [28]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [29]:
peft_model = HybridPeftWrapper.from_config(original_model, lora_config=lora_config)
peft_model = peft_model.to(device)

In [30]:
loss = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss

In [31]:
loss

tensor(3.0055, device='cuda:0', grad_fn=<NllLossBackward0>)

In [32]:
loss.backward()

In [33]:
del original_model
del peft_model

In [34]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()

**Model Output - PEFT Model Prompt Tuning**

In [35]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [36]:
peft_model = HybridPeftWrapper.from_config(original_model, pt_config=pt_config)
peft_model = peft_model.to(device)

In [37]:
loss = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss

In [38]:
loss

tensor(3.0603, device='cuda:0', grad_fn=<NllLossBackward0>)

In [39]:
loss.backward()

In [40]:
del original_model
del peft_model

In [41]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()

**Model Output - PEFT Model LoRA and Prompt Tuning**

In [42]:
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [43]:
peft_model = HybridPeftWrapper.from_config(original_model, lora_config=lora_config, pt_config=pt_config)
peft_model = peft_model.to(device)

In [44]:
loss = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss

In [45]:
loss

tensor(3.0603, device='cuda:0', grad_fn=<NllLossBackward0>)

In [46]:
loss.backward()

In [47]:
del original_model
del peft_model

In [48]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()