In [None]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"',
                        shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import functools
import random
from typing import Any, Dict, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary

from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import math

import constants
import data_loaders

In [4]:
sequence_length = constants.CHUNK_SIZE_BYTES

# Dataset 與取 batch
enwik8_data_generator = data_loaders.get_enwik9_iterator(
    num_chunks=constants.NUM_CHUNKS // 10, # 只取了完整的 EnWik9 数据集的 10% 部分，也就是 EnWik8
    sequence_length=sequence_length,
)

next(enwik8_data_generator)

b'<mediawiki xmlns="http://www.mediawiki.org/xml/export-0.3/" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.mediawiki.org/xml/export-0.3/ http://www.mediawiki.org/xml/export-0.3.xsd" version="0.3" xml:lang="en">\n  <siteinfo>\n    <sitename>Wikipedia</sitename>\n    <base>http://en.wikipedia.org/wiki/Main_Page</base>\n    <generator>MediaWiki 1.6alpha</generator>\n    <case>first-letter</case>\n      <namespaces>\n      <namespace key="-2">Media</namespace>\n      <namespace key="-1">Special</namespace>\n      <namespace key="0" />\n      <namespace key="1">Talk</namespace>\n      <namespace key="2">User</namespace>\n      <namespace key="3">User talk</namespace>\n      <namespace key="4">Wikipedia</namespace>\n      <namespace key="5">Wikipedia talk</namespace>\n      <namespace key="6">Image</namespace>\n      <namespace key="7">Image talk</namespace>\n      <namespace key="8">MediaWiki</namespace>\n      <namespace key="9">MediaWiki talk</names

In [79]:
class Enwik8Dataset(Dataset):
    """Dataset for Enwik data."""
    def __init__(self, data_chunks) -> None:
        self.dataset = data_chunks

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> torch.Tensor:
        seq = self.dataset[idx]
        seq_ascii = np.frombuffer(seq, dtype=np.uint8)
        # 依然回傳 uint8；模型內會轉為 long
        return torch.tensor(seq_ascii, dtype=torch.uint8)

In [5]:
# ----------------------------
# Config
# ----------------------------
class TransformerConfig:
    """Hyperparameters used in the Transformer architectures."""
    def __init__(
        self,
        vocab_size: int,
        embedding_dim: int = 64,
        num_layers: int = 4,
        num_heads: int = 8,
        emb_init_scale: float = 0.02,
        widening_factor: int = 4,
        dropout: float = 0.0,
        max_length: int = constants.CHUNK_SIZE_BYTES,
        bos_token_id: int = 0,
        tie_weights: bool = False,  # 如需與輸入 embedding 綁定，改成 True
    ) -> None:
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.emb_init_scale = emb_init_scale
        self.widening_factor = widening_factor
        self.dropout = dropout
        self.max_length = max_length
        self.bos_token_id = bos_token_id
        self.tie_weights = tie_weights

In [6]:
# ----------------------------
# Sinusoidal Positional Encoding（PyTorch 沒內建，保留精簡版）
# ----------------------------
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4096):
        super().__init__()
        position = torch.arange(0, max_len).unsqueeze(1)                # [T, 1]
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)                               # [T, D]
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, T, D] -> 回傳 [T, D]，之後會 broadcast 到 batch 維度
        T = x.size(1)
        return self.pe[:T]

