In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
import time

import torch
import torch.nn as nn

from modelutils import *
from datautils import *

import datasets
import matplotlib.pyplot as plt
from bf16_fused_adam import BF16FusedAdamW
import datasets
from collections import Counter

In [3]:
def get_opt(model):
    import torch
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    from transformers import AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto', cache_dir="/mnt/nvme/llm_weights")
    print("ms", model.config.max_position_embeddings)
    model.seqlen = 1024*8# model.config.max_position_embeddings
    return model

In [4]:
model_name = "meta-llama/Meta-Llama-3-8B"
TARGET_BITS = 2.0

model = get_opt(model_name)
model.eval()

print(sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.model.layers.parameters()))

model

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

ms 8192
8030261248 6979584000


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [5]:
n_samples = 256

ds = datasets.load_from_disk("redpajama_tokenized_llama3/")

np.random.seed(47)
inds = np.random.randint(0, len(ds), size=n_samples)

dataloader = torch.LongTensor(ds[inds]["input_ids"])
dataloader.shape

torch.Size([256, 8192])

In [6]:
from cut_cross_entropy import linear_cross_entropy

model.cuda()
model.gradient_checkpointing_enable()
model.config.use_cache = False
model.train()

def f_hook(m, i, o):
    X = i[0].detach().float()
    X = X.reshape(-1, X.shape[-1])
    m.i_norm += X.square().mean(dim=0)
    
def b_hook(m, _, go):
    X = go[0].detach().float()
    X = X.reshape(-1, X.shape[-1])
    m.o_norm += X.square().mean(dim=0) * 1e6

for n, p in model.named_parameters():
    print(n)
    if "embed_tokens" not in n:
        p.requires_grad = False

handles = []

for n, m in model.named_modules():
    if type(m) == nn.Linear and "lm_head" not in n:
        print(n)
        m.i_norm = torch.zeros(m.weight.shape[1], device=m.weight.device)
        m.o_norm = torch.zeros(m.weight.shape[0], device=m.weight.device)
        handles.append(m.register_forward_hook(f_hook))
        handles.append(m.register_full_backward_hook(b_hook))

for idx, bx in enumerate(dataloader):
    #print(bx)
    print(bx.shape)
    bx = bx.cuda().unsqueeze(0)
    #lm_logits = model(bx.cuda()).logits
    embs = model.model(bx.cuda())[0]
    #shift_logits = lm_logits[:, :-1, :].contiguous()
    #shift_labels = bx[:, 1:]
    #loss_fct = nn.CrossEntropyLoss()
    #loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = linear_cross_entropy(embs, model.lm_head.weight, bx, shift=1)
    print(loss)
    #print(loss2)
    #qq = qqqq
    print(idx, loss)
    loss.backward()


    
for h in handles:
    h.remove()

model.embed_tokens.weight
model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.0.input_layernorm.weight
model.layers.0.post_attention_layernorm.weight
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.v_proj.weight
model.layers.1.self_attn.o_proj.weight
model.layers.1.mlp.gate_proj.weight
model.layers.1.mlp.up_proj.weight
model.layers.1.mlp.down_proj.weight
model.layers.1.input_layernorm.weight
model.layers.1.post_attention_layernorm.weight
model.layers.2.self_attn.q_proj.weight
model.layers.2.self_attn.k_proj.weight
model.layers.2.self_attn.v_proj.weight
model.layers.2.self_attn.o_proj.weight
model.layers.2.mlp.gate_proj.weight
model.layers.2.mlp.up_proj.weight
model.layers.2.mlp.down_proj.weight
model.layers.2.inp

In [7]:
a = torch.randn(3, 5)
a.sum(dim=1)

tensor([ 2.2889,  1.1245, -0.0119])

In [8]:
def power_iteration(A, num_iters=5):
    """
    Performs power iteration to compute the top singular vectors and value.
    
    Arguments:
        A (torch.Tensor): The input matrix of shape (m, n).
        num_iters (int): Number of iterations to perform.
    
    Returns:
        u (torch.Tensor): Dominant left singular vector (m,).
        sigma (torch.Tensor): Dominant singular value (scalar).
        v (torch.Tensor): Dominant right singular vector (n,).
    """
    # Start with a random vector on the appropriate device
    n = A.shape[1]
    v = torch.randn(n, device=A.device)
    v = v / torch.norm(v)
    
    for _ in range(num_iters):
        # Multiply A*v
        u = torch.mv(A, v)
        u_norm = torch.norm(u)
        if u_norm == 0:
            break
        u = u / u_norm
        
        # Multiply A^T*u
        v = torch.mv(A.t(), u)
        v_norm = torch.norm(v)
        if v_norm == 0:
            break
        v = v / v_norm
    
    # Estimate the dominant singular value as ||A*v||
    sigma = torch.norm(torch.mv(A, v))
    # The left singular vector corresponding to sigma:
    u = torch.mv(A, v) / sigma
    return u, sigma, v

