pool -> top k instances -> fine-tune LLM

In [1]:
'''
credit: https://github.com/prateekjoshi565/Fine-Tuning-BERT
'''
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
# specify GPU
device = torch.device("cuda")

from typing import List, Union
from datasets import load_dataset


  from .autonotebook import tqdm as notebook_tqdm


# Import BERT Model and BERT Tokenizer

In [2]:
# # import BERT-base pretrained model
# bert = AutoModel.from_pretrained('bert-base-uncased', return_dict=False)

# # Load the BERT tokenizer
# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

gte = AutoModel.from_pretrained('thenlper/gte-small', return_dict=False)
tokenizer = AutoTokenizer.from_pretrained('thenlper/gte-small')



In [3]:
text_embedder = gte

# Define Model Architecture

In [4]:
class Model_Arch(nn.Module):

    def __init__(self, text_embedder):
      
      super(Model_Arch, self).__init__()

      self.text_embedder = text_embedder 
      
      # dropout layer
      self.dropout = nn.Dropout(0.1)
      
      # relu activation function
      self.relu =  nn.ReLU()

      # # dense layer 1
      # self.fc1 = nn.Linear(768,512)
      
      # # dense layer 2 (Output layer)
      # self.fc2 = nn.Linear(512,3)

      self.fc1 = nn.Linear(384, 3)

      #softmax activation function
      self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, sent_id, mask):

      #pass the inputs to the model  
      _, cls_hs = self.text_embedder(sent_id, attention_mask=mask, return_dict=False)
      # print(cls_hs.shape)
      
      x = self.fc1(cls_hs)

      # x = self.relu(x)

      # x = self.dropout(x)

      # # output layer
      # x = self.fc2(x)
      
      # apply softmax activation
      x = self.softmax(x)

      return x

In [5]:

# pass the pre-trained BERT to our define architecture
model = Model_Arch(text_embedder)

# push the model to GPU
model = model.to(device)

# Load Saved Model

In [6]:
model.load_state_dict(torch.load('saved_gte_weights.pt'))

<All keys matched successfully>

# Data Selection Pool

In [7]:

def load_raw_dataset(train_files: Union[List[str], str]):
    """ load raw dataset """
    if isinstance(train_files, str):
        train_files = [train_files]
    processed_datasets = load_dataset(
        "json",
        data_files=train_files,
    )
    return processed_datasets

In [8]:
# from datasets import disable_caching
# disable_caching()

In [9]:
data_selection = load_raw_dataset("../data/train/processed/dolly/train_dolly_data.jsonl")['train']
# data_selection = load_dataset('json', data_files='../data/train/processed/dolly/train_dolly_data.jsonl', download_mode='force_redownload')['train']
# import json
# data_selection = []
# with open("../data/train/processed/dolly/train_dolly_data.jsonl", 'r', encoding='utf-8', errors='ignore') as file:
#     for line in file:
#         data_selection.append(json.loads(line))

In [10]:
# from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# # wrap tensors
# val_data = TensorDataset(val_seq, val_mask, val_y)

# # # sampler for sampling the data during training
# # val_sampler = SequentialSampler(val_data)

# # dataLoader for validation set
# val_dataloader = DataLoader(val_data, batch_size=32)

In [11]:
# data_selection = data_selection.map(cache_file_name=data_selection.cache_files[0]['filename'])

In [12]:
# data_selection.cleanup_cache_files()

In [None]:
# pool['messages'][:2]