In [7]:
# ----------------------------
# 使用 PyTorch 內建 TransformerDecoder 的語言模型
# ----------------------------
class TransformerDecoder(nn.Module):
    """Transformer decoder model (PyTorch built-ins)."""

    def __init__(self, config: TransformerConfig) -> None:
        super().__init__()
        self.config = config

        # Embedding
        self.embedding = nn.Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.embedding_dim,
        )
        # 对神经网络中的 embedding layer 权重进行初始化
        # 使用的是截断正态分布（truncated normal distribution）
        nn.init.trunc_normal_(self.embedding.weight, std=config.emb_init_scale)

        # Positional encoding（固定正弦）
        self.pos_encoding = SinusoidalPositionalEncoding(
            d_model=config.embedding_dim,
            max_len=config.max_length,
        )

        # 內建 Decoder Layer + 堆疊
        d_model = config.embedding_dim
        dim_ff = d_model * config.widening_factor
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=config.num_heads,
            dim_feedforward=dim_ff,
            dropout=config.dropout,
            activation="gelu",
            batch_first=True,      # 讓張量是 [B, T, D]
            norm_first=True,       # Pre-LN 比較穩定
        )
        self.decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=config.num_layers,
        )

        # 最終 LayerNorm（常見做法，可留可去）
        self.final_norm = nn.LayerNorm(d_model)

        # 輸出線性層
        self.output_layer = nn.Linear(d_model, config.vocab_size, bias=False)

        # 可選：權重綁定（weight tying）
        # 是否让 output 层的權重與 input embedding 層的權重綁定（一般默认是 fasle）
        if config.tie_weights:
            self.output_layer.weight = self.embedding.weight

    @torch.no_grad()
    def _causal_mask(self, T: int, device) -> torch.Tensor:
        # PyTorch 的 attn_mask: True 表示不允許注意（被遮蔽）
        return torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)

    def shift_right(self, sequences: torch.Tensor) -> torch.Tensor:
        """Right-shift the input by padding on the temporal axis."""
        sequences = sequences.long()
        bos = torch.full(
            (sequences.size(0), 1),
            fill_value=self.config.bos_token_id,
            dtype=sequences.dtype,
            device=sequences.device,
        )
        return torch.cat([bos, sequences[:, :-1]], dim=1)

    def forward(self, targets: torch.Tensor) -> torch.Tensor:
        """
        Returns:
            log_probs: [B, T, V]（與原程式一致：回傳 log_softmax，供 to_marginals 使用）
        """
        # 右移得到自迴歸輸入
        inputs = self.shift_right(targets)                   # [B, T] long

        # Token + Positional
        x = self.embedding(inputs)                           # [B, T, D]
        x = x * math.sqrt(self.config.embedding_dim)
        pos = self.pos_encoding(x).to(x.device)              # [T, D]
        x = x + pos                                          # broadcast 到 [B, T, D]

        # 因果遮罩（純 LM 無 encoder memory）
        T = x.size(1)
        attn_mask = self._causal_mask(T, x.device)           # [T, T] bool

        # Decoder forward（不需要 memory）
        h = self.decoder(tgt=x, memory=x, tgt_mask=attn_mask)
        h = self.final_norm(h)

        logits = self.output_layer(h)                        # [B, T, V]
        return logits

In [83]:
import itertools
from typing import Tuple
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

import torch.nn as nn

class SentenceLevelNLoss(nn.Module):
    def __init__(self):
        super(SentenceLevelNLoss, self).__init__()

    def forward(self, logits, targets):
        log_probs = torch.log_softmax(logits, dim=-1)
        true_predictions = torch.gather(log_probs, 2, targets.long().unsqueeze(2)).squeeze(2)
        sentence_loss = -torch.mean(torch.sum(true_predictions, dim=1))
        return sentence_loss

def train_transformer_decoder(
    model: nn.Module,
    data_loader: DataLoader,
    training_steps: int,
    log_every: int,
    use_tqdm: bool = True,
    device: str = 'cuda',
) -> Tuple[nn.Module, float]:
    
    model.to(device)
    model.train()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = SentenceLevelNLoss()
    
    print('Initialization done, starting training...')
    last_loss = 0.0
    data_iter = itertools.cycle(data_loader)
    for step in tqdm(range(training_steps), disable=not use_tqdm):
        batch = next(data_iter).to(device)
        optimizer.zero_grad()
        logits = model(batch)                  # [B, T, V] logits
        loss = loss_fn(logits, batch)
        loss.backward()
        optimizer.step()

        if log_every > 0 and step % log_every == 0:
            print(f'Step {step}, Loss {loss.item()}')

        last_loss = loss.item()

    return model, last_loss