def svd_abs(W):
    Sg = W.sign()
    Sg[Sg == 0] = 1
    u, s, v = power_iteration(W.abs(), num_iters=5)
    apx = s * torch.ger(u, v)
    
    return apx * Sg

def find_other2(A, W, nnz, Z, U, print_sc=None, debug=False, reg=0, rho_start=0.03, iters=3, prune_iters=1, flip=False, final=False):
    XX = A.T.matmul(A)
    XX += torch.diag(torch.ones_like(XX.diag())) * XX.diag().mean() * reg
    
    #norm2 = torch.ones_like(norm2)
    Wnn = W# * norm2.unsqueeze(1)
    rho = 1
    XY = A.T.matmul(Wnn)
    XXinv = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho)
    XXinv2 = torch.inverse(XX + torch.eye(XX.shape[1], device=XX.device)*rho_start)
    U = U
    Z = Z
    
    B = XXinv2.matmul(XY + rho_start*(Z-U))
    
    r_scale = c_scale = mask = None

    for itt in range(iters-1):
        Z = svd_abs(B+U)
        #print("   ", "z", itt, (A.matmul(Z) - W).square().sum().item(), W.square().sum().item())

        U = U + (B - Z)    

        B = XXinv.matmul(XY + rho*(Z-U))

    Z = svd_abs(B+U)
    #print("   ", "z", iters-1, (A.matmul(Z) - W).square().sum().item(), W.square().sum().item())
    U = U + (B - Z)
   
    return (Z), U


