In [9]:
from mytool.data import get_data
dataloader = get_data("c4", "validation", first_n=128, tokenize=True, eos=True)

W0531 16:51:18.662310 139792472610624 builder.py:816] Found cached dataset json (/nvme/wangruohui/hfcache/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


Dataset({
    features: ['text', 'timestamp', 'url'],
    num_rows: 45576
})


In [10]:
from single import Transformer, ModelArgs
import torch


def init_model(device="cuda", weight_path="full_fused.pth"):
    args = ModelArgs()
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(args).to(device)
    print(f"create model \t {next(iter(model.parameters())).device} {next(iter(model.parameters())).dtype}")
    state_dict = torch.load(weight_path, map_location="cuda")
    model.load_state_dict(state_dict, strict=False)
    model.eval()
    model = model.to(device)

    return model

In [53]:
import math
from dataclasses import dataclass


@dataclass
class Stat:
    min: float = 0
    max: float = 0
    mean: float = 0
    var: float = 0
    count: int = 0

    def update(self, value):
        if self.count == 0:
            self.min = value
            self.max = value
            self.mean = value
            # self.M2 = 0
            self.var = 0
        else:
            self.min = min(self.min, value)
            self.max = max(self.max, value)
            # old_mean = self.mean
            self.var = (
                self.var * self.count / (self.count + 1)
                + (value - self.mean) ** 2 * self.count / (self.count + 1) ** 2
            )
            self.mean = (self.mean * self.count + value) / (self.count + 1)
            # self.M2 += (value-self.mean)*(value-old_mean)
            # self.var = self.M2 / (self.count + 1)

        self.count += 1

    def __str__(self):
        return f"min: {self.min}, max: {self.max}, mean: {self.mean}, var: {self.var}, count: {self.count}"


vals = [3, 4, 7, 11]
stat = Stat()
for v in vals:
    stat.update(v)

print(stat)

min: 3, max: 11, mean: 6.25, var: 9.687499999999998, count: 4


In [3]:
from dataclasses import dataclass


@dataclass
class Stat:
    min: float = 0
    max: float = 0
    mean: float = 0
    var: float = 0
    count: int = 0

    def update(self, value):
        if self.count == 0:
            self.min = value.clone().to(torch.float64)
            self.max = value.clone().to(torch.float64)
            self.mean = value.clone().to(torch.float64)
            self.var = 0
        else:
            self.min = torch.minimum(self.min, value)
            self.max = torch.maximum(self.max, value)
            self.var = (
                self.var * self.count / (self.count + 1)
                + (value - self.mean) ** 2 * self.count / (self.count + 1) ** 2
            )
            self.mean = (self.mean * self.count + value) / (self.count + 1)

        self.count += 1

    @property
    def std(self):
        return torch.sqrt(self.var)

    def __str__(self):
        return f"min: {self.min}\nmax: {self.max}\nmean: {self.mean}\nstd: {self.std}\ncount: {self.count}"

stat = Stat()    
stat.update(torch.tensor([1,42,53]))
stat.update(torch.tensor([3,32,55]))
stat.update(torch.tensor([6,32,35]))
stat.update(torch.tensor([5,72,57]))
print(stat)

min: tensor([ 1., 32., 35.], dtype=torch.float64)
max: tensor([ 6., 72., 57.], dtype=torch.float64)
mean: tensor([ 3.7500, 44.5000, 50.0000], dtype=torch.float64)
std: tensor([ 1.9203, 16.3936,  8.7750], dtype=torch.float64)
count: 4


In [11]:
from dataclasses import dataclass


@dataclass
class Stat:
    min: float = 0
    max: float = 0
    mean: float = 0
    M2: float = 0
    count: int = 0

    def update(self, value):
        assert value.ndim == 2
        local_count = value.shape[0]
        local_min = value.amin(dim=0).to(torch.float64)
        local_max = value.amax(dim=0).to(torch.float64)
        local_mean = value.mean(dim=0).to(torch.float64)
        local_M2 = value.var(dim=0, correction=0).to(torch.float64) * local_count

        if self.count == 0:
            self.min, self.max = local_min, local_max
            self.mean, self.M2 = local_mean, local_M2
        else:
            self.min = torch.minimum(self.min, local_min)
            self.max = torch.maximum(self.max, local_max)
            delta = local_mean - self.mean
            new_count = self.count + local_count
            self.M2 = self.M2 + local_M2 + delta ** 2 * self.count * local_count / new_count
            self.mean = self.mean + delta * local_count / new_count
        
        self.count += local_count

    @property
    def std(self):
        return torch.sqrt(self.M2 / self.count)

    def __str__(self):
        return f"min: {self.min}\nmax: {self.max}\nmean: {self.mean}\nstd: {self.std}\ncount: {self.count}"


stat = Stat()
stat.update(torch.tensor([[1., 42, 53],[3., 32, 55]]))
stat.update(torch.tensor([[6., 32, 35]]))
stat.update(torch.tensor([[5., 72, 57]]))
print(stat)

min: tensor([ 1., 32., 35.], dtype=torch.float64)
max: tensor([ 6., 72., 57.], dtype=torch.float64)
mean: tensor([ 3.7500, 44.5000, 50.0000], dtype=torch.float64)
std: tensor([ 1.9203, 16.3936,  8.7750], dtype=torch.float64)
count: 4


In [18]:
class PerChannelStatHook:
    stat = {}

    def __init__(self, name, which="output") -> None:
        name = str(name) + "-" + which
        assert name not in self.__class__.stat
        self.name = name
        self.which = which
        self.stat = Stat()
        self.__class__.stat[name] = self.stat

    def __call__(self, m, i, o):
        if self.which == "output":
            value = o
        elif self.which == "input":
            value = i

        assert isinstance(value, torch.Tensor)
        self.stat.update(value.squeeze())
    
    @classmethod
    def clear(cls):
        cls.stat = {}

DEVICE = "cuda:3"
model = init_model(device=DEVICE, weight_path="full_fused.pth")

for name, module in model.named_modules():
    if any(c in name for c in ["_norm", "attention.", "feed_forward."]):
        # print(name)
        hook = PerChannelStatHook(name, "output")
        module.register_forward_hook(hook)

In [19]:
import torch.nn as nn

total_loss = 0
total_token = 0
ns = 128

for token in dataloader[:ns]:
    if len(token) > 2048:
        token = token[:2048]
    token = torch.tensor([token]).to(DEVICE)
    logits = model(token, 0)
    # print(logits.shape)
    shift_logits = logits[:, :-1, :]
    shift_labels = token[:, 1:]
    loss_fct = nn.CrossEntropyLoss(reduction="sum")
    loss = loss_fct(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
    )
    # print(loss)
    total_loss += loss.float()
    total_token += shift_labels.size(-1)

ppl = torch.exp(torch.sum(total_loss) / total_token)
print("PPL", ppl.item())

PPL 6.7389750480651855


In [20]:
print(PerChannelStatHook.stat.keys())

dict_keys(['layers.0.attention.wq-output', 'layers.0.attention.wk-output', 'layers.0.attention.wv-output', 'layers.0.attention.wo-output', 'layers.0.feed_forward.w1-output', 'layers.0.feed_forward.w2-output', 'layers.0.feed_forward.w3-output', 'layers.0.attention_norm-output', 'layers.0.ffn_norm-output', 'layers.1.attention.wq-output', 'layers.1.attention.wk-output', 'layers.1.attention.wv-output', 'layers.1.attention.wo-output', 'layers.1.feed_forward.w1-output', 'layers.1.feed_forward.w2-output', 'layers.1.feed_forward.w3-output', 'layers.1.attention_norm-output', 'layers.1.ffn_norm-output', 'layers.2.attention.wq-output', 'layers.2.attention.wk-output', 'layers.2.attention.wv-output', 'layers.2.attention.wo-output', 'layers.2.feed_forward.w1-output', 'layers.2.feed_forward.w2-output', 'layers.2.feed_forward.w3-output', 'layers.2.attention_norm-output', 'layers.2.ffn_norm-output', 'layers.3.attention.wq-output', 'layers.3.attention.wk-output', 'layers.3.attention.wv-output', 'layers.

In [None]:

print(Hook._acts['model.layers.0.self_attn.q_rot-output'])
print(DistributionHook._act_counts.keys())
print(DistributionHook._act_counts['model.layers.0.self_attn.q_rot-output'])

In [None]:
import matplotlib.pyplot as plt
import os
import numpy as np
os.makedirs('vis-act-counts', exist_ok=True)
for k, v in DistributionHook._act_counts.items():
    v = v.cpu().numpy()
    # print(k)
    # print(v)

    plt.figure(figsize=(5,5))
    plt.bar(np.arange(42)-25, height=v)
    plt.title(k)
    plt.yscale('log')
    plt.grid()
    # plt.show()
    plt.savefig(f'vis-act-counts/{k}.png')
    plt.close()


In [None]:
# torch.save(Hook._acts, 'acts-mix.pt')

In [None]:
k, v  = next(iter(Hook._acts.items()))

In [None]:
cat = torch.cat(v, dim=0).transpose(1,2).contiguous().view(-1, 4096)

In [None]:
32* 128

In [None]:
import matplotlib.pyplot as plt
import os

save_dir = 'vis-acts'
os.makedirs(save_dir, exist_ok=True)

# html = open(f'{save_dir}/index.html', 'w')
# html.write('<html>\n<body>\n')

for k, v in Hook._acts.items():
    print(k, len(v))
    # cat = torch.cat(v, dim=1).squeeze()
    if 'rot' in k:
        cat = torch.cat(v, dim=0)
        print(cat.shape)
        cat = cat.transpose(1,2).contiguous()
        print(cat.shape)
        cat = cat.reshape(-1, 4096)
        print(cat.shape)
    else:
        continue
        cat = torch.cat(v, dim=0).contiguous().view(-1, 4096)
    absmax = cat.abs().max().item()
    cat = cat.cpu().numpy()
    
    plt.figure(figsize=(15,20))
    plt.imshow(cat, interpolation='none', cmap='bwr', vmin=-absmax, vmax=absmax)
    plt.title(k)
    plt.colorbar()
    plt.savefig(f'{save_dir}/{k}.png')
    plt.close()
    # html.write(f'<img src="{k}.png"></img>\n')

# html.write('</body>\n</html>')
# html.close()



In [None]:
k1 = "model.layers.22.self_attn.q_proj-output"
k2 = "model.layers.22.self_attn.q_rot-output"

acts = Hook._acts
a1 = acts[k1]
a2 = acts[k2]

# cat = torch.cat(a2, dim=0)
cat = a2[0]
print(cat.shape)
cat = cat.transpose(1,2)
print(cat.shape)
cat = cat.reshape(-1, 4096)

print(a1[0].shape)
print(a2[0].shape)

plt.figure(figsize=(15,20))
plt.subplot(1,2,1)
plt.imshow(a1[0][0].cpu().numpy(), interpolation='none', cmap='bwr')
# plt.imshow(a1[0][0][:,16*128:17*128].cpu().numpy(), interpolation='none', cmap='bwr')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(cat.cpu().numpy(), interpolation='none', cmap='bwr')
# plt.imshow(a2[0][0][16].cpu().numpy(), interpolation='none', cmap='bwr')
plt.colorbar()

In [None]:
def __init__(dim, max_position_embeddings=2048, base=10000, device=None):
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    print(inv_freq)
    print(1.0/inv_freq)

__init__(128)