In [10]:
import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

from siren_pytorch import SirenNet

from copy import deepcopy
import tracemalloc
import gc

In [2]:
net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 1024,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    final_activation = None,   # activation of final layer (nn.Identity() for direct output)
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

coor = torch.randn(64, 2, requires_grad=True)


In [3]:
def trace_gpu(input_net, input_coor, device: str = "cuda:0"):

    net = deepcopy(input_net).to(device)
    coor = deepcopy(input_coor).to(device)

    out = net(coor)
    for i in range(3):
        g = torch.autograd.grad(out[..., i] / 256, coor, torch.ones_like(out[..., i]), create_graph=True)


    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated(device=device)
    torch.cuda.reset_max_memory_cached(device=device)
    torch.cuda.reset_peak_memory_stats(device=device)
    torch.cuda.reset_accumulated_memory_stats(device=device)
    


    net(coor)
    for i in range(3):
        torch.autograd.grad(out[..., i] / 256, coor, torch.ones_like(out[..., i]), create_graph=True)

    peak_mem_gpu = torch.cuda.max_memory_allocated(device=device)
    # print(f"Peak memory usage on GPU: {peak_mem_gpu / (1024 ** 3):.03f} GB")
    # MB
    print(f"Peak memory usage on GPU: {peak_mem_gpu / (1024 ** 2):.03f} MB")


In [11]:
def trace_cpu(input_net, input_coor):
    gc.collect()
    tracemalloc.start()

    net = input_net.to("cpu")
    coor = input_coor.to("cpu")
    
    out = net(coor)
    for i in range(3):
        g = torch.autograd.grad(out[..., i] / 256, coor, torch.ones_like(out[..., i]), create_graph=True)

    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    print(f"Peak memory usage on CPU: {peak / (1024 ** 3):.03f} GB")

In [13]:
torch.cuda.init()
trace_gpu(net, coor, device="cuda:0")
trace_cpu(net, coor)

# net = net.to('cpu')
# coor = coor.to('cpu')

Peak memory usage on GPU: 42.544 MB
Peak memory usage on CPU: 0.000 GB


In [6]:
torch.cuda.reset_peak_memory_stats(device='cuda:0')
torch.cuda.empty_cache()

out = net(coor)
g = torch.autograd.grad(out[0].sum(), net.parameters(), create_graph=True)

peak_mem_gpu = torch.cuda.max_memory_allocated(device='cuda:0')
print(f"Peak memory usage on GPU: {peak_mem_gpu / (1024 ** 3):.03f} GB")

Peak memory usage on GPU: 0.016 GB


In [7]:
# with profile(
#     activities=[
#         torch.profiler.ProfilerActivity.CPU,
#         torch.profiler.ProfilerActivity.CUDA,
#     ],
#     profile_memory=True,
# ) as prof:
#     out = net(coor)
#     g = torch.autograd.grad(out[0].sum(), net.parameters(), create_graph=True)


# print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
# # peak CPU memory usage

# total_mem_cpu = 0
# total_mem_gpu = 0

# for event in prof.events():
#     total_mem_cpu += event.cpu_memory_usage
#     total_mem_gpu += event.cuda_memory_usage

# print(f"Total CPU memory usage: {total_mem_cpu / 8 / 1024 / 1024} MB")
# print(f"Total GPU memory usage: {total_mem_gpu / 8 / 1024 / 1024} MB")