In [1]:
import argparse
import os 
from datetime import datetime
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version

from lib.prune import prune_wanda, prune_sparsegpt, prune_magnitude
from lib.prune import prune_sparsegpt_ww, prune_wanda_ww
from lib.eval import eval_ppl, eval_zero_shot
from lib.esd_utils import get_esd_metrics

from lib.sam_prune import prune_wanda_sam, prune_sparsegpt_sam

from lib.utils import check_sparsity

from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map

from segment_anything import sam_model_registry

from segment_anything.utils.dataset import dataset_dis, dataset_dis_val, dataset_duts, dataset_duts_te
from segment_anything.utils.dataloader import get_im_gt_name_dict, create_dataloaders, RandomHFlip, Resize, LargeScaleJitter
from segment_anything.utils import misc
from segment_anything.utils.loss import norm_attn, pca_fit_transform, sig_ce_loss, dice_loss, mask_iou, sig_mae_score, f1_score


print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

debug = False 

import time 
import heapq 
import torch 
import torch.nn as nn 
import transformers
import numpy as np
from tqdm import tqdm
import logging

from typing import List, Optional, Tuple, Union
from torch import nn

# from .sam_sparsegpt import SparseGPT, SparseGPT_NoReconstruct, SparseGPTV3
from lib.sam_layerwrapper import WrappedGPT, WrappedGPTV3, SparseGPTV3, SparseGPT, SparseGPTV2
from lib.data import get_loaders, prepare_calibration_input_sam

from lib.ablate import AblateGPT 

from lib.matmul_had import *
from lib.utils import *

torch 2.3.1+cu121
transformers 4.47.1
accelerate 0.29.1
# of gpus:  8


In [2]:
class PruneConfig:
    def __init__(self):
        self.model = '/h3cstore_ns/jcxie/SAM/SVD_SAM/pretrain/sam_vit_b_01ec64.pth'
        self.model_name = 'vit_b'
        self.seed = 0
        self.nsamples = 128
        self.sparsity_ratio = 0.5
        self.sparsity_type = "unstructured"
        self.prune_method = "sparsegpt_silu_ww"
        self.cache_dir = "llm_weights"
        self.use_variant = False
        self.save = '/h3cstore_ns/jcxie/LISA/wanda-main/ckpt'
        self.save_model = None
        self.exclude = 'gate_proj'
        self.ww_metric = "alpha_peak"
        self.ww_metric_cache = "/h3cstore_ns/jcxie/LISA/wanda-main/data/llama2-7b-hf"
        self.epsilon = 0.3
        self.mapping_type = "block_wise"
        self.Hyper_m = 3.0
        self.Lamda = 0.20
        self.eval_zero_shot = False
args = PruneConfig()
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)

if "ww" in args.prune_method and not os.path.exists(f"{args.ww_metric_cache}/{args.ww_metric}.npy"):
    metric_values = get_esd_metrics(args.model, args.ww_metric, args.cache_dir)
    np.save(f"{args.ww_metric_cache}/{args.ww_metric}.npy", metric_values)

# Handling n:m sparsity
prune_n, prune_m = 0, 0
if args.sparsity_type != "unstructured":
    assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
    prune_n, prune_m = map(int, args.sparsity_type.split(":"))

model = sam_model_registry[args.model_name](args.model).cuda()

valid_datasets = [dataset_dis_val]

input_size = [1024,1024]
print("--- create valid dataloader ---")
valid_im_gt_list = get_im_gt_name_dict(valid_datasets, flag="valid")
valid_dataloaders, valid_datasets = create_dataloaders(valid_im_gt_list,
                                                        my_transforms = [
                                                                    Resize(input_size)
                                                                ],
                                                        batch_size=1,
                                                        training=False)
print(len(valid_dataloaders), " valid dataloaders created")

for n, p in model.named_parameters():
    print(n, p.size())

