In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import random
import re
from tokenizers import ByteLevelBPETokenizer
import time
import math
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np
random.seed(122090791)

## data preprocessing

In [None]:
with open('data/raw/partA/eng-cmn.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()


pairs = []
for line in lines:
    line = line.strip()
    if not line:
        continue
    parts = line.split('\t')
    if len(parts) != 3:
        # print(f"跳过无效行: {line}")
        continue
    en, zh, _ = parts
    
    # clean english text
    en_clean = re.sub(r'[^\w\s]', '', en.strip().lower()).strip()
    # clean chinese text
    zh_clean = re.sub(r'[^\w\s]', '', zh.strip()).strip()
    
    # skip empty lines
    if not en_clean or not zh_clean:
        # print(f"跳过空文本行: {line}")
        continue
    
    pairs.append((zh_clean, en_clean))  # chi -> eng

# set and filter pairs
seen = set()
cleaned_pairs = []
for pair in pairs:
    if pair not in seen:
        seen.add(pair)
        cleaned_pairs.append(pair)

max_length = 128
filtered_pairs = [
    (zh, en) for zh, en in cleaned_pairs
    if len(zh) <= max_length and len(en) <= max_length
]

# split dataset
random.shuffle(filtered_pairs)
total = len(filtered_pairs)
train_pairs = filtered_pairs[:int(0.9*total)]
test_pairs = filtered_pairs[int(0.9*total):]

# train tokenizer
tokenizer = ByteLevelBPETokenizer()

def corpus_iterator():
    for zh, en in train_pairs:
        yield zh  # chi input
        yield en  # eng target

tokenizer.train_from_iterator(
    iterator=corpus_iterator(),
    vocab_size=10000,
    min_frequency=2,
    special_tokens=["[SOS]", "[EOS]", "[PAD]", "[UNK]", "[SEP]"]
)

# generate token ids
SOS_ID = tokenizer.token_to_id("[SOS]")
SEP_ID = tokenizer.token_to_id("[SEP]")
EOS_ID = tokenizer.token_to_id("[EOS]")

def encode_pair(zh, en):
    # input: [SOS] chi [SEP] eng [EOS]
    input_text = f"[SOS]{zh}[SEP]{en}[EOS]"
    return tokenizer.encode(input_text).ids

def process_encodings(encodings):
    input_ids_list, target_ids_list, loss_mask_list = [], [], []
    for encoding in encodings:
        input_ids = encoding[:-1]
        target_ids = encoding[1:]
        sep_pos = encoding.index(SEP_ID) + 1
        loss_mask = [0] * len(input_ids)
        for i in range(sep_pos, len(input_ids)):
            loss_mask[i] = 1
        input_ids_list.append(input_ids)
        target_ids_list.append(target_ids)
        loss_mask_list.append(loss_mask)
    return input_ids_list, target_ids_list, loss_mask_list

train_encodings = [encode_pair(zh, en) for zh, en in train_pairs]
test_encodings = [encode_pair(zh, en) for zh, en in test_pairs]

train_inputs, train_targets, train_masks = process_encodings(train_encodings)
test_inputs, test_targets, test_masks = process_encodings(test_encodings)

## model

In [5]:
class GPTTranslator(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=8):
        super().__init__()
        self.d_model = d_model
        
        # word embedding layer
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(1024, d_model)  # position embedding layer
        
        # Transformer Decoder层
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        
        # output layer
        self.lm_head = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, memory=None, tgt_mask=None):
        device = x.device
        batch_size, seq_len = x.shape
        
        # position encoding
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0)
        pos_emb = self.pos_emb(positions)
        
        # word embedding + position encoding
        tok_emb = self.token_emb(x)
        x = tok_emb + pos_emb
        
        # generate mask
        if tgt_mask is None:
            tgt_mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1).to(device)
        
        memory = torch.zeros((batch_size, 1, self.d_model), device=device)

        # Transformer forward pass
        out = self.transformer(
            tgt=x,
            memory=memory,
            tgt_mask=tgt_mask,
        )
        
        # predict logits
        logits = self.lm_head(out)
        return logits

# get vocab size
vocab_size = tokenizer.get_vocab_size()


model = GPTTranslator(vocab_size=vocab_size)
# print(f"参数量: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

参数量: 11.7M


## loss fn

In [None]:
def masked_loss(logits, targets, mask):
    # flat
    logits_flat = logits.view(-1, logits.size(-1))
    targets_flat = targets.view(-1)
    mask_flat = mask.view(-1).bool()
    
    # calculate loss
    loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
    masked_loss = loss * mask_flat
    
    # average
    return masked_loss.sum() / mask_flat.sum()

## batch generator

In [None]:
def batch_generator(data_inputs, data_targets, data_masks, batch_size=32):
    indices = list(range(len(data_inputs)))
    random.shuffle(indices)
    
    for i in range(0, len(indices), batch_size):
        batch_indices = indices[i:i+batch_size]
        
        # find the max length of the current batch
        max_len = max(len(data_inputs[idx]) for idx in batch_indices)
        
        # fill the current batch with padding to the max length
        batch_inputs = []
        batch_targets = []
        batch_masks = []
        for idx in batch_indices:
            inputs = data_inputs[idx]
            targets = data_targets[idx]
            mask = data_masks[idx]
            
            # fill the current sequence with padding to the max length
            pad_len = max_len - len(inputs)
            pad_id = tokenizer.token_to_id("[PAD]")
            
            batch_inputs.append(inputs + [pad_id] * pad_len)
            batch_targets.append(targets + [pad_id] * pad_len)
            batch_masks.append(mask + [0] * pad_len)
        
        yield (
            torch.tensor(batch_inputs),
            torch.tensor(batch_targets),
            torch.tensor(batch_masks).float()
        )

