In [4]:
import torch
from pathlib import Path

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

CKPT_PATH = (
    Path.home() / "BioVoice" / "wespeaker-voxceleb-resnet293-LM" / "avg_model.pt"
)
print("Checkpoint path:", CKPT_PATH)

assert CKPT_PATH.exists(), "Checkpoint not found"

ckpt = torch.load(CKPT_PATH, map_location="cpu")
print("Checkpoint loaded, type:", type(ckpt))

if isinstance(ckpt, dict):
    print("Checkpoint keys:", ckpt.keys())

Using device: cuda
Checkpoint path: /home/SpeakerRec/BioVoice/wespeaker-voxceleb-resnet293-LM/avg_model.pt


  ckpt = torch.load(CKPT_PATH, map_location="cpu")


Checkpoint loaded, type: <class 'collections.OrderedDict'>
Checkpoint keys: odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.shortcut.0.weight', 'layer1.0.shortcut.1.weight', 'layer1.0.shortcut.1.bias', 'layer1.0.shortcut.1.running_mean', 'layer1.0.shortcut.1.running_var', 'layer1.0.shortcut.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.runn

In [6]:
from wespeaker.models.resnet import ResNet293

model = ResNet293(
    feat_dim=80,
    embed_dim=256,
)

# Ignore projection layer
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = model.to(DEVICE)

model.eval()
for p in model.parameters():
    p.requires_grad_(True)

print("Model loaded (projection ignored)")

Missing keys: []
Unexpected keys: ['projection.weight']
Model loaded (projection ignored)


In [9]:
x = torch.randn(1, 200, 80, device=DEVICE)

out = model(x)
embedding = out[-1]

print("Embedding shape:", embedding.shape)

loss = embedding.norm()
loss.backward()

for name, p in model.named_parameters():
    if "layer4.0.conv1.weight" in name:
        print("Grad exists:", p.grad is not None)
        break

Embedding shape: torch.Size([1, 256])
Grad exists: True