--- create valid dataloader ---
------------------------------ valid --------------------------------
--->>> valid  dataset  0 / 1   DIS5K-VD <<<---
-im- DIS5K-VD /h3cstore_ns/jcxie/SAM/SVD_SAM/data/data/DIS5K/DIS-VD/im :  470
-gt- DIS5K-VD /h3cstore_ns/jcxie/SAM/SVD_SAM/data/data/DIS5K/DIS-VD/gt :  470
1  valid dataloaders created
image_encoder.pos_embed torch.Size([1, 64, 64, 768])
image_encoder.patch_embed.proj.weight torch.Size([768, 3, 16, 16])
image_encoder.patch_embed.proj.bias torch.Size([768])
image_encoder.blocks.0.norm1.weight torch.Size([768])
image_encoder.blocks.0.norm1.bias torch.Size([768])
image_encoder.blocks.0.attn.rel_pos_h torch.Size([27, 64])
image_encoder.blocks.0.attn.rel_pos_w torch.Size([27, 64])
image_encoder.blocks.0.attn.qkv.weight torch.Size([2304, 768])
image_encoder.blocks.0.attn.qkv.bias torch.Size([2304])
image_encoder.blocks.0.attn.proj.weight torch.Size([768, 768])
image_encoder.blocks.0.attn.proj.bias torch.Size([768])
image_encoder.blocks.0.norm2.w

In [3]:
device = 'cuda'
dataloader = valid_dataloaders
ratios = None
dual_ascent = False
valid_out = False

with torch.no_grad():
    dev = device

    inps, outs = prepare_calibration_input_sam(args, model, dataloader[0], args.nsamples)

    layers = model.image_encoder.blocks
    layer_num = len(find_layers(layers))

    if ratios is None:
        ratios = [args.sparsity_ratio for i in range(layer_num)]
    k=0

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        gpts = {}
        for name in subset:
            if i == 4 and name == 'mlp.lin1':
                gpts[name] = SparseGPTV3(subset[name])
            else:
                gpts[name] = SparseGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.nsamples):
            with torch.no_grad(): 
                outs[j] = layer(inps[j].unsqueeze(0))[0]
        for h in handles:
            h.remove()

        if i == 4:
            test_H = gpts['mlp.lin1'].H
            test_H_B = gpts['mlp.lin1'].H_B
            ora_W = model.image_encoder.blocks[i].mlp.lin1.weight.data

        for name in subset:
            print(i, name)
            gpts[name].fasterprune(ratios[k], prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)

            if dual_ascent:
                flag, alpha, beta = gpts[name].get_args()
                min_iter = 0
                if flag:
                    gpts[name].dual_ascent2(beta = beta, alpha = alpha, min_iter=min_iter, theld= args.dual_theld)
                    gpts[name].del_valid()
                else:
                    del gpts[name].H, gpts[name].H_B

            gpts[name].free()
            k+=1

        if i == 4:
            new_W = gpts['mlp.lin1'].layer.weight.data
            break

        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0))[0]

        layers[i] = layer
        inps, outs = outs, inps

0 attn.qkv
0 attn.proj
0 mlp.lin1
0 mlp.lin2
1 attn.qkv
1 attn.proj
1 mlp.lin1
1 mlp.lin2
2 attn.qkv
2 attn.proj
2 mlp.lin1
2 mlp.lin2
3 attn.qkv
3 attn.proj
3 mlp.lin1
3 mlp.lin2
4 attn.qkv
4 attn.proj
4 mlp.lin1
4 mlp.lin2


In [3]:
device = 'cuda'
dataloader = valid_dataloaders
ratios = None
dual_ascent = False
valid_out = False

with torch.no_grad():
    dev = device

    inps, outs = prepare_calibration_input_sam(args, model, dataloader[0], args.nsamples)

    layers = model.image_encoder.blocks
    layer_num = len(find_layers(layers))

    if ratios is None:
        ratios = [args.sparsity_ratio for i in range(layer_num)]
    k=0

    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        gpts = {}
        for name in subset:
            if i == 4 and name == 'mlp.lin1':
                gpts[name] = SparseGPTV2(subset[name])
            else:
                gpts[name] = SparseGPT(subset[name])

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)
            return tmp

        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))

        for j in range(args.nsamples):
            with torch.no_grad(): 
                outs[j] = layer(inps[j].unsqueeze(0))[0]
        for h in handles:
            h.remove()

        if i == 4:
            test_inps = gpts['mlp.lin1'].inps
            test_outs = gpts['mlp.lin1'].outs
            ora_W = model.image_encoder.blocks[i].mlp.lin1.weight.data

        for name in subset:
            print(i, name)
            gpts[name].fasterprune(ratios[k], prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)

            if dual_ascent:
                flag, alpha, beta = gpts[name].get_args()
                min_iter = 0
                if flag:
                    gpts[name].dual_ascent2(beta = beta, alpha = alpha, min_iter=min_iter, theld= args.dual_theld)
                    gpts[name].del_valid()
                else:
                    del gpts[name].H, gpts[name].H_B

            gpts[name].free()
            k+=1

        if i == 4:
            new_W = gpts['mlp.lin1'].layer.weight.data
            break

        for j in range(args.nsamples):
            with torch.no_grad():
                outs[j] = layer(inps[j].unsqueeze(0))[0]

        layers[i] = layer
        inps, outs = outs, inps