## train

In [8]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.savefig('plot_gpt.png')
    plt.close()

In [9]:
n_epochs = 128
print_every = 5
plot_every = 5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

start = time.time()
plot_losses = []
print_loss_total = 0  # Reset every print_every
plot_loss_total = 0  # Reset every plot_every


for epoch in range(1, n_epochs+1):
    model.train()
    total_loss = 0.0
    
    batch_iter = batch_generator(train_inputs, train_targets, train_masks, batch_size=32)
    
    for batch_inputs, batch_targets, batch_masks in batch_iter:
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        batch_masks = batch_masks.to(device)
        
        # forward pass
        logits = model(batch_inputs)
        
        # loss calculation
        loss = masked_loss(logits, batch_targets, batch_masks)
        
        # backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print_loss_total += total_loss/len(train_inputs)
    plot_loss_total += total_loss/len(train_inputs)

    if epoch % print_every == 0:
        print_loss_avg = print_loss_total / print_every
        print_loss_total = 0
        print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))
    
    if epoch % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        plot_losses.append(plot_loss_avg)
        plot_loss_total = 0
    
showPlot(plot_losses)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


1m 28s (- 36m 18s) (5 3%) 0.1471
3m 1s (- 35m 42s) (10 7%) 0.0989
4m 37s (- 34m 47s) (15 11%) 0.0658
6m 13s (- 33m 36s) (20 15%) 0.0413
7m 53s (- 32m 30s) (25 19%) 0.0248
9m 27s (- 30m 55s) (30 23%) 0.0150
11m 3s (- 29m 22s) (35 27%) 0.0100
12m 42s (- 27m 57s) (40 31%) 0.0077
14m 16s (- 26m 20s) (45 35%) 0.0063
15m 49s (- 24m 40s) (50 39%) 0.0056
17m 27s (- 23m 9s) (55 42%) 0.0050
19m 2s (- 21m 35s) (60 46%) 0.0046
20m 36s (- 19m 58s) (65 50%) 0.0042
22m 8s (- 18m 21s) (70 54%) 0.0040
23m 41s (- 16m 44s) (75 58%) 0.0038
25m 21s (- 15m 13s) (80 62%) 0.0036
26m 57s (- 13m 38s) (85 66%) 0.0035
28m 35s (- 12m 4s) (90 70%) 0.0033
30m 9s (- 10m 28s) (95 74%) 0.0032
31m 47s (- 8m 54s) (100 78%) 0.0031
33m 20s (- 7m 18s) (105 82%) 0.0030
34m 53s (- 5m 42s) (110 85%) 0.0030
36m 23s (- 4m 6s) (115 89%) 0.0029
37m 54s (- 2m 31s) (120 93%) 0.0028
39m 22s (- 0m 56s) (125 97%) 0.0027


## translate

In [None]:
def translate(zh_sentence, max_len=50):
    model.eval()
    with torch.no_grad():
        # encoder_input_ids: [SOS] chi [SEP]
        input_text = f"[SOS]{zh_sentence}[SEP]"
        input_ids = tokenizer.encode(input_text).ids
        input_tensor = torch.tensor([input_ids]).to(device)
        
        # generate
        for _ in range(max_len):
            logits = model(input_tensor)
            next_token = logits[0, -1].argmax()
            if next_token == EOS_ID:
                break
            input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
        
        # decode
        output_ids = input_tensor[0].tolist()
        sep_pos = output_ids.index(SEP_ID)
        en_ids = output_ids[sep_pos+1:]
        return tokenizer.decode(en_ids)

In [11]:
def evaluate_translation(test_pairs, num_samples=10):
    samples = random.sample(test_pairs, min(num_samples, len(test_pairs)))
    
    for zh, en_true in samples:
        en_pred = translate(zh)
        en_pred += " <EOS>"
        
        print(f"> {zh}")
        print(f"= {en_true}")
        print(f"<{en_pred}\n")

evaluate_translation(test_pairs, num_samples=10)

> 湯姆幫瑪麗買了她所有需要買的東西
= tom helped mary buy everything she needed
< <EOS>

> 老人們很早就起床
= old people get up very early
< privacy is good for eight hours <EOS>

> 他给我讲述了他的一生
= he told me the story of his life
< his mistakes to me his took took took <EOS>

> 汤姆总是旷课
= tom is always absent
< is always fighting <EOS>

> 你能冷凍它嗎
= can you freeze it
< you cold in it <EOS>

> 如果你不介意的话我想一个人呆着
= id like to be alone if you dont mind
< not like to go out how long time <EOS>

> 你現在有空嗎
= are you free right now
< free free is free <EOS>

> 我燒了紙
= i burned the paper
< got a paper <EOS>

> 我現在正在彈鋼琴
= i am playing the piano now
< is playing the piano playing the piano <EOS>

> 为什么你不来看我们
= why dont you come visit us
< will not have enough for us <EOS>

