Skip to content

teramototoya/cca.pytorch

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CCA.pytorch

PyTorch implementation of

Requirements

  • Python>=3.6
  • PyTorch>=0.4.1
  • torchvision>=0.2.1
  • homura (to run example.py)
  • matplotlib (to run example.py)

Usage

from cca import CCAHook
hook1 = CCAHook(model, "layer3.0.conv1")
hook2 = CCAHook(model, "layer3.0.conv2")
model.eval()
with torch.no_grad():
    model(torch.randn(1200, 3, 224, 224))
hook1.distance(hook2, size=8) # resize to 8x8

Example

python example.py trains ResNet-20 on CIFAR-10 for 100 epochs then measures CCA distance between a trained model and its checkpoints.

Note

While the original SVCCA uses DFT for resizing, we use global average pooling for simplicity.

About

CCAs for looking into DNNs

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%