In [1]:
import torch

In [None]:
from sr_emu import Generator

In [None]:
G = Generator(6, 6, 1)

In [None]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


In [None]:
count_parameters(G)

+------------------------------------------+------------+
|                 Modules                  | Parameters |
+------------------------------------------+------------+
|              block0.weight               |    3072    |
|               block0.bias                |    512     |
|       block0.style_block.0.weight        |     6      |
|        block0.style_block.0.bias         |     6      |
|               addnoise.std               |    512     |
|           hblock0.conv1.weight           |  3538944   |
|            hblock0.conv1.bias            |    256     |
|    hblock0.conv1.style_block.0.weight    |    512     |
|     hblock0.conv1.style_block.0.bias     |    512     |
|          hblock0.addnoise1.std           |    256     |
|           hblock0.conv2.weight           |  1769472   |
|            hblock0.conv2.bias            |    256     |
|    hblock0.conv2.style_block.0.weight    |    256     |
|     hblock0.conv2.style_block.0.bias     |    256     |
|          hbl

6989426

In [None]:
state = torch.load('/hildafs/home/xzhangn/state_710.pt')

In [6]:
G.cuda()

Generator(
  (block0): ConvStyled3d(
    (style_block): Sequential(
      (0): Linear(in_features=1, out_features=6, bias=True)
    )
  )
  (act): LeakyReLUStyled(negative_slope=0.2, inplace=True)
  (addnoise): AddNoise()
  (hblock0): HBlock(
    (act): LeakyReLUStyled(negative_slope=0.2, inplace=True)
    (upsample): Resampler()
    (conv1): ConvStyled3d(
      (style_block): Sequential(
        (0): Linear(in_features=1, out_features=512, bias=True)
      )
    )
    (addnoise1): AddNoise()
    (conv2): ConvStyled3d(
      (style_block): Sequential(
        (0): Linear(in_features=1, out_features=256, bias=True)
      )
    )
    (addnoise2): AddNoise()
    (proj): ConvStyled3d(
      (style_block): Sequential(
        (0): Linear(in_features=1, out_features=256, bias=True)
      )
    )
    (noise_proj1): ConvStyled3d(
      (style_block): Sequential(
        (0): Linear(in_features=1, out_features=6, bias=True)
      )
    )
    (noise_proj2): ConvStyled3d(
      (style_block): Seq

In [7]:
input_size = 24
padding = 3
N = input_size + 2 * padding
noise_padding = 4
noise0_N = input_size*2 + 2 * noise_padding
noise1_N = input_size*4 + 2 * noise_padding
noise2_N = input_size*8 + 2 * noise_padding

In [8]:
input = torch.randn(1, 6, N, N, N).cuda()
style = torch.randn(1, 1).cuda()
early_noise = torch.randn(1, 6, N, N, N).cuda()
noise0 = torch.randn(1, 6, noise0_N, noise0_N, noise0_N).cuda()
noise1 = torch.randn(1, 6, noise1_N, noise1_N, noise1_N).cuda()
noise2 = torch.randn(1, 6, noise2_N, noise2_N, noise2_N).cuda()

In [9]:
out = G(input, style, early_noise, noise0, noise1, noise2)

In [10]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   22309 MB |   23874 MB |   43255 MB |   20945 MB |
|       from large pool |   22304 MB |   23869 MB |   43246 MB |   20942 MB |
|       from small pool |       5 MB |       5 MB |       8 MB |       3 MB |
|---------------------------------------------------------------------------|
| Active memory         |   22309 MB |   23874 MB |   43255 MB |   20945 MB |
|       from large pool |   22304 MB |   23869 MB |   43246 MB |   20942 MB |
|       from small pool |       5 MB |       5 MB |       8 MB |       3 MB |
|---------------------------------------------------------------

In [11]:
print(torch.cuda.max_memory_allocated())

25033991680


In [12]:
!nvidia-smi

Thu Jun  8 14:56:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  On   | 00000000:01:00.0 Off |                    0 |
| N/A   30C    P0    67W / 400W |  27870MiB / 40536MiB |     52%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [13]:
loss = out.sum()
loss.backward()

In [14]:
print(torch.cuda.max_memory_allocated())

27884080128
