# Perplexity

混淆度（Perplexity, PPL) 是评价语言模型拟合程度的性能指标。模型在测评集混淆度越低，性能越好。

给定 cross-entropy loss: 

$$
L(\theta, x) = \frac{1}{n} \sum_i^n \frac{1}{L} \sum_t^L -\log p_\theta(x_t|x_{<t}))
$$

其中, $n$ 为 数据条数，$L$ 为序列长度，混淆度为:

$$
PPL(\theta, x_\text{test}) = \exp(L(\theta, x_\text{test}))
$$

给定模型在数据集 $x_\text{train}$ 训练后的参数为 $\theta$, 测评是 $PPL(\theta, x_\text{test})$

## model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
torch.manual_seed(42)

class MyModel(nn.Module):
    def __init__(self, dim = 512, vocab_size = 100, max_len = 1024):
        super().__init__()
        self.embd = nn.Embedding(vocab_size, dim)
        self.w = nn.Linear(dim, dim)
        self.lm_head = nn.Linear(dim, vocab_size)
    def forward(self, x):
        bs, seq_len = x.shape
        X = self.embd(x)
        X = self.w(X)
        logits = self.lm_head(X)
        return logits

dim = 512,
seq_len = 16
vocab_size = 100
batch_size = 2
input_ids = torch.randint(vocab_size, [batch_size, seq_len], )
model = MyModel()
logits = model(input_ids)
print(logits.shape)

torch.Size([2, 16, 100])


# Loss

In [4]:
IGNORE_INDEX = -100
labels = torch.zeros_like(input_ids, dtype=torch.long)
labels[:, 0:seq_len-1] = input_ids[:, 1:seq_len]
labels[:, seq_len-1] = IGNORE_INDEX

loss_fn = nn.CrossEntropyLoss(ignore_index = IGNORE_INDEX)
loss = loss_fn(logits.view(batch_size*seq_len, vocab_size), 
               labels.view(batch_size*seq_len))
ppl = loss.exp()
print(ppl)

tensor(106.4641, grad_fn=<ExpBackward0>)


## 批量计算方法

In [5]:
def cal_ppl(model, data):
    bs, seq_len = data.shape

    labels = torch.zeros_like(data, dtype=torch.long)
    labels[:, 0:seq_len-1] = data[:, 1:seq_len]
    labels[:, seq_len-1] = IGNORE_INDEX

    total_ppl = torch.zeros(bs)

    for i in range(bs):
        logits = model(data[i,:].unsqueeze(0))
        loss = loss_fn(logits.view(1*seq_len, vocab_size), 
               labels[i,:])
        total_ppl[i] = loss.exp()

    return total_ppl.mean()


data = torch.randint(vocab_size, [4096, 512], )
ppl = cal_ppl(model, data)
print(ppl)

tensor(106.0619, grad_fn=<MeanBackward0>)


## 手动 PPL 计算方法

In [28]:
dim = 512,
seq_len = 3
vocab_size = 4
batch_size = 1
input_ids = torch.randint(vocab_size, [batch_size, seq_len], )
model = MyModel( dim = 512, vocab_size = 4)
with torch.no_grad():
    logits = model(input_ids)
print(logits.shape)
# print(logits)
print(input_ids)
print('input:', input_ids[:, :-1])
print('label:', input_ids[:, 1:])

torch.Size([1, 3, 4])
tensor([[2, 3, 1]])
input: tensor([[2, 3]])
label: tensor([[3, 1]])


In [29]:
p = F.softmax(logits, dim = -1)
print(p)
# print(p[0,:-1,:].shape) 
# print(input_ids[0, 1:].shape)
p_gather = p[0, :-1, :].gather(index=input_ids[0, 1:, None], dim = 1)
print(p_gather)

tensor([[[0.2148, 0.2090, 0.2403, 0.3359],
         [0.2998, 0.1619, 0.2801, 0.2581],
         [0.3909, 0.2395, 0.1924, 0.1772]]])
tensor([[0.3359],
        [0.1619]])


In [30]:
-(p_gather.log()).mean()

tensor(1.4558)