0 attn.qkv
0 attn.proj
0 mlp.lin1
0 mlp.lin2
1 attn.qkv
1 attn.proj
1 mlp.lin1
1 mlp.lin2
2 attn.qkv
2 attn.proj
2 mlp.lin1
2 mlp.lin2
3 attn.qkv
3 attn.proj
3 mlp.lin1
3 mlp.lin2
4 attn.qkv
4 attn.proj
4 mlp.lin1
4 mlp.lin2


In [4]:
test_inps = torch.stack(test_inps, dim=0)
test_outs = torch.stack(test_outs, dim=0)

print('ora_W', ora_W.shape)
print('new_W', new_W.shape)
print('test_inps', test_inps.shape)
print('test_outs', test_outs.shape)
import torch
ora_W = ora_W.cuda()
new_W = new_W.cuda()
test_inps = test_inps.cpu()
test_outs = test_outs.cpu()

ora_W torch.Size([3072, 768])
new_W torch.Size([3072, 768])
test_inps torch.Size([128, 768, 4096])
test_outs torch.Size([128, 3072, 4096])


In [5]:
(new_W == 0).sum() / new_W.numel()

tensor(0.5000, device='cuda:0')

In [17]:
from tqdm import tqdm
import math
def dual_ascent_method3(H_A, H_B, W_old, M, beta, alpha, gama, rho, epsilon=3e-2, max_iter=10, lambda_zero=False, percdamp=.01):
    M = M.to(torch.float32)
    H_A = H_A.to(torch.float32)
    H_B = H_B.to(torch.float32)
    W_old = W_old.to(torch.float32)
    
    # 初始化 W 和 Lambda
    W = W_old.clone()
    if lambda_zero:
        Lambda = torch.zeros_like(W)
    else:
        term1 = beta * (torch.mm(W, H_A) - H_B)
        term2 = alpha * (W - W_old)
        Lambda = -M * (term1 + term2)

    for k in tqdm(range(max_iter)):
        # 保存上一次的 W
        W_prev = W.clone()

        # 更新 W
        A = (beta + gama) * H_A + alpha * torch.eye(H_A.shape[0], device=H_A.device)
        try:
            damp = percdamp * torch.mean(torch.diag(A))
            diag = torch.arange(A.shape[-1], device=A.device)
            A[diag, diag] += damp
            A = torch.linalg.cholesky(A)
            A_inv = torch.cholesky_inverse(A)
            # A_inv = torch.linalg.cholesky(A, upper=True)
        except RuntimeError as e:
            print(f"Cholesky decomposition failed: {e}. Falling back to direct inverse.")
            raise e
        
        B = beta * H_B + alpha * W_old
        W = torch.mm(B - (M * Lambda), A_inv)

        # 更新 Lambda
        Lambda = Lambda + rho * (M * W)
        # W = W - M * W
        # 收敛判断

        if k % 10 == 0 :
            

            # print( (W.to(torch.float32) @ test_inps[0].cuda() - test_outs[0].cuda()).abs().max(), (W.to(torch.float32) @ test_inps[0].cuda() - test_outs[0][0].cuda()).abs().mean())
            # print(torch.norm(W - W_prev))
            if torch.norm(W - W_prev) < epsilon:
                print( (W.to(torch.float32) @ test_inps[0].cuda() - test_outs[0].cuda()).abs().max(), (W.to(torch.float32) @ test_inps[0].cuda() - test_outs[0][0].cuda()).abs().mean())
                print(f"Converged at iteration {k}")
                print(torch.norm(W - W_prev))
                break
        # if k % 10 == 0 :
        #     print(torch.norm(W - W_prev))
    W = W - M * W
    return W.to(torch.float32)

test_x = test_inps[0]
test_y = test_outs[0]

