In [1]:
import os

from typing import Any, List, Dict, Mapping, Tuple, Union, Optional
import copy
from dataclasses import dataclass, field
import json
import pathlib
from typing import Dict, Optional, Sequence
import pickle
import tqdm

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother

from transformers import T5ForConditionalGeneration, AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")

In [3]:
with open("../finetune_data_prepared.jsonl") as f:
    data = [json.loads(line) for line in f]

In [7]:
all_examples = []

for item in tqdm.tqdm(data):
    prompt = item["prompt"]
    answer = item["completion"].strip()

    prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
    answer_ids = tokenizer.encode(answer, add_special_tokens=False) + [tokenizer.eos_token_id]

    input_ids = torch.tensor(prompt_ids, dtype=torch.long)
    answer_ids = torch.tensor(answer_ids, dtype=torch.long)

    if len(input_ids) > 2048:
        print("skipping", len(input_ids))
        continue

    # labels = input_ids.clone()
    # labels[: len(prompt_ids)] = -100

    all_examples.append(
        dict(
            input_ids=input_ids,
            labels=answer_ids,
            attention_mask=torch.ones_like(input_ids, dtype=torch.bool),
        )
    )

  2%|▏         | 117/6203 [00:00<00:15, 388.07it/s]

skipping 2196


  6%|▌         | 373/6203 [00:00<00:14, 411.52it/s]

skipping 2116


  8%|▊         | 498/6203 [00:01<00:14, 397.91it/s]

skipping 2069
skipping 2191
skipping 2106
skipping 2096
skipping 2134


 10%|█         | 622/6203 [00:01<00:13, 402.81it/s]

skipping 2205
skipping 2119


 13%|█▎        | 789/6203 [00:01<00:13, 412.20it/s]

skipping 2116


 17%|█▋        | 1050/6203 [00:02<00:12, 417.30it/s]

skipping 2238


 20%|██        | 1257/6203 [00:03<00:12, 395.76it/s]

skipping 2114
skipping 2097
skipping 2068


 22%|██▏       | 1337/6203 [00:03<00:12, 379.14it/s]

skipping 2135


 24%|██▎       | 1462/6203 [00:03<00:11, 396.76it/s]

skipping 2110
skipping 2060


 26%|██▌       | 1626/6203 [00:04<00:11, 390.52it/s]

skipping 2191
skipping 2183


 28%|██▊       | 1750/6203 [00:04<00:11, 403.65it/s]

skipping 2130


 30%|███       | 1870/6203 [00:04<00:11, 385.74it/s]

skipping 2106
skipping 2289
skipping 2110


 32%|███▏      | 1989/6203 [00:04<00:10, 387.57it/s]

skipping 2098


 35%|███▍      | 2156/6203 [00:05<00:10, 397.17it/s]

skipping 2104


 37%|███▋      | 2277/6203 [00:05<00:09, 393.76it/s]

skipping 2199
skipping 2100


 38%|███▊      | 2357/6203 [00:05<00:09, 390.98it/s]

skipping 2284


 41%|████      | 2528/6203 [00:06<00:09, 397.67it/s]

skipping 2055
skipping 2098


 43%|████▎     | 2686/6203 [00:06<00:09, 375.93it/s]

skipping 2291
skipping 2066


 45%|████▌     | 2805/6203 [00:07<00:08, 387.46it/s]

skipping 2071


 47%|████▋     | 2891/6203 [00:07<00:08, 407.70it/s]

skipping 2163


 48%|████▊     | 2973/6203 [00:07<00:08, 399.94it/s]

skipping 2189
skipping 2135


 59%|█████▊    | 3629/6203 [00:09<00:06, 382.08it/s]

skipping 2196


 61%|██████    | 3788/6203 [00:09<00:06, 384.01it/s]

skipping 2223


 66%|██████▌   | 4070/6203 [00:10<00:05, 381.00it/s]

skipping 2187
skipping 2223
skipping 2070


 69%|██████▊   | 4263/6203 [00:10<00:05, 371.15it/s]

skipping 2188


 74%|███████▍  | 4617/6203 [00:11<00:04, 388.97it/s]

skipping 2200


 77%|███████▋  | 4773/6203 [00:12<00:03, 381.30it/s]

skipping 2167
skipping 2069


 86%|████████▌ | 5340/6203 [00:13<00:02, 374.75it/s]

skipping 2092


 89%|████████▊ | 5501/6203 [00:14<00:01, 379.44it/s]

skipping 2199
skipping 2195


 91%|█████████ | 5617/6203 [00:14<00:01, 369.73it/s]

skipping 2222
skipping 2292
skipping 2100
skipping 2089


 92%|█████████▏| 5692/6203 [00:14<00:01, 362.01it/s]

skipping 2094


 94%|█████████▎| 5802/6203 [00:14<00:01, 349.93it/s]

skipping 2110
skipping 2110


 97%|█████████▋| 6016/6203 [00:15<00:00, 345.92it/s]

skipping 2168


 99%|█████████▉| 6166/6203 [00:15<00:00, 363.87it/s]

skipping 2166


100%|██████████| 6203/6203 [00:16<00:00, 386.60it/s]


In [9]:
with open("./t5_finetune_dataset.pkl", "wb") as f:
    pickle.dump(all_examples, f)

In [10]:
[1,2,3][:-1]

[1, 2]