def train_transformer_decoder_by_epoch(
    model: nn.Module,
    data_loader: DataLoader,
    num_epochs: int,
    log_every: int,
    use_tqdm: bool = True,
    device: str = 'cuda',
) -> Tuple[nn.Module, float]:
    """
    按轮次训练Transformer解码器的函数。

    参数:
    model (nn.Module): 要训练的模型
    data_loader (DataLoader): 数据加载器
    num_epochs (int): 训练的轮数
    log_every (int): 每隔多少步打印一次日志
    use_tqdm (bool): 是否使用tqdm显示进度条，默认为True
    device (str): 使用的设备，默认为'cuda'

    返回:
    Tuple[nn.Module, float]: 训练后的模型和最后一轮的最后一个损失值
    """
    model.to(device)
    model.train()
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = SentenceLevelNLoss()
    
    print('Initialization done, starting training...')
    last_loss = 0.0
    for epoch in tqdm(range(num_epochs), disable=not use_tqdm):
        for step, batch in enumerate(data_loader):
            batch = batch.to(device)
            optimizer.zero_grad()
            logits = model(batch)                  # [B, T, V] logits
            loss = loss_fn(logits, batch)
            loss.backward()
            optimizer.step()

            if log_every > 0 and step % log_every == 0:
                print(f'Epoch {epoch}, Step {step}, Loss {loss.item()}')

            last_loss = loss.item()

    return model, last_loss

In [84]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create config and model
config = TransformerConfig(vocab_size=constants.ALPHABET_SIZE)
print(config.__dict__)
model = TransformerDecoder(config)
print(summary(model, input_size=(2, constants.CHUNK_SIZE_BYTES), dtypes=[torch.long]))

{'vocab_size': 256, 'embedding_dim': 64, 'num_layers': 4, 'num_heads': 8, 'emb_init_scale': 0.02, 'widening_factor': 4, 'dropout': 0.0, 'max_length': 2048, 'bos_token_id': 0, 'tie_weights': False}
Layer (type:depth-idx)                        Output Shape              Param #
TransformerDecoder                            [2, 2048, 256]            --
├─Embedding: 1-1                              [2, 2048, 64]             16,384
├─SinusoidalPositionalEncoding: 1-2           [2048, 64]                --
├─TransformerDecoder: 1-3                     [2, 2048, 64]             --
│    └─ModuleList: 2-1                        --                        --
│    │    └─TransformerDecoderLayer: 3-1      [2, 2048, 64]             66,752
│    │    └─TransformerDecoderLayer: 3-2      [2, 2048, 64]             66,752
│    │    └─TransformerDecoderLayer: 3-3      [2, 2048, 64]             66,752
│    │    └─TransformerDecoderLayer: 3-4      [2, 2048, 64]             66,752
├─LayerNorm: 1-4            

In [None]:
sequence_length = constants.CHUNK_SIZE_BYTES

# Dataset 與取 batch
enwik8_data_generator = data_loaders.get_enwik9_iterator(
    num_chunks=constants.NUM_CHUNKS // 10, # 只取了完整的 EnWik9 数据集的 10% 部分，也就是 EnWik8
    sequence_length=sequence_length,
)
enwik8_chunks = list(enwik8_data_generator)
enwik8Dataset = Enwik8Dataset(enwik8_chunks)
enwik8DataLoader = DataLoader(enwik8Dataset, batch_size=32, shuffle=True)
for batch in enwik8DataLoader:
    print(batch.shape, batch.dtype)
    break
print(len(enwik8DataLoader))

torch.Size([32, 2048]) torch.uint8
1526


In [86]:
model, loss = train_transformer_decoder_by_epoch(
    model=model,
    data_loader=enwik8DataLoader,
    num_epochs=3,
    log_every=500,
    device=device
)
print(f'Final loss: {loss}')
# Save model
torch.save(model.state_dict(), 'params.pth')
print('Parameters saved in file params.pth')