H_A = torch.zeros((test_x.shape[-2], test_x.shape[-2]), device="cuda")
# H_B = torch.zeros((test_y.shape[-2], test_x.shape[-2]), device="cuda")
nsamples = 0
for i in range(120):
    H_A *= nsamples / (nsamples + 1)
    # H_B *= nsamples / (nsamples + 1)

    nsamples += 1

    H_A += (math.sqrt(2 / nsamples) * test_inps[i].cuda().to(torch.float32)) @ (math.sqrt(2 / nsamples) * test_inps[i].cuda().to(torch.float32)).T
    # H_B += (math.sqrt(2 / nsamples) * test_outs[i][0].cuda().to(torch.float32)) @ (math.sqrt(2 / nsamples) * test_inps[i].cuda().to(torch.float32)).T

H_B = ora_W @ H_A
    

M = (new_W == 0).to(torch.float32).to(H_B.device)
print(M.device)
# M = torch.zeros_like(M).to(torch.float32).cuda()
print((M == 0).sum() / M.numel())
W_old = new_W.clone().to(H_B.device)
beta = 0.99
alpha = 0.01
gama = 0.0000
rho = 1

W_update = dual_ascent_method3(H_A, H_B, W_old, M, beta, alpha, gama, rho, lambda_zero=True, max_iter=10000)

cuda:0
tensor(0.5000, device='cuda:0')


  2%|▏         | 180/10000 [00:00<00:20, 490.64it/s]

tensor(3.1055, device='cuda:0') tensor(0.8515, device='cuda:0')
Converged at iteration 180
tensor(0.0298, device='cuda:0')





In [30]:
print((W_update == 0).sum() / W_update.numel())

for i in [0, 32, 64, 125, 126]:
    W_update = W_update - M.to(torch.float32) * W_update
    print(test_outs[i].abs().max(), test_outs[i].abs().mean())
    print(i, (W_update @ test_inps[i].cuda() - ora_W @ test_inps[i].cuda()).abs().max(), (W_update @ test_inps[i].cuda() - ora_W @ test_inps[i].cuda()).abs().mean())
    print(i, (new_W.to(H_B.device) @ test_inps[i].cuda() - ora_W @ test_inps[i].cuda()).abs().max(), (new_W.to(H_B.device) @ test_inps[i].cuda() - ora_W @ test_inps[i].cuda()).abs().mean())
    print("-----------")
    # print((new_W @ test_inps[i]).abs().max(), (new_W @ test_inps[i]).abs().mean())
    # print((W_update2 @ test_inps[i]).abs().max(), (W_update2 @ test_inps[i]).abs().mean())
    # print(test_outs[i][0].abs().max(), test_outs[i][0].abs().mean())
    print("-------------------------------------------------------------------------")

