In [1]:
# import requests
# input_txt_href = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
# text = requests.get(input_txt_href).text

with open('/kaggle/input/wodejingshenjiayuan/.txt', 'r') as f:
    text = f.read()

chars = sorted(list( set(text)))
vocab_size = len(chars)
s2i = {c:i for i,c in enumerate(chars)}
i2s = {i:c for i,c in enumerate(chars)}
encode = lambda s: [s2i[c] for c in s]
decode = lambda l: ''.join([i2s[i] for i in l])

data = encode(text)
n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n:]

n,vocab_size

(108611, 2748)

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

# torch.manual_seed(317)

def get_batch(data, batch_size = 4, block_size = 8):
    idx = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.tensor([data[i: i + block_size] for i in idx],dtype=torch.long)
    y = torch.tensor([data[i+1:i+1 + block_size] for i in idx],dtype=torch.long)
    return x,y

get_batch(train_data, 2)

(tensor([[ 166,  120, 2249, 1155,   82,  804, 2013, 1949],
         [2031,  849,   89,   82, 2304,  455, 2316,  606]]),
 tensor([[ 120, 2249, 1155,   82,  804, 2013, 1949,  283],
         [ 849,   89,   82, 2304,  455, 2316,  606, 1343]]))

In [3]:
# 多头
class Head(nn.Module):
    def __init__(self, n_embd, head_embed,dropout):
        super().__init__()
        self.key = nn.Linear(n_embd,head_embed,bias=False)
        self.query = nn.Linear(n_embd,head_embed,bias=False)
        self.value = nn.Linear(n_embd,head_embed,bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_x):
        C = input_x.size(-1)
        k = self.key(input_x)
        q = self.query(input_x)
        weight = q @ k.transpose(-2,-1) * C ** -0.5

        T = weight.size(-1)
        tril = torch.tril(torch.ones(T,T))
        weight = weight.masked_fill(tril == 0, float('-inf'))
        v = self.value(input_x)
        weight = weight.softmax(dim=-1)
        weight = self.dropout(weight)
        out = weight @ v
        return out

class MultiHead(nn.Module):
    def __init__(self, num_heads, n_embd, head_embd,dropout):
        super().__init__()
        self.norm = nn.LayerNorm(n_embd)
        self.heads = nn.ModuleList([Head(n_embd,head_embd,dropout) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd,n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        input = self.norm(x)
        out = torch.cat([head(input) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

class Block(nn.Module):
    def __init__(self,n_embd,num_heads,dropout):
        super().__init__()
        self.sa_heads = MultiHead(num_heads,n_embd,n_embd//num_heads,dropout)
        self.feed_forward = nn.Sequential(
            nn.LayerNorm(n_embd),
            nn.Linear(n_embd,4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd,n_embd),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        x = x + self.sa_heads(x)
        x = x + self.feed_forward(x)
        return x

In [4]:
class BingramLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd,num_heads,dropout,num_block):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, num_heads, dropout) for _ in range(num_block)],
            nn.LayerNorm(n_embd),
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        token_emb = self.embedding(idx)
        pos_emb = self.position_embedding(torch.arange(idx.size(-1)))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        # print(logits.shape)
        if targets != None:
            logits = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss

    def generate(self, idx, max_len):
        for _ in range(max_len):
            logits, loss = self(idx[:, -self.block_size:])
            # print(logits.shape)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs,num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [5]:
batch_size = 4096
block_size = 40
train_steps = int(1e4)
val_steps = train_steps / 100
n_embd = 64
num_heads=8
dropout = 0.2
num_block = 4

if torch.cuda.is_available(): 
    torch.set_default_device('cuda')
    print('device: cuda')
model = BingramLanguageModel(vocab_size,block_size,n_embd,num_heads,dropout,num_block)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

@torch.no_grad()
def estimate_loss(batch_size,block_size):
    model.eval()
    x, y = get_batch(val_data, batch_size, block_size)
    _, loss = model(x, y)
    model.train()
    return loss

def test_generate():
    test_idx = torch.tensor([encode('我在北京的街头看到')], dtype=torch.long)
    test_logits = model.generate(test_idx, max_len=block_size)
    print(decode(test_logits[0].tolist()))

for steps in range(train_steps):
    x, y = get_batch(train_data, batch_size, block_size)
    _, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % val_steps == 0: 
        val_loss = estimate_loss(batch_size,block_size)
        print(f"step {steps}: train loss {loss.item()} , validate loss {val_loss.item()}")
        if steps % (10 * val_steps) == 0:
            test_generate()
test_generate()
torch.save(model, 'model.pth')


device: cuda
step 0: train loss 8.074625015258789 , validate loss 7.977770805358887
我在北京的街头看到况粗普察业士谬巨聊齿藐甘炸胥侧版捕眠学牛师剔释告码惯律硝罩判仁煺如案鳄珍千拍者罢
step 100: train loss 5.6737518310546875 , validate loss 5.836033821105957
step 200: train loss 4.743711471557617 , validate loss 5.2092156410217285
step 300: train loss 4.298342704772949 , validate loss 5.048758029937744
step 400: train loss 4.038207054138184 , validate loss 4.991559028625488
step 500: train loss 3.826084613800049 , validate loss 4.99304723739624
step 600: train loss 3.7003636360168457 , validate loss 5.060132026672363
step 700: train loss 3.5778374671936035 , validate loss 5.095638751983643
step 800: train loss 3.465257167816162 , validate loss 5.190592288970947
step 900: train loss 3.367635726928711 , validate loss 5.256810188293457
step 1000: train loss 3.2895050048828125 , validate loss 5.291468620300293
我在北京的街头看到了骂；后一口也不安翁一面，我看不给她之处。据说实，我就喜欢把这四这两饭吃凉来面，
step 1100: train loss 3.2233943939208984 , validate loss 5.390722751617432
step 1