In [1]:
import os

from typing import Any, List, Dict, Mapping, Tuple, Union, Optional
import copy
from dataclasses import dataclass, field
import json
import pathlib
from typing import Dict, Optional, Sequence
import pickle
import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

from transformers import LlamaTokenizer, LlamaForCausalLM

In [2]:
class FinQAFinetuneDataset(Dataset):
    def __init__(self, data_path, tokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.pad_token_id = self.tokenizer.unk_token_id

        with open(data_path, "rb") as f:
            self.data = pickle.load(f)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_ids = self.data[idx]["input_ids"]
        labels = self.data[idx]["labels"]
        attention_mask = self.data[idx]["attention_mask"]
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask,
        )


In [3]:
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-65b-hf")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


In [4]:
dataset = FinQAFinetuneDataset(data_path="./llama_finetune_dataset.pkl", tokenizer=tokenizer)

In [5]:
with open("../finetune_data_prepared.jsonl") as f:
    data = [json.loads(line) for line in f]

In [6]:
all_examples = []

for item in tqdm.tqdm(data):
    prompt = item["prompt"]
    answer = item["completion"]

    prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
    answer_ids = tokenizer.encode(answer, add_special_tokens=False) + [tokenizer.eos_token_id]

    input_ids = torch.tensor(prompt_ids + answer_ids, dtype=torch.long)

    if len(input_ids) > 2048:
        print("skipping", len(input_ids))
        continue

    labels = input_ids.clone()
    labels[: len(prompt_ids)] = -100

    all_examples.append(
        dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=torch.ones_like(input_ids, dtype=torch.bool),
        )
    )

  2%|▏         | 115/6203 [00:00<00:25, 235.66it/s]

skipping 2585
skipping 2409


  4%|▍         | 243/6203 [00:01<00:24, 243.60it/s]

skipping 2350


  6%|▌         | 343/6203 [00:01<00:24, 236.96it/s]

skipping 2292
skipping 2153
skipping 2355


  8%|▊         | 466/6203 [00:01<00:25, 229.43it/s]

skipping 2341
skipping 2425
skipping 2538


  8%|▊         | 514/6203 [00:02<00:25, 225.69it/s]

skipping 2456
skipping 2270
skipping 2211


 10%|▉         | 590/6203 [00:02<00:23, 237.28it/s]

skipping 2052
skipping 2556
skipping 2354


 12%|█▏        | 769/6203 [00:03<00:22, 242.19it/s]

skipping 2350


 14%|█▎        | 848/6203 [00:03<00:21, 252.06it/s]

skipping 2341


 16%|█▌        | 1004/6203 [00:04<00:21, 241.43it/s]

skipping 2055
skipping 2060
skipping 2261


 17%|█▋        | 1080/6203 [00:04<00:21, 236.62it/s]

skipping 2273
skipping 2100
skipping 2308


 19%|█▊        | 1152/6203 [00:04<00:21, 231.54it/s]

skipping 2098
skipping 2378


 20%|█▉        | 1223/6203 [00:05<00:22, 226.10it/s]

skipping 2280
skipping 2350
skipping 2450
skipping 2205


 21%|██        | 1272/6203 [00:05<00:21, 225.25it/s]

skipping 2295
skipping 2133
skipping 2169
skipping 2149


 21%|██        | 1318/6203 [00:05<00:21, 226.28it/s]

skipping 2151
skipping 2213
skipping 2432
skipping 2286
skipping 2360


 23%|██▎       | 1439/6203 [00:06<00:20, 227.24it/s]

skipping 2172
skipping 2170
skipping 2175


 24%|██▍       | 1513/6203 [00:06<00:19, 236.90it/s]

skipping 2084
skipping 2367
skipping 2083


 26%|██▌       | 1587/6203 [00:06<00:19, 235.34it/s]

skipping 2538
skipping 2259
skipping 2304
skipping 2551
skipping 2236


 27%|██▋       | 1679/6203 [00:07<00:21, 215.33it/s]

skipping 2149
skipping 2236
skipping 2070
skipping 2222


 28%|██▊       | 1754/6203 [00:07<00:18, 235.38it/s]

