In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from sr_emu import Generator

In [3]:
G = Generator(6, 6, 1)

In [4]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


In [5]:
count_parameters(G)

+------------------------------------------+------------+
|                 Modules                  | Parameters |
+------------------------------------------+------------+
|              block0.weight               |    3072    |
|               block0.bias                |    512     |
|       block0.style_block.0.weight        |     6      |
|        block0.style_block.0.bias         |     6      |
|               addnoise.std               |    512     |
|           hblock0.conv1.weight           |  3538944   |
|            hblock0.conv1.bias            |    256     |
|    hblock0.conv1.style_block.0.weight    |    512     |
|     hblock0.conv1.style_block.0.bias     |    512     |
|          hblock0.addnoise1.std           |    256     |
|           hblock0.conv2.weight           |  1769472   |
|            hblock0.conv2.bias            |    256     |
|    hblock0.conv2.style_block.0.weight    |    256     |
|     hblock0.conv2.style_block.0.bias     |    256     |
|          hbl

6989426

In [6]:
from torchinfo import summary

In [7]:
summary(G, [(1 ,6, 22, 22, 22), (1, 1), (1, 6, 22, 22, 22), (1, 6, 40, 40, 40), (1, 6, 38, 38, 38), (1, 6, 72, 72, 72), (1, 6, 70, 70, 70), (1, 6, 136, 136, 136), (1, 6, 134, 134, 134)])

Layer (type:depth-idx)                   Output Shape              Param #
Generator                                --                        --
├─ConvStyled3d: 1-1                      [1, 512, 22, 22, 22]      --
│    └─Sequential: 2-1                   [1, 6]                    --
│    │    └─Linear: 3-1                  [1, 6]                    12
├─LeakyReLUStyled: 1-2                   [1, 512, 22, 22, 22]      --
├─ConvStyled3d: 1-3                      [1, 512, 22, 22, 22]      --
│    └─Sequential: 2-2                   [1, 6]                    --
│    │    └─Linear: 3-2                  [1, 6]                    12
├─AddNoise: 1-4                          [1, 512, 22, 22, 22]      512
├─HBlock: 1-5                            [1, 256, 38, 38, 38]      --
│    └─Resampler: 2-3                    [1, 512, 42, 42, 42]      --
│    └─ConvStyled3d: 2-4                 [1, 256, 40, 40, 40]      --
│    │    └─Sequential: 3-3              [1, 512]                  1,024
│    └─Leak