In [None]:
import os
from typing import BinaryIO

In [None]:
def find_chunk_boundaries(
    file: BinaryIO, 
    desired_num_chunks: int, 
    split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), (
        "Must represent special token as a bytestring"
    )

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [None]:
import regex as re
num_processes=4
input_path='/Users/zhanghao1/Downloads/cs336/assignment1-basics/tests/fixtures/tinystories_sample_5M.txt'
with open(input_path,"rb") as f:
    boundaries = find_chunk_boundaries(
        f, num_processes, "<|endoftext|>".encode("utf-8"))
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        print(1)
        break

In [None]:
chunk


In [None]:
special_tokens=["<|endoftext|>","<|endoftext|>"]
s="|".join([re.escape(item) for item in special_tokens])
s

In [None]:

re.split(s,chunk)

In [None]:
cnt=0
for i in range(len(chunk)-1):
    if chunk[i]=='i' and chunk[i+1]=='n':
        cnt+=1
cnt

In [None]:
pat=r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
word_cnt={}
iterlist=re.finditer(pat,chunk)
for match in iterlist:
    s=match.group().encode("utf-8")
    if s in word_cnt:
        word_cnt[s]+=1
    else:
        word_cnt[s]=1

In [None]:
word_cnt

In [None]:
import pickle
file_path = "/Users/zhanghao1/Downloads/cs336/assignment1-basics/tests/_snapshots/test_train_bpe_special_tokens.pkl"

with open(file_path, "rb") as f:
    data = pickle.load(f)
print(data.keys())
reference_merges=data['merges']
for idx,item in enumerate(reference_merges):
    if b"\n" in item[0] or b'\n' in item[1]:
        print(idx,item)

In [None]:
for idx,item in enumerate(data['vocab_values']):
    if b'\n' in item:
        print(idx,item)

In [None]:
ord(b'\n')

In [None]:
def replace_pair(lst, pair, new_val):
    a, b = pair
    res=[]
    i=0
    while i<len(lst):
        print(i)
        if i!=len(lst)-1 and lst[i]==pair[0] and lst[i+1]==pair[1]:
            res.append(new_val)
            i+=2
            continue

        res.append(lst[i])
        i+=1
    return res
lst=[97,97,97,97,97]+[97,97,97,97,97]
pair=(97,97)
replace_pair(lst,pair,257)

In [None]:
import numpy as np
str_tuples=[(b"t",b"h"),(b" c",b"om")]
np.lexsort(np.array(str_tuples).T)

In [None]:
import time
def find_max_pairs_vec(pairs_cnt,vocab):

    
    max_val = values.max()
    max_mask = (values == max_val)
    
    max_keys = keys[max_mask]
    str_tuples = [(vocab[k[1]], vocab[k[0]]) for k in max_keys]
    lex_last = np.lexsort(np.array(str_tuples).T)[-1]  # 选字典序最大的
    
    return tuple(max_keys[lex_last]), max_val

def find_max_pairs(pairs_cnt,vocab):
    max_cnt=-1
    max_pair=(-1,-1)
    max_pair_str=None
    for item in pairs_cnt.keys():
        if pairs_cnt[item]>max_cnt or (pairs_cnt[item]==max_cnt and (vocab[item[0]],vocab[item[1]])>max_pair_str):
            max_cnt=pairs_cnt[item]
            max_pair=item
            max_pair_str=(vocab[item[0]],vocab[item[1]])
    return max_pair,max_cnt

pairs_cnt={}
vocab={}
for i in range(5000000):
    pairs_cnt[(i,i)]=i
    vocab[i]=str(i)
t=time.time()
for i in range(10):
    find_max_pairs(pairs_cnt,vocab)
print(time.time()-t)
keys = np.array(list(pairs_cnt.keys()))
values = np.array(list(pairs_cnt.values()))
t=time.time()
for i in range(10):
    find_max_pairs_vec(pairs_cnt,vocab)
print(time.time()-t)

In [None]:
keys

In [4]:
from collections import Counter
dic=Counter()
dic[(3,5)]-=1

In [5]:
dic

Counter({(3, 5): -1})

In [3]:
del dic[(3,5)]

In [1]:
import pickle

