In [1]:
import torch
import time
import tabulate
import exact.ops as ex_ops

  from .autonotebook import tqdm as notebook_tqdm


OSError: libmetis.so: cannot open shared object file: No such file or directory

In [83]:

@torch.no_grad()
def seed_gen_rad_mat(rm_size, feat_size, device, dtype, seed):
    torch.cuda.manual_seed(seed)
    bern = torch.randint(2, size=rm_size, device=device, requires_grad=False, dtype=dtype)
    return (2.0 * bern - 1) / feat_size **0.5
@torch.no_grad()
def low_mem_input2rp(input, kept_acts):
    assert len(input.size()) == 2
    rand_mat_size = (input.shape[1], kept_acts)
    # Create random matrix
    seed = int(time.time()*1000)
    rand_matrix = seed_gen_rad_mat(rand_mat_size, kept_acts, input.device, input.dtype, seed)
    dim_reduced_input = torch.matmul(input, rand_matrix)
    print('='*20, 'input2rp', '='*20)
    print(rand_matrix)
    return dim_reduced_input, rand_mat_size, seed


@torch.no_grad()
def low_mem_rp2input(dim_reduced_input, input_shape, seed, rm_size):
    assert len(dim_reduced_input.size()) == 2
    rand_matrix = seed_gen_rad_mat(rm_size, rm_size[1], dim_reduced_input.device, dim_reduced_input.dtype, seed)
    input = torch.matmul(dim_reduced_input, rand_matrix.t())    
    print('='*20, 'rp2input', '='*20)
    print(rand_matrix)
    return input.view(input_shape)

In [84]:
input = torch.rand((4,4), dtype=torch.float32, device='cuda')
kept_acts = 2
dim_reduced_input, rand_mat_size, seed = low_mem_input2rp(input, kept_acts)
input_recovered = low_mem_rp2input(dim_reduced_input, input.shape, seed, rand_mat_size)


tensor([[ 0.7071, -0.7071],
        [-0.7071, -0.7071],
        [-0.7071,  0.7071],
        [ 0.7071,  0.7071]], device='cuda:0')
tensor([[ 0.7071, -0.7071],
        [-0.7071, -0.7071],
        [-0.7071,  0.7071],
        [ 0.7071,  0.7071]], device='cuda:0')


In [93]:
def test_rp_speed():
    runtime = 1_0000
    M, N = 1024, 128 
    input = torch.rand(M, N, dtype=torch.float32).cuda()
    kept_acts = 0.5
    kept_acts = int(input.shape[1]*kept_acts+0.99)
    low_mem_input2rp_time = 0.0
    low_mem_rp2input_time = 0.0
    input2rp_time = 0.0
    rp2input_time = 0.0
    for _ in range(runtime//100):
        dim_reduced_input, rm_size, seed = ex_ops.low_mem_input2rp(input, kept_acts)
        rec_input = ex_ops.low_mem_rp2input(dim_reduced_input, input.shape, seed, rm_size)
        dim_reduced_input, rd_mat = ex_ops.input2rp(input, kept_acts)
        rec_input = ex_ops.rp2input(dim_reduced_input, input.shape, rd_mat)
    torch.cuda.synchronize()
    for _ in range(runtime):
        
        start = time.time()
        dim_reduced_input, rm_size, seed = ex_ops.low_mem_input2rp(input, kept_acts)
        # print(rd_mat)
        end = time.time()
        # torch.cuda.synchronize()
        low_mem_input2rp_time += end - start
        start = time.time()
        rec_input = ex_ops.low_mem_rp2input(dim_reduced_input, input.shape, seed, rm_size)
        end = time.time()
        # torch.cuda.synchronize()
        low_mem_rp2input_time += end - start


        start = time.time()
        dim_reduced_input, rd_mat = ex_ops.input2rp(input, kept_acts)
        end = time.time()
        # torch.cuda.synchronize()
        input2rp_time += end - start
        start = time.time()
        rec_input = ex_ops.rp2input(dim_reduced_input, input.shape, rd_mat)
        end = time.time()
        # torch.cuda.synchronize()
        rp2input_time += end - start

        
    torch.cuda.synchronize()

        
    low_mem_input2rp_time = low_mem_input2rp_time/runtime*1e6
    low_mem_rp2input_time = low_mem_rp2input_time/runtime*1e6
    input2rp_time = input2rp_time/runtime*1e6
    rp2input_time = rp2input_time/runtime*1e6
    print(tabulate.tabulate([
        ["exact input2rp avg (us)"] + [input2rp_time],
        ["exact rp2input avg (us)"] + [rp2input_time],
        ["low mem input2rp avg (us)"] + [low_mem_input2rp_time],
        ["low mem rp2input avg (us)"] + [low_mem_rp2input_time],
    ]))
test_rp_speed()

-------------------------  --------
exact input2rp avg (us)    19.0703
exact rp2input avg (us)     7.72867
low mem input2rp avg (us)  21.665
low mem rp2input avg (us)  22.2179
-------------------------  --------


In [1]:
from exact.ops import low_mem_input2rp, low_mem_rp2input
input = torch.rand(4,4)
kept_acts = 2
dim_reduced_input, rand_mat_size, seed = low_mem_input2rp(input, kept_acts)
input_recovered = low_mem_rp2input(dim_reduced_input, input.shape, seed, rand_mat_size)


OSError: libmetis.so: cannot open shared object file: No such file or directory