skipping 2100


 29%|██▉       | 1802/6203 [00:07<00:19, 225.36it/s]

skipping 2129
skipping 2263
skipping 2065


 30%|███       | 1869/6203 [00:08<00:20, 213.20it/s]

skipping 2460
skipping 2499
skipping 2079
skipping 2259


 31%|███       | 1915/6203 [00:08<00:19, 219.48it/s]

skipping 2173


 32%|███▏      | 1984/6203 [00:08<00:18, 222.40it/s]

skipping 2362
skipping 2188


 33%|███▎      | 2034/6203 [00:08<00:17, 233.59it/s]

skipping 2166
skipping 2185


 34%|███▍      | 2106/6203 [00:09<00:17, 230.57it/s]

skipping 2405


 35%|███▍      | 2153/6203 [00:09<00:18, 223.09it/s]

skipping 2365
skipping 2135


 35%|███▌      | 2199/6203 [00:09<00:18, 216.59it/s]

skipping 2384
skipping 2525


 37%|███▋      | 2315/6203 [00:10<00:17, 218.08it/s]

skipping 2456
skipping 2429
skipping 2259


 38%|███▊      | 2361/6203 [00:10<00:17, 221.43it/s]

skipping 2125
skipping 2170


 40%|████      | 2483/6203 [00:10<00:16, 230.42it/s]

skipping 2150
skipping 2362


 41%|████      | 2553/6203 [00:11<00:16, 218.90it/s]

skipping 2239
skipping 2317
skipping 2097


 42%|████▏     | 2620/6203 [00:11<00:16, 217.31it/s]

skipping 2259


 43%|████▎     | 2665/6203 [00:11<00:16, 216.39it/s]

skipping 2504


 44%|████▎     | 2709/6203 [00:11<00:16, 214.86it/s]

skipping 2440
skipping 2309


 45%|████▍     | 2775/6203 [00:12<00:16, 210.84it/s]

skipping 2427


 46%|████▋     | 2878/6203 [00:12<00:13, 243.01it/s]

skipping 2271


 48%|████▊     | 2951/6203 [00:12<00:13, 232.66it/s]

skipping 2169
skipping 2550
skipping 2212


 48%|████▊     | 2999/6203 [00:13<00:14, 228.15it/s]

skipping 2104
skipping 2115
skipping 2495


 49%|████▉     | 3070/6203 [00:13<00:14, 221.89it/s]

skipping 2335
skipping 2190


 51%|█████     | 3168/6203 [00:13<00:13, 233.01it/s]

skipping 2271
skipping 2179


 53%|█████▎    | 3290/6203 [00:14<00:12, 228.86it/s]

skipping 2110
skipping 2273


 54%|█████▍    | 3362/6203 [00:14<00:12, 226.77it/s]

skipping 2487
skipping 2130
skipping 2417


 56%|█████▌    | 3458/6203 [00:15<00:11, 233.13it/s]

skipping 2313
skipping 2145
skipping 2151


 58%|█████▊    | 3576/6203 [00:15<00:11, 220.00it/s]

skipping 2168
skipping 2384
skipping 2524
skipping 2519


 58%|█████▊    | 3624/6203 [00:15<00:11, 223.90it/s]

skipping 2207
skipping 2240


 60%|█████▉    | 3718/6203 [00:16<00:11, 224.73it/s]

skipping 2240
skipping 2330
skipping 2184


 61%|██████    | 3767/6203 [00:16<00:10, 226.83it/s]

skipping 2080
skipping 2306
skipping 2390


 62%|██████▏   | 3862/6203 [00:16<00:10, 230.38it/s]

skipping 2076


 63%|██████▎   | 3910/6203 [00:17<00:10, 225.32it/s]

skipping 2068
skipping 2226


 64%|██████▍   | 3980/6203 [00:17<00:09, 226.43it/s]

skipping 2273
skipping 2261


 65%|██████▌   | 4050/6203 [00:17<00:09, 223.48it/s]

skipping 2257
skipping 2554
skipping 2250
skipping 2065
skipping 2662


 66%|██████▋   | 4119/6203 [00:18<00:09, 221.83it/s]