Initialization done, starting training...


  0%|          | 0/3 [00:00<?, ?it/s]

Epoch 0, Step 0, Loss 11639.6328125
Epoch 0, Step 500, Loss 6236.7177734375
Epoch 0, Step 1000, Loss 5742.93408203125
Epoch 0, Step 1500, Loss 5561.9580078125


 33%|███▎      | 1/3 [07:42<15:24, 462.07s/it]

Epoch 1, Step 0, Loss 5458.1064453125
Epoch 1, Step 500, Loss 5413.98046875
Epoch 1, Step 1000, Loss 5168.072265625
Epoch 1, Step 1500, Loss 5097.517578125


 67%|██████▋   | 2/3 [15:32<07:47, 467.26s/it]

Epoch 2, Step 0, Loss 5192.1435546875
Epoch 2, Step 500, Loss 5176.4443359375
Epoch 2, Step 1000, Loss 5035.2255859375
Epoch 2, Step 1500, Loss 5007.82763671875


100%|██████████| 3/3 [22:58<00:00, 459.34s/it]


Final loss: 4857.38671875
Parameters saved in file params.pth


In [5]:
import numpy as np
from language_modeling_is_compression import constants
from language_modeling_is_compression import data_loaders

enwik9_data_generator = data_loaders.get_enwik9_iterator(
      num_chunks=10,
      chunk_start_idx= constants.NUM_CHUNKS // 10,
      sequence_length=constants.CHUNK_SIZE_BYTES,
)

rawdata = next(enwik9_data_generator)
print(f'Data sample size: {len(rawdata)} bytes')

tokenized_data = np.frombuffer(rawdata, dtype=np.uint8)
print(type(tokenized_data), tokenized_data)
print(rawdata)

Data sample size: 2048 bytes
<class 'numpy.ndarray'> [ 60 109 101 ...  32  32  32]
b'<mediawiki xmlns="http://www.mediawiki.org/xml/export-0.3/" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://www.mediawiki.org/xml/export-0.3/ http://www.mediawiki.org/xml/export-0.3.xsd" version="0.3" xml:lang="en">\n  <siteinfo>\n    <sitename>Wikipedia</sitename>\n    <base>http://en.wikipedia.org/wiki/Main_Page</base>\n    <generator>MediaWiki 1.6alpha</generator>\n    <case>first-letter</case>\n      <namespaces>\n      <namespace key="-2">Media</namespace>\n      <namespace key="-1">Special</namespace>\n      <namespace key="0" />\n      <namespace key="1">Talk</namespace>\n      <namespace key="2">User</namespace>\n      <namespace key="3">User talk</namespace>\n      <namespace key="4">Wikipedia</namespace>\n      <namespace key="5">Wikipedia talk</namespace>\n      <namespace key="6">Image</namespace>\n      <namespace key="7">Image talk</namespace>\n      <name

In [88]:
def predict_fn(tokenized_data_batch):
    config = TransformerConfig(vocab_size=constants.ALPHABET_SIZE)
    model = TransformerDecoder(config)
    params = torch.load('params.pth')
    model.load_state_dict(params)

    # Set model to evaluation mode
    model.eval()
    with torch.no_grad():
        # Convert numpy array to PyTorch tensor
        x = torch.tensor(tokenized_data_batch, dtype=torch.int64)
        # Get logits from the model
        logits = model(x)
        # Convert to log probabilities and then to numpy
        log_probs = torch.log_softmax(logits, dim=-1).cpu().numpy()
    return log_probs

In [95]:
use_slow_lossless_compression = False
if use_slow_lossless_compression:
  # use_slow_lossless_compression:
  log_probs = list()
  for t in range(len(tokenized_data)):
    # assume tokenized_data = [h,e,l,l,o]
    # t0: input = [h] 實際模型的輸入是 [BOS]，因為還要做 right_shift,模型的輸出是 p(h | <bos>)
    # t1: input = [h,e] 實際模型的輸入是 [h]，因為還要做 right_shift，模型的輸出是 p(e | h)
    # t2: input = [h,e,l]
    # t3: input = [h,e,l,l]
    # t4: input = [h,e,l,l,o]
    # 为什么这里必须一步步来，不能一次输入所有的 tokenized_data 呢，因为解码是一个一个token解码的。
    # 直接一次 forward "<bos> h e l l" 在 hell位置上得到的 logits 其实跟一步步的logits不同
    # 为什么不同？因为 casual mask 只是应用在了 attention 的计算里，而整个transformer 的计算组件里还有很多
    # 没有用到 casual mask 的组件，比如 layer normalization 和 FFN 等。
    input_seq = tokenized_data[None, : t + 1]
    subsequence_probs = predict_fn(input_seq)
    last_token_probs = subsequence_probs[0, -1]
    log_probs.append(last_token_probs)
  log_probs = np.vstack(log_probs)
else:
  log_probs = predict_fn(tokenized_data[None])[0, ...]
probs = np.exp(log_probs)

In [96]:
from language_modeling_is_compression import utils
from language_modeling_is_compression import arithmetic_coder

output = list()
encoder = arithmetic_coder.Encoder(
    base=constants.ARITHMETIC_CODER_BASE,
    precision=constants.ARITHMETIC_CODER_PRECISION,
    output_fn=output.append,
)

for pdf, symbol in zip(probs, tokenized_data):
  encoder.encode(utils.normalize_pdf_for_arithmetic_coding(pdf), symbol)

encoder.terminate()
compressed_bits = ''.join(map(str, output))
print(len(compressed_bits)) # 压缩后的bits 比如 8519 bits
# 而因為 8519 不是 8 的倍數，所以需要 padding，最少需要在左邊pad一個bit-0
compressed_data, num_padded_bits = utils.bits_to_bytes(compressed_bits)
compression_ratio = len(compressed_data) / len(rawdata)
print(f'Compression ratio: {compression_ratio:.4f}')

7168
Compression ratio: 0.4375


In [91]:
from collections.abc import Iterator
from typing import Callable, Tuple, Union

data_iter = iter(utils.bytes_to_bits(compressed_data, num_padded_bits=num_padded_bits))
# The decoder requires a function that reads digits from {0, 1, ..., base - 1}
# from the compressed input and returns `None` when the input is exhausted.
def _input_fn(bit_sequence: Iterator[str] = data_iter) -> Union[int, None]:
  try:
    return int(next(bit_sequence))
  except StopIteration:
    return None

decoder = arithmetic_coder.Decoder(
    base=constants.ARITHMETIC_CODER_BASE,
    precision=constants.ARITHMETIC_CODER_PRECISION,
    input_fn=_input_fn,
)

# We need a dummy token because the language model right-shifts the sequence
# by one when computing the conditional probabilities. Concretely, at every
# step, we need the `pdf` of the next token given all currently decompressed
# tokens, but without a dummy token, the last `pdf` would be that of the last
# already decompressed token. The value of the dummy token is irrelevant.
sequence_array = np.empty((1,), dtype=np.uint8)
probs = np.exp(predict_fn(sequence_array[None])[0, ...])

In [92]:
uncompressed_length = constants.CHUNK_SIZE_BYTES
for idx in range(uncompressed_length):
  token = decoder.decode(
      utils.normalize_pdf_for_arithmetic_coding(probs[idx])
  )
  sequence_array = np.insert(sequence_array, -1, token)
  if len(sequence_array) == constants.CHUNK_SIZE_BYTES+1:
    break
  probs = np.exp(predict_fn(sequence_array[None])[0, ...])  
decoded_data = sequence_array[:-1].tobytes()

In [93]:
if rawdata == decoded_data:
  print('SUCCESS: Data was successfully compressed and decompressed!')
else:
  print('ERROR: Decompressed data does not match original data!')

SUCCESS: Data was successfully compressed and decompressed!