def factorizeT(W, XX, o_norm, asp=0.16, sp=0.16, iters=80):
    nza = int(W.shape[0]**2 * asp)
    nzb = int(W.numel() * sp - nza)
    
    norm = XX.sqrt().unsqueeze(1) + 1e-8
    norm_o = o_norm.sqrt() + 1e-8
       
    Wn = W * norm * norm_o
       
    mid = int(TARGET_BITS*(W.shape[0]*W.shape[1]) / (W.shape[0] + W.shape[1]))
    
    Az = torch.randn((W.shape[0], mid), device=W.device)
    Au = torch.zeros_like(Az)

    Bz = torch.randn((mid, W.shape[1]), device=W.device)
    Bu = torch.zeros_like(Bz)
    
    for itt in range(iters):
        #if itt < 10:
        #    rho_start = 0.0
        #elif itt < 15:
        #    rho_start = 0.00
        #else:
        #    rho_start = 0.1
        rho_start = min(1.0, itt / (iters-3))**3
        if True or itt > iters // 2:
            nzaa = nza
            nzbb = nzb
        else:
            alph = (itt / (iters // 2))**2
            nzaa = int(nza / 2 * (1-alph) + nza * alph)
            nzbb = int(nzb / 2 * (1-alph) + nzb * alph)

            
        mid = Bz.norm(dim=1) + 1e-12
        final = itt == iters - 1
        
        Az, Au = (x.T for x in find_other2(Bz.T / mid, Wn.T, nzaa, Az.T, Au.T, reg=3e-2, debug=False, rho_start=rho_start, flip=True, final=final))
        mid = Az.norm(dim=0) + 1e-12
        Bz, Bu = find_other2(Az / mid, Wn, nzbb, Bz, Bu, reg=3e-2, debug=False, rho_start=rho_start, final=final)
        #print("err", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())
        if itt == iters - 1:
            print("err", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())
            
    return ((Az / norm).matmul(Bz / norm_o)).T, (Bz / norm_o).T, (Az / norm).T, 1 / mid


def factorizef(W, XX, o_norm, asp=0.16, sp=0.16, iters=80, l_prev=None):
    s_time = time.time()
    if W.shape[0] >= W.shape[1]:
        return factorizeT(W.T, XX, o_norm, asp, sp=sp, iters=iters)
    
    #print("a")
    nza = int(W.shape[0]**2 * asp)
    nzb = int(W.numel() * sp - nza)
    norm = XX.sqrt() + 1e-8
    norm_o = (o_norm.sqrt() + 1e-8).unsqueeze(1)

    Wn = W * norm * norm_o
    mid = int(TARGET_BITS*(W.shape[0]*W.shape[1]) / (W.shape[0] + W.shape[1]))
    
    Az = torch.randn((W.shape[0], mid), device=W.device)
    Au = torch.zeros_like(Az)

    Bz = torch.randn((mid, W.shape[1]), device=W.device)
    Bu = torch.zeros_like(Bz)
    
    for itt in range(iters):
        #if itt < 10:
        #    rho_start = 0.0
        #elif itt < 15:
        #    rho_start = 0.00
        #else:
        #    rho_start = 0.1
            
        rho_start = min(1.0, itt / (iters-3))**3
        if True or itt > iters // 2:
            nzaa = nza
            nzbb = nzb
        else:
            alph = (itt / (iters // 2))**2
            nzaa = int(nza / 2 * (1-alph) + nza * alph)
            nzbb = int(nzb / 2 * (1-alph) + nzb * alph)
            
        
        final = itt == iters - 1
        mid = Bz.norm(dim=1) + 1e-12
        Az, Au = (x.T for x in find_other2(Bz.T / mid, Wn.T, nzaa, Az.T, Au.T, reg=3e-2, debug=False, rho_start=rho_start, final=final))        
        mid = Az.norm(dim=0) + 1e-12
        Bz, Bu = find_other2(Az / mid, Wn, nzbb, Bz, Bu, reg=3e-2, debug=False, rho_start=rho_start, flip=True, final=final)
        #print("err", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())
        #print(itt, time.time() - s_time, end =" ") 
        #print_scores(Az.matmul(Bz / norm))
        if itt == iters - 1:
            print("err", itt, ((Az / mid).matmul(Bz) - Wn).square().sum().item(), (Wn).square().sum().item())
            
    return (Az / norm_o).matmul(Bz / norm), Az / norm_o, Bz / norm, 1 / mid

def factorize(lx, l_prev=None):
    W = lx.weight.detach().float()
    
    sm = min(W.shape)
    lg = max(W.shape)
    mid = sm
    print("mid", mid)
    lim = 1.25
    for density in np.linspace(0.1, 0.4, 301):
        total_size = 0
        total_pars = 0
        total_pars += sm*lg
        mask_size = sm*mid + mid*lg
        total_ones2 = sm*lg * density
        p=total_ones2 / mask_size
        ent = -p*np.log2(p)-(1-p)*np.log2(1-p)
        total_size += (total_ones2*2 + mask_size*ent)
        #print(" ", density, total_size / total_pars)
        if total_size / total_pars < lim:
            sp = density
    
    
    if W.shape[0] == W.shape[1]:
        asp = sp/2
    else:
        asp = sp
    W2, Ab, Bb, mid = factorizef(W, lx.i_norm, lx.o_norm, asp=asp, sp=sp, l_prev=l_prev, iters=260)
    Ac = Ab
    
    #Ac = Ab
    #W3 = Ac.matmul(Bb)
    
    #Bc = get_at(lx.XX, W.T, Bb.T, Ac.T).T
    Bc = Bb
    
    An = Ac.norm() + 1e-12
    Bn = Bc.norm() + 1e-12
    Ac *= (Bn/An).sqrt()
    Bc *= (An/Bn).sqrt()
    
    W3 = (Ac * mid).matmul(Bc)
    assert W3.shape == lx.weight.shape
    print("sparsity check", ((Ac != 0).sum() + (Bb != 0).sum()).item() / W3.numel())
    return W3, Ac, Bc, mid

In [9]:
import cupy

def my_pack(x):
    x = (x == 1).to(torch.uint8)
    out = torch.zeros((x.shape[0]//8), device=x.device, dtype=torch.uint8)
    for i in range(8):
        out += x[i::8] << (7 - i)
    return out

@torch.compile
def my_unpack(x):
    out = torch.zeros((x.shape[0], 8), device=x.device, dtype=torch.int8)
    for i in range(8):
        out[:,i] = (x >> (7 - i)) & 1
    return out.flatten() * 2 - 1

def power_iteration(A, num_iters=5):
    """
    Performs power iteration to compute the top singular vectors and value.
    
    Arguments:
        A (torch.Tensor): The input matrix of shape (m, n).
        num_iters (int): Number of iterations to perform.
    
    Returns:
        u (torch.Tensor): Dominant left singular vector (m,).
        sigma (torch.Tensor): Dominant singular value (scalar).
        v (torch.Tensor): Dominant right singular vector (n,).
    """
    # Start with a random vector on the appropriate device
    n = A.shape[1]
    v = torch.randn(n, device=A.device)
    v = v / torch.norm(v)
    
    for _ in range(num_iters):
        # Multiply A*v
        u = torch.mv(A, v)
        u_norm = torch.norm(u)
        if u_norm == 0:
            break
        u = u / u_norm
        
        # Multiply A^T*u
        v = torch.mv(A.t(), u)
        v_norm = torch.norm(v)
        if v_norm == 0:
            break
        v = v / v_norm
    
    # Estimate the dominant singular value as ||A*v||
    sigma = torch.norm(torch.mv(A, v))
    # The left singular vector corresponding to sigma:
    u = torch.mv(A, v) / sigma
    return u, sigma, v

def svd_abs2(W):
    Sg = W.sign()
    Sg[Sg == 0] = 1
    u, s, v = power_iteration(W.abs(), num_iters=5)
    apx = s * torch.ger(u, v)
    
    return u * s, Sg, v

class BitLinear(nn.Module):
    def __init__(self, b):
        super().__init__()
        
        #u, b, v = svd_abs(w.float())
        b_packed = my_pack(b.flatten())
        self.shape = b.shape
        #print(b)
        #print(my_unpack(b_packed).reshape(b.shape))
        
        self.register_buffer("bp", b_packed)

    
    def forward(self, x):
        return x.matmul(my_unpack(self.bp).reshape(self.shape).T.to(x.dtype))
        

        
class Mul(nn.Module):
    def __init__(self, w):
        super().__init__()
        #print("w", w.amin().item(), w.median().item(), w.amax().item())
        
        self.register_buffer("w", w)
    
    def forward(self, x):
        return x * self.w.to(x.dtype)
        

def replace(lx):
    dev = "cuda"
    m1 = lx.weight.B
    m2 = lx.weight.A
    
    u1, b1, v1 = svd_abs2(m1.float())
    u2, b2, v2 = svd_abs2(m2.float())
    
    lx2 = nn.Sequential(
        Mul(v1),
        BitLinear(b1),
        Mul(u1*lx.weight.mid*v2),
        BitLinear(b2),
        Mul(u2)
    )
    return lx2

@torch.no_grad()
def opt_sequential(model, dataloader, dev):
    print('Starting ...')
    
    model.cpu()
    model.gradient_checkpointing_disable()
    model.eval()
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    

    model.model.embed_tokens = model.model.embed_tokens.to(dev) 
    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (n_samples, model.seqlen, model.config.hidden_size), dtype=dtype, device="cpu"
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch.unsqueeze(0).to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    torch.cuda.empty_cache()

    comp_inps = inps.clone()
    attention_mask = cache['attention_mask']
    position_embeddings = cache['position_embeddings']

    print('Ready.')

    layers = model.model.layers

    for i in range(len(layers)):
        layer = layers[i].to(dev)

        subset = find_layers(layer)
        for j in range(n_samples):
            inps[j] = layer(inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        imp = layer.mlp.down_proj.o_norm
        
        for name in [
            "self_attn.q_proj",
            "self_attn.v_proj",
            "self_attn.o_proj",
            "self_attn.k_proj",
            "mlp.up_proj",
            "mlp.gate_proj",
            "mlp.down_proj",
        ]:
            #if "gate_proj" not in name:
            #    continue
            to_opt = {n: p for n, p in layer.named_parameters() if "weight" in n and "layernorm" not in n}
            if len(to_opt) > 0 and (("q_proj" in name and i >= 1) or "k_proj" in name):
                
                for n, p in to_opt.items():
                    p.requires_grad = True
                print(to_opt.keys())

                #opt = torch.optim.Adam(to_opt.values(), lr=1e-5)
                #opt = Lamb(to_opt.values(), lr=1e-3, weight_decay=1e-4)
                #sch = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1e-8, total_iters=16)
                lr = 3e-5
                opt = BF16FusedAdamW(to_opt.values(), lr, weight_decay=1e-4)
                sch = torch.optim.lr_scheduler.OneCycleLR(opt, lr, total_steps=n_samples*8 // 8, cycle_momentum=False)
                err_before = 0
                for j in range(n_samples):
                    cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                    err_before += ((cur_out.float() - inps[j].cuda().float()).square() * imp).mean().item()

                print("err before", err_before)

                with torch.enable_grad():
                    for ep in range(8):
                        err_total = 0
                        for j in range(n_samples):
                            cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                            err = ((cur_out.float() - inps[j].cuda().float()).square() * imp).sum()
                            err.backward()
                            if j % 8 == 7:
                                opt.step()
                                sch.step()
                                layer.zero_grad(set_to_none=True)
                            err_total += err.item() / inps.shape[1] / inps.shape[2]
                        print(ep, err_total)

                err_after = 0
                for j in range(n_samples):
                    cur_out = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                    err_after += ((cur_out.float() - inps[j].cuda().float()).square() * imp).mean().item()
                print("err after ", err_after)
                        
            #gpts[name].free()
            
            print(i, name)
            print('Pruning ...')
            lx = subset[name]
            
            s1 = time.time()
            #W2 = go_admm(lx)
            W2, Ac, Bb, mid, = factorize(lx)
            W2 = W2.T
            err_spxsp = (W2.T - lx.weight).square().sum().item()
            print("err_spxsp", err_spxsp)
            
            
            lx.weight.data = W2.T.to(lx.weight)
            lx.weight.A = Ac
            lx.weight.B = Bb
            lx.weight.mid = mid
            parts = name.split('.')
            block = getattr(layer, parts[0])
            setattr(block, parts[1], replace(lx))
            
            
            
        for j in range(n_samples):
            comp_inps[j] = layer(comp_inps[j].unsqueeze(0).cuda(), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
            
        layers[i] = layer.cpu()
        del layer
        
        torch.cuda.empty_cache()

        
start = time.time()
model.cpu()
opt_sequential(model, dataloader, DEV)
print("total time", time.time() - start)

Starting ...
Ready.
0 self_attn.q_proj
Pruning ...
mid 4096
err 259 0.4383566975593567 9841.158203125
sparsity check 2.0
err_spxsp 533.7374267578125
0 self_attn.v_proj
Pruning ...
mid 1024
err 259 13.441218376159668 1171.7940673828125
sparsity check 1.99951171875
err_spxsp 33.56348419189453
0 self_attn.o_proj
Pruning ...
mid 4096
err 259 9.076663970947266 436.83599853515625
sparsity check 2.0
err_spxsp 134.54632568359375
dict_keys(['self_attn.k_proj.weight', 'mlp.gate_proj.weight', 'mlp.up_proj.weight', 'mlp.down_proj.weight'])
err before 0.0025646664507803507
0 0.002087427475998993
1 0.0014450392134222056
2 0.001128885020079906
3 0.0012699482997504674
4 0.001326797969340987
5 0.0009757646530488273
6 0.0008886884204457601
7 0.0008791659238340799
err after  0.0008783414446043025
0 self_attn.k_proj
Pruning ...
mid 1024
err 259 0.3716686964035034 378.6822204589844
sparsity check 1.99951171875
err_spxsp 253.10183715820312
0 mlp.up_proj
Pruning ...
mid 4096
err 259 27.28969383239746 394.871

In [10]:
@torch.no_grad()
def opt_eval(model, testenc, dev, dataset: str, log_wandb: bool = False):
    print('Evaluating ...')

    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for i in range(nsamples):
        batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
        try:
            model(batch)
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_embeddings = cache['position_embeddings']

    for i in range(len(layers)):
        print(i)
        layer = layers[i].to(dev)

        
        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if model.model.norm is not None:
        model.model.norm = model.model.norm.to(dev)
    model.lm_head = model.lm_head.to(dev)

    testenc = testenc.to(dev)
    nlls = []
    for i in range(nsamples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[
            :, (i * model.seqlen):((i + 1) * model.seqlen)
        ][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    print(f"Perplexity: {ppl.item():3f}")
    if log_wandb:
         wandb.log({f'{dataset}/perplexity': ppl.item()})

    model.config.use_cache = use_cache

In [11]:
model.gradient_checkpointing_disable()
model.eval()

for dataset in ['wikitext2']:
    dataloader, testloader = get_loaders(
        dataset, seed=0, model=model_name, seqlen=model.seqlen
    )
    print(dataset)
    opt_eval(model, testloader, DEV, dataset, False)

tok 128000 128001
wikitext2
Evaluating ...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
Perplexity: 9.049220


In [12]:
torch.save(model.state_dict(), "/mnt/nvme/llamapush/llama3-8B-dsf1bit-20-ft.pt")