In [19]:
import torch
import os
from prettytable import PrettyTable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    params_list = []
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        params_list.append((name,params))
    sorted_params_list = sorted(params_list,key=lambda x:x[1],reverse=True)
    for name,params in sorted_params_list:
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")


In [20]:
from models.gpt import GPT, GPTConfig
model_name = 'mini-gpt'
# init from a model saved in a specific directory
ckpt_path = os.path.join('out', model_name+'.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)

state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.' # remove weird prefix (according to nanoGPT)
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [21]:
count_parameters(model)

+-----------------------------------+------------+
|              Modules              | Parameters |
+-----------------------------------+------------+
|          tok_embd.weight          |  19316736  |
|             ff.weight             |  19316736  |
|      blocks.0.ff.lin1.weight      |   589824   |
|      blocks.0.ff.lin2.weight      |   589824   |
|      blocks.1.ff.lin1.weight      |   589824   |
|      blocks.1.ff.lin2.weight      |   589824   |
|      blocks.2.ff.lin1.weight      |   589824   |
|      blocks.2.ff.lin2.weight      |   589824   |
|      blocks.3.ff.lin1.weight      |   589824   |
|      blocks.3.ff.lin2.weight      |   589824   |
|      blocks.4.ff.lin1.weight      |   589824   |
|      blocks.4.ff.lin2.weight      |   589824   |
|      blocks.5.ff.lin1.weight      |   589824   |
|      blocks.5.ff.lin2.weight      |   589824   |
|      blocks.0.csa.proj.weight     |   147456   |
|      blocks.1.csa.proj.weight     |   147456   |
|      blocks.2.csa.proj.weight