skipping 2359
skipping 2150


 68%|██████▊   | 4212/6203 [00:18<00:08, 223.34it/s]

skipping 2390
skipping 2158
skipping 2555


 69%|██████▉   | 4280/6203 [00:18<00:08, 216.82it/s]

skipping 2075
skipping 2273
skipping 2369
skipping 2239
skipping 2317


 70%|███████   | 4347/6203 [00:19<00:08, 217.61it/s]

skipping 2167
skipping 2194


 71%|███████   | 4415/6203 [00:19<00:08, 219.71it/s]

skipping 2211
skipping 2076
skipping 2363


 73%|███████▎  | 4534/6203 [00:19<00:07, 223.70it/s]

skipping 2141


 74%|███████▍  | 4581/6203 [00:20<00:07, 227.57it/s]

skipping 2555
skipping 2112
skipping 2074


 75%|███████▌  | 4676/6203 [00:20<00:07, 215.30it/s]

skipping 2087


 76%|███████▌  | 4723/6203 [00:20<00:06, 221.22it/s]

skipping 2289
skipping 2364
skipping 2267
skipping 2165


 77%|███████▋  | 4792/6203 [00:21<00:06, 221.19it/s]

skipping 2049
skipping 2529
skipping 2083
skipping 2099


 78%|███████▊  | 4862/6203 [00:21<00:06, 223.23it/s]

skipping 2239


 80%|███████▉  | 4958/6203 [00:21<00:05, 222.58it/s]

skipping 2237
skipping 2295


 82%|████████▏ | 5079/6203 [00:22<00:04, 226.20it/s]

skipping 2253
skipping 2066


 83%|████████▎ | 5126/6203 [00:22<00:04, 227.99it/s]

skipping 2069
skipping 2267
skipping 2300


 84%|████████▍ | 5218/6203 [00:22<00:04, 214.45it/s]

skipping 2518
skipping 2522


 86%|████████▌ | 5327/6203 [00:23<00:04, 208.43it/s]

skipping 2278
skipping 2185
skipping 2182


 87%|████████▋ | 5420/6203 [00:23<00:03, 214.16it/s]

skipping 2197
skipping 2068


 88%|████████▊ | 5464/6203 [00:24<00:03, 212.66it/s]

skipping 2143
skipping 2563
skipping 2586


 90%|████████▉ | 5557/6203 [00:24<00:02, 221.51it/s]

skipping 2200
skipping 2301


 90%|█████████ | 5605/6203 [00:24<00:02, 225.32it/s]

skipping 2426
skipping 2363
skipping 2272


 91%|█████████▏| 5672/6203 [00:25<00:02, 208.26it/s]

skipping 2394
skipping 2266


 92%|█████████▏| 5735/6203 [00:25<00:02, 203.84it/s]

skipping 2337
skipping 2117
skipping 2248
skipping 2165


 94%|█████████▎| 5801/6203 [00:25<00:01, 209.96it/s]

skipping 2188
skipping 2247
skipping 2117


 94%|█████████▍| 5845/6203 [00:25<00:01, 206.66it/s]

skipping 2180
skipping 2061
skipping 2070
skipping 2091


 95%|█████████▍| 5889/6203 [00:26<00:01, 205.38it/s]

skipping 2255
skipping 2269


 96%|█████████▌| 5958/6203 [00:26<00:01, 214.47it/s]

skipping 2381
skipping 2212
skipping 2054
skipping 2165
skipping 2151


 97%|█████████▋| 6002/6203 [00:26<00:00, 208.57it/s]

skipping 2524
skipping 2337
skipping 2212


 98%|█████████▊| 6068/6203 [00:26<00:00, 209.68it/s]

skipping 2389
skipping 2100


 99%|█████████▉| 6131/6203 [00:27<00:00, 204.18it/s]

skipping 2540
skipping 2328
skipping 2265
skipping 2061
skipping 2336


100%|██████████| 6203/6203 [00:27<00:00, 224.65it/s]

skipping 2381
skipping 2287
skipping 2389





In [9]:
with open("./llama_finetune_dataset.pkl", "wb") as f:
    pickle.dump(all_examples, f)