# 从文件加载
with open("../owt_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

with open("../owt_merges.pkl", "rb") as f:
    merges = pickle.load(f)

In [2]:
vocab[256]

b'<|endoftext|>'

In [9]:
max_len=0
s=b""
max_k=0
for k,item in vocab.items():
    if len(item)>max_len:
        max_len=len(item)
        s=item
        max_k=k


In [10]:
print(s,max_k)

b' accomplishment' 7160


In [11]:
max([1,3,3,4,3,16])

16

In [12]:
for i in b"s":
    print(i)

115


In [16]:
int.to_bytes(115)

b's'

In [2]:
special_tokens=['a','aa']
special_tokens=sorted(special_tokens, key=len, reverse=True)
special_tokens

['aa', 'a']

In [33]:
import tiktoken
import time
reference_tokenizer = tiktoken.get_encoding("gpt2")
with open("../data/TinyStoriesV2-GPT4-valid.txt","r") as f:
    data=f.read()
b_cnt=len(data.encode('utf-8'))
t=time.time()
reference_tokenizer.encode(data,allowed_special={'<|endoftext|>'})
t=time.time()-t
print(t)
b_cnt/t

1.7577948570251465


12801608.168363236

In [46]:
reference_tokenizer.n_vocab

50257

In [None]:
token
with open("../data/owt_valid.txt") as f:
    for id in f:
        break
item

<regex.Match object; span=(14, 14), match=''>

In [95]:
import torch
from einops import rearrange,einsum
d_k=20
theta=10000

In [105]:
max_seq_len=12
theta_list=1.0 / (theta ** (torch.arange(0, d_k, 2).float() / d_k))  # (dim/2,)
t=torch.arange(max_seq_len).unsqueeze(1)*theta_list
cos=torch.cos(t)
sin=torch.sin(t)

In [62]:
x_t=rearrange(t,"... (a b) -> ... a b",a=10,b=2)
x_t.shape

torch.Size([2, 3, 10, 2])

In [79]:
even=x_t[...,0]
odd=x_t[...,1]
out=torch.concat([(cos*even+sin*odd).unsqueeze(-1),(-cos*even+cos*odd).unsqueeze(-1)],axis=-1)
out=rearrange(out,"... a b -> ... (a b)")

In [100]:
theta_list

tensor([ 500., 1500., 2500., 3500., 4500., 5500., 6500., 7500., 8500., 9500.])

In [107]:
torch.arange(0, d_k, 2)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])

In [120]:
import torch
from einops import einsum
x=torch.randn(2,5,3)
mask=torch.ones(5,3).to(dtype=torch.bool)
x+mask*torch.inf

tensor([[[inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf]],

        [[inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf],
         [inf, inf, inf]]])

In [125]:
torch.tril(torch.ones(5,5)).to(dtype=torch.bool)

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

In [150]:
context_length=1024
d_model=1600
d_ff=6400
vocab_size=50257
num_heads=25
num_layers=48

def cal_flops(
    context_length=1024,
    d_model=1600,
    d_ff=6400,
    vocab_size=50257,
    num_heads=25,
    num_layers=48,
):
    d_head=d_model//num_heads


    mha=3*2*d_model*context_length*d_model\
        +num_heads*2*context_length*d_head*context_length\
        +num_heads*2*context_length*context_length*d_head\
        +2*context_length*d_model*d_model

    ffn=6*d_ff*d_model*context_length

    lm_head=2*context_length*d_model*vocab_size
    total=(mha+ffn)*num_layers+lm_head

    return mha*num_layers,ffn*num_layers,lm_head,total

cal_flops()

(1328755507200, 3019898880000, 164682137600, 4513336524800)

In [159]:
# GPT-2 small (12 layers, 768d_model, 12 heads), GPT-2 medium (24layers, 1024d_model, 16 heads), and GPT-2 large
items=cal_flops(context_length=16384)
[item/1e12 for item in items]

[98.5694994432, 48.31838208, 2.6349142016, 149.5227957248]

In [181]:
x=torch.randn(5,10)

In [185]:
ind=torch.arange(5,10,dtype=torch.int64).unsqueeze(1)

torch.gather(x,index=ind,dim=1).squeeze(1)

tensor([-0.6428, -0.0197,  0.3973,  0.2688, -1.0461])

In [None]:
class adamw(torch.optim.Optimizer):
    def __init__(self,weights,lr,beta_1,beta_2,lamb,eps=1e-8):
        super(weights,lr).__init__()
        self.beta_1=beta_1
        self.beta_2=beta_2
        self.m=[torch.zeros_like(p) for p in weights]
        self.v=[torch.zeros_like(p) for p in weights]
        self.lamb=lamb
        self.eps=eps

    def step(self,closure: Optional[Callable]=None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr=group['lr']
            print(group.keys())
            for p in group["params"]:
                print(type(p))

In [1]:
import torch
from torch.optim import SGD
from torch import nn
weights=nn.Parameter(5*torch.randn((10,10)))
opt=SGD([weights],lr=1e3)
for t in range(10):
    opt.zero_grad()
    loss=(weights**2).mean()
    print(loss.cpu().item())
    loss.backward()
    opt.step()



26.056364059448242
9406.3466796875
3395691.0
1225844352.0
442529775616.0
159753254731776.0
5.767092247514317e+16
2.0819199895380427e+19
7.515730907145636e+21
2.7131791119329537e+24


In [178]:
ind

tensor([[5, 6, 7, 8, 9]])