<a href="https://colab.research.google.com/github/planctao/data-flower/blob/main/llm_from_scratch_ch02_self.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import urllib.request
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

In [2]:
# 下载 数据集
if not os.path.exists("the-verdict.txt"):
    url = ("https://raw.githubusercontent.com/rasbt/"
           "LLMs-from-scratch/main/ch02/01_main-chapter-code/"
           "the-verdict.txt")
    file_path = "the-verdict.txt"
    urllib.request.urlretrieve(url, file_path)

In [None]:
# 分词

# 读取文本
with open('the-verdict.txt','r') as f:
    raw_text = f.read()

In [None]:
import re
def split_sentence(text):
    result = re.split(r'([,.:;?_!"()\']|--|\s)', text)
    result = [item.strip() for item in result if item.strip()]
    return result

In [None]:
split_list = split_sentence(raw_text) # 变成单个单词的形式

In [None]:
len(split_list)

4690

In [None]:
sorted_set = sorted(set(split_list)) # 对其进行set并去重，构建token列表
len(sorted_set)

def make_token_dict(sorted_set): # 通过构建一个dict，构建ID->token以及token->ID的映射列表
    ret1 = {}
    ret2 = {}
    for i, voc in enumerate(sorted_set):
        ret1[i] = voc
        ret2[voc] = i
    return ret1, ret2

ID2voc, voc2ID = make_token_dict(sorted_set=sorted_set)

for i in range(0, 20):
    print(ID2voc[i]," : ",i)

!  :  0
"  :  1
'  :  2
(  :  3
)  :  4
,  :  5
--  :  6
.  :  7
:  :  8
;  :  9
?  :  10
A  :  11
Ah  :  12
Among  :  13
And  :  14
Are  :  15
Arrt  :  16
As  :  17
At  :  18
Be  :  19


In [None]:
class Tokenizer1:
    def __init__(self, id2voc, voc2id):
        self.ID2voc = id2voc
        self.voc2ID = voc2ID
    def encode(self, texts):
        def split_sentence(text):
            result = re.split(r'([,.:;?_!"()\']|--|\s)', text)
            result = [item.strip() for item in result if item.strip()]
            return result
        texts = split_sentence(texts)
        return [voc2ID[text] for text in texts]
    def decode(self, ids):
        return [ID2voc[id] for id in ids]


In [5]:
class GPT2DatasetV1(Dataset):
    def __init__(self, txt, max_length, stride, tokenizer):
        ids = tokenizer.encode(txt) #
        self.input_ids = []
        self.target_ids = []
        for i in range(0, len(ids) - max_length, stride):
            x = torch.tensor(ids[i: i + max_length])
            y = torch.tensor(ids[i+1: i + max_length + 1])
            self.input_ids.append(x)
            self.target_ids.append(y)
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

with open("the-verdict.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()
def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):

    tokenizer = tiktoken.get_encoding("gpt2")
    dataset = GPT2DatasetV1(
        txt=raw_text, max_length=256, stride=1, tokenizer=tokenizer
    )
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        drop_last=drop_last,
        num_workers=0
    )
    return dataloader
dataloader = create_dataloader_v1(
    txt = raw_text,
    batch_size=1,
    max_length=4,
    stride=1,
    shuffle=True,
    drop_last=True,
    num_workers=1
)

data_iter = iter(dataloader)
first_batch = next(data_iter)
print(first_batch)

[tensor([[  550, 11001,    11,   262,  2756,   286,   366,    38,   271, 10899,
            82,     1,  1816,   510,    13,   198,   198,  1026,   373,   407,
         10597,  1115,   812,  1568,   326,    11,   287,   262,  1781,   286,
           257,  1178,  2745,     6,  4686,  1359,   319,   262, 34686, 41976,
            11,   340,  6451,  5091,   284,   502,   284,  4240,  1521,   402,
           271, 10899,   550,  1813,   510,   465, 12036,    13,  1550, 14580,
            11,   340,  1107,   373,   257, 29850,  1917,    13,  1675, 24456,
           465,  3656,   561,   423,   587,  1165,  2562,   438, 14363,  3148,
          1650,  1010,   550,   587,  6699,   262,  1540,   558,   286,  2282,
           326,  9074,    13,   402,   271, 10899,   550,   366,  7109, 14655,
           683,   866,   526,  1114,  9074,    13,   402,   271, 10899,   438,
           292,   884,   438, 18108,   407, 11196, 10597,  3016,   257,   614,
           706,  3619,   338, 10568,   550,   587, 