[[{'role': 'user',
   'content': "Task: When did Virgin Australia start operating?\n\nVirgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.\n\nAnswer:"},
  {'role': 'assistant',
   'content': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.'}],
 [{'role': 'user', 'content': 'Which is a species of fish? Tope or Rope'},
  {'role': 'assistant', 'content': '\nTope'}]]

In [None]:
# filtered = []
# with open("../data/train/processed/dolly/filtered_train_dolly_data.jsonl", 'r', encoding='utf-8', errors='ignore') as file:
#     for line in file:
#         filtered.append(json.loads(line))

In [15]:
# def unfold_QA_short(data):
#     # Create a new dataset with the unfolded QA format
#     new_dataset = []
#     for text_entry in data['messages']:
#         unfolded_text_entry = ''
#         for text_pair in text_entry:
#             unfolded_text_entry += text_pair['role'] + ': ' + text_pair['content'] + ' '
#         new_dataset.append(unfolded_text_entry)
#     return new_dataset

def unfold_QA_short(data):
    # Create a new dataset with the unfolded QA format
    new_dataset = []
    for QA_entry in data:
        unfolded_QA_entry = ''
        for QA_pair in QA_entry:
            unfolded_QA_entry += QA_pair['role'] + ': ' + QA_pair['content'] + ' '
        new_dataset.append(unfolded_QA_entry)
    return new_dataset

In [23]:
pool = data_selection.map(lambda x: tokenizer(unfold_QA_short(x['messages']), padding='max_length', truncation=True, max_length=512), batched=True)

Map: 100%|██████████| 7505/7505 [00:01<00:00, 4358.06 examples/s]


In [24]:
pool.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [None]:
# for test set
pool_seq = torch.tensor(pool['input_ids'])
pool_mask = torch.tensor(pool['attention_mask'])
# pool['input_ids'].clone().detach()

  pool_seq = torch.tensor(pool['input_ids'])
  pool_mask = torch.tensor(pool['attention_mask'])


In [26]:
pool_seq.shape

torch.Size([7505, 512])

In [None]:
# max_seq_len = 512
# # tokenize and encode sequences in the test set
# tokens_pool = tokenizer.batch_encode_plus(
#     unfold_QA_short(pool['messages']),
#     max_length = max_seq_len,
#     pad_to_max_length=True,
#     truncation=True,
#     return_token_type_ids=False
# )



# Filter Pool

In [17]:
# for test set
pool_seq = torch.tensor(tokens_pool['input_ids'])
pool_mask = torch.tensor(tokens_pool['attention_mask'])
# test_y = torch.tensor(get_labels(test_data))

In [27]:
# get predictions for test data
with torch.no_grad():
  batch_size = 32
  preds = []

  for i in range(0, len(pool_seq), batch_size):
    batch_seq = pool_seq[i:i+batch_size].to(device)
    batch_mask = pool_mask[i:i+batch_size].to(device)
    batch_preds = model(batch_seq, batch_mask)
    preds.append(batch_preds.detach().cpu().numpy())
    if i % 100 == 0:
      print("Processed ", i)
    # print("Processed ", i)

  preds = np.concatenate(preds, axis=0)
  # preds = model(pool_seq.to(device), pool_mask.to(device))
  # preds = preds.detach().cpu().numpy()

Processed  0
Processed  800
Processed  1600
Processed  2400
Processed  3200
Processed  4000
Processed  4800
Processed  5600
Processed  6400
Processed  7200


In [28]:
# model's performance
preds = np.argmax(preds, axis = 1)
# print(classification_report(test_y, preds))

In [29]:
preds.shape

(7505,)

In [30]:
target_indices = np.arange(len(preds))[preds == 2]
target_indices[:5]

array([0, 1, 3, 5, 6])

In [38]:
results = data_selection[target_indices]

In [40]:
# Create a subset of the dataset by selecting specific indices
# subset_indices = [0, 1, 2, 3, 4]  # Example indices
results = data_selection.select(target_indices)

# # Display the subset
# print(subset)

In [39]:
results.keys()

dict_keys(['dataset', 'id', 'messages'])

In [None]:
# no need if all train_data.jsonl is used
keys = results.keys()
results = []
for idx in target_indices:
    results.append({key:pool[key][idx]  for key in keys})

AttributeError: 'Dataset' object has no attribute 'keys'

In [42]:
len(results)

5482

In [44]:
import json

# Save the data rows of text into a JSONL file
with open('../data/train/processed/dolly/filtered_train_dolly_data.jsonl', 'w', encoding='utf-8') as file:
    for entry in results:
        file.write(json.dumps(entry) + '\n')

In [36]:
results['train'][:2]

KeyError: 'train'