PyTorch implementation of
- SVCCA: Singular Vector Canonical Correlation Analysis for Deep Learning Dynamics and Interpretability
- Insights on representational similarity in neural networks with canonical correlation
- Python>=3.6
- PyTorch>=0.4.1
- torchvision>=0.2.1
- homura (to run
example.py
) - matplotlib (to run
example.py
)
from cca import CCAHook
hook1 = CCAHook(model, "layer3.0.conv1")
hook2 = CCAHook(model, "layer3.0.conv2")
model.eval()
with torch.no_grad():
model(torch.randn(1200, 3, 224, 224))
hook1.distance(hook2, size=8) # resize to 8x8
python example.py
trains ResNet-20 on CIFAR-10 for 100 epochs then measures CCA distance between a trained model and its checkpoints.
While the original SVCCA uses DFT for resizing, we use global average pooling for simplicity.