tensor(0.5000, device='cuda:0')
tensor(7.3101) tensor(0.9603)
0 tensor(0.5665, device='cuda:0') tensor(0.0662, device='cuda:0')
0 tensor(0.5296, device='cuda:0') tensor(0.0501, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.0926) tensor(0.9997)
32 tensor(0.6468, device='cuda:0') tensor(0.0814, device='cuda:0')
32 tensor(0.5481, device='cuda:0') tensor(0.0730, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(6.8198) tensor(0.9867)
64 tensor(0.5995, device='cuda:0') tensor(0.0764, device='cuda:0')
64 tensor(0.5153, device='cuda:0') tensor(0.0665, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.1617) tensor(0.9814)
125 tensor(0.7372, device='cuda:0') tensor(0.0766, device='cuda:0')
125 tensor(0.5629, device='cuda:0') tensor(0.0660, device='cuda:0')
-----------
------------------------------------------

In [14]:
print((W_update == 0).sum() / W_update.numel())

for i in [0, 32, 64, 125, 126]:
    W_update = W_update - M.to(torch.float32) * W_update
    print(test_outs[i].abs().max(), test_outs[i].abs().mean())
    print(i, (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().max(), (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().mean())
    print(i, (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i][0].cuda()).abs().max(), (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i][0].cuda()).abs().mean())
    print("-----------")
    # print((new_W @ test_inps[i]).abs().max(), (new_W @ test_inps[i]).abs().mean())
    # print((W_update2 @ test_inps[i]).abs().max(), (W_update2 @ test_inps[i]).abs().mean())
    # print(test_outs[i][0].abs().max(), test_outs[i][0].abs().mean())
    print("-------------------------------------------------------------------------")

tensor(0.5000, device='cuda:0')
tensor(7.3101) tensor(0.9603)
0 tensor(3.2922, device='cuda:0') tensor(0.8560, device='cuda:0')
0 tensor(7.0352, device='cuda:0') tensor(0.8510, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.0926) tensor(0.9997)
32 tensor(3.3288, device='cuda:0') tensor(0.8556, device='cuda:0')
32 tensor(7.5410, device='cuda:0') tensor(0.9369, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(6.8198) tensor(0.9867)
64 tensor(3.2881, device='cuda:0') tensor(0.8565, device='cuda:0')
64 tensor(7.6727, device='cuda:0') tensor(0.9397, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.1617) tensor(0.9814)
125 tensor(3.3770, device='cuda:0') tensor(0.8556, device='cuda:0')
125 tensor(7.6420, device='cuda:0') tensor(0.8889, device='cuda:0')
-----------
------------------------------------------

In [13]:
print((W_update == 0).sum() / W_update.numel())
for i in [0, 32, 64, 125, 126]:
    W_update = W_update - M.to(torch.float32) * W_update
    print(test_outs[i].abs().max(), test_outs[i].abs().mean())
    print(i, (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().max(), (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().mean())
    print(i, (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i][0].cuda()).abs().max(), (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i][0].cuda()).abs().mean())
    print("-----------")
    # print((new_W @ test_inps[i]).abs().max(), (new_W @ test_inps[i]).abs().mean())
    # print((W_update2 @ test_inps[i]).abs().max(), (W_update2 @ test_inps[i]).abs().mean())
    # print(test_outs[i][0].abs().max(), test_outs[i][0].abs().mean())
    print("-------------------------------------------------------------------------")

tensor(0.8000, device='cuda:0')
tensor(7.3737) tensor(0.9544)
0 tensor(6.7985, device='cuda:0') tensor(0.8591, device='cuda:0')
0 tensor(6.7128, device='cuda:0') tensor(0.8431, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.5892) tensor(1.0019)
32 tensor(7.2514, device='cuda:0') tensor(0.8324, device='cuda:0')
32 tensor(7.3916, device='cuda:0') tensor(0.9059, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.4044) tensor(0.9736)
64 tensor(6.4009, device='cuda:0') tensor(0.7625, device='cuda:0')
64 tensor(8.5443, device='cuda:0') tensor(0.9947, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.8228) tensor(0.9756)
125 tensor(7.3288, device='cuda:0') tensor(0.8225, device='cuda:0')
125 tensor(8.1914, device='cuda:0') tensor(0.8897, device='cuda:0')
-----------
------------------------------------------

In [5]:
from tqdm import tqdm
import math

# test_x = test_inps[0]
# test_y = test_outs[0]

H_A = test_H.clone().to("cuda")
H_B = test_H_B.clone().to("cuda")

M = (new_W == 0).to(torch.float32).to(H_B.device)
print(M.device)
# M = torch.zeros_like(M).to(torch.float32).cuda()
print((M == 0).sum() / M.numel())
W_old = new_W.clone().to(H_B.device)

def dual_ascent2(W_old, H_A, H_B, beta=0.01, alpha=0.99, gama=0.0000, rho=1, epsilon=1e-2, max_iter=1000, lambda_zero=False, percdamp=.01, min_iter=300, theld=0.07):
    

    old_score = 0
    new_score = 0
    # for v in range(len(self.valid_inps)):
    #     old_score += (W_old @ self.valid_inps[v] - self.valid_outs[v]).abs().mean()
    # old_score = old_score / len(self.valid_inps)
    W_old = W_old.to(torch.float32)
    M = (W_old == 0).to(torch.float32)


    W = W_old.clone()
    if lambda_zero:
        Lambda = torch.zeros_like(W)
    else:
        term1 = beta * (torch.mm(W, H_A) - H_B)
        term2 = alpha * (W - W_old)
        Lambda = -M * (term1 + term2)

    for k in tqdm(range(max_iter)):
        # 保存上一次的 W
        W_prev = W.clone()

        # 更新 W
        A = (beta + gama) * H_A + alpha * torch.eye(H_A.shape[0], device=H_A.device)
        try:
            damp = percdamp * torch.mean(torch.diag(A))
            diag = torch.arange(A.shape[-1], device=A.device)
            A[diag, diag] += damp
            A = torch.linalg.cholesky(A)
            A_inv = torch.cholesky_inverse(A)
            # A_inv = torch.linalg.cholesky(A, upper=True)
        except RuntimeError as e:
            print(f"Cholesky decomposition failed: {e}. Falling back to direct inverse.")
            raise e
        
        B = beta * H_B + alpha * W_old
        W = torch.mm(B - (M * Lambda), A_inv)

        # 更新 Lambda
        Lambda = Lambda + rho * (M * W)

        # 收敛判断
        if k % 100 == 0 :
            # W = W - M * W
            
            if k > min_iter:
                if torch.norm(W - W_prev) < epsilon:
                    print(f"Converged at iteration {k}")
                    print(torch.norm(W - W_prev))
                    break
        # if k % 10 == 0 :
        #     print(torch.norm(W - W_prev))

    torch.cuda.synchronize()
    W = W - M * W
    print(torch.norm(W - W_prev))

    # for v in range(len(self.valid_inps)):
    #     new_score = (W.to(torch.float32) @ self.valid_inps[v] - self.valid_outs[v]).abs().mean()
    # new_score = new_score / len(self.valid_inps)

    # print("old_score:", old_score, "new_score:", new_score)
    # if new_score < (old_score * (1 - theld)):
    #     print("Converged!")
    #     self.layer.weight.data = W.to(torch.float32)
    # else:
    #     print("Not converged!")
    #     self.layer.weight.data = W_old.to(torch.float32)

    # print("Dual ascent finished!")
    del W_old, H_A, H_B, A, B, Lambda
    torch.cuda.empty_cache()
    return W.to(torch.float32)

W_update = dual_ascent2(W_old, H_A, H_B, beta=0.99, alpha=0.01, gama=0.0000, rho=1, lambda_zero=True, max_iter=10000, percdamp=.01, min_iter=300, theld=0.07)

cuda:0
tensor(0.5000, device='cuda:0')


  8%|▊         | 800/10000 [00:03<00:40, 224.55it/s]

Converged at iteration 800
tensor(0.0096, device='cuda:0')
tensor(5.2047, device='cuda:0')





In [6]:
print(test_inps.shape)
print(test_outs.shape)

NameError: name 'test_inps' is not defined

In [16]:
print((W_update == 0).sum() / W_update.numel())
for i in [0, 32, 64, 125, 126]:
    W_update = W_update - M.to(torch.float32) * W_update
    print(test_outs[i].abs().max(), test_outs[i].abs().mean())
    print((W_update @ test_inps[i].cuda()).abs().max(), (W_update @ test_inps[i].cuda()).abs().mean())
    print(i, (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().max(), (W_update @ test_inps[i].cuda() - test_outs[i].cuda()).abs().mean())
    print(i, (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i].cuda()).abs().max(), (new_W.to(H_B.device) @ test_inps[i].cuda() - test_outs[i].cuda()).abs().mean())
    print("-----------")
    # print((new_W @ test_inps[i]).abs().max(), (new_W @ test_inps[i]).abs().mean())
    # print((W_update2 @ test_inps[i]).abs().max(), (W_update2 @ test_inps[i]).abs().mean())
    # print(test_outs[i][0].abs().max(), test_outs[i][0].abs().mean())
    print("-------------------------------------------------------------------------")

tensor(0.5000, device='cuda:0')
tensor(7.3101) tensor(0.9603)
tensor(6.9375, device='cuda:0') tensor(0.8099, device='cuda:0')
0 tensor(2.7277, device='cuda:0') tensor(0.1954, device='cuda:0')
0 tensor(3.2381, device='cuda:0') tensor(0.8544, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.0926) tensor(0.9997)
tensor(6.9747, device='cuda:0') tensor(0.8353, device='cuda:0')
32 tensor(2.6844, device='cuda:0') tensor(0.2219, device='cuda:0')
32 tensor(3.2661, device='cuda:0') tensor(0.8551, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(6.8198) tensor(0.9867)
tensor(6.5462, device='cuda:0') tensor(0.8258, device='cuda:0')
64 tensor(2.7837, device='cuda:0') tensor(0.2123, device='cuda:0')
64 tensor(3.2789, device='cuda:0') tensor(0.8557, device='cuda:0')
-----------
-------------------------------------------------------------------------
tensor(7.1617) tensor(0.9814