In [1]:
import sys
sys.path.append('..')

from VarPro import *
from resnet import *
import torchvision
import torchvision.transforms.v2 as transforms
from torchinfo import summary



In [6]:
resnet = SimpleResNet18(in_channels=3, num_classes=10)

In [7]:
summary(resnet, input_size=(128, 3, 32, 32))

Layer (type:depth-idx)                        Output Shape              Param #
VarProModel                                   [128, 10]                 --
├─SimpleResNetFeatureModel: 1-1               [128, 64]                 --
│    └─Conv2d: 2-1                            [128, 16, 32, 32]         432
│    └─BatchNorm2d: 2-2                       [128, 16, 32, 32]         32
│    └─Sequential: 2-3                        [128, 16, 32, 32]         --
│    │    └─BasicBlock: 3-1                   [128, 16, 32, 32]         4,672
│    │    └─BasicBlock: 3-2                   [128, 16, 32, 32]         4,672
│    └─Sequential: 2-4                        [128, 32, 16, 16]         --
│    │    └─BasicBlock: 3-3                   [128, 32, 16, 16]         14,528
│    │    └─BasicBlock: 3-4                   [128, 32, 16, 16]         18,560
│    └─Sequential: 2-5                        [128, 64, 8, 8]           --
│    │    └─BasicBlock: 3-5                   [128, 64, 8, 8]           57,728
│

In [5]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./cifar10_data',
                                        train=True,
                                        download=True,
                                        transform=transform_train)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

In [6]:
inputs, targets = next(iter(train_loader))

In [12]:
resnet

VarProModel(
  (feature_model): SimpleResNetFeatureModel(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [18]:
out = resnet.feature_model.conv1(inputs)
out = F.relu(resnet.feature_model.bn1(out))
out = resnet.feature_model.layer1[0].conv1(out)

RuntimeError: Given groups=1, weight of size [16, 64, 3, 3], expected input[128, 16, 32, 32] to have 64 channels, but got 16 channels instead

In [19]:
resnet.feature_model.layer1[0].conv1

Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [7]:
features = resnet.feature_model(inputs).clone().detach()
K = features @ features.T / (128 * 512)
torch.linalg.norm(K, ord=2)

tensor(5.2625)

In [5]:
## Training
lmbda = 1e-1
time_scale = 2**(-10)
lr = 512 * time_scale

# biased or unbiased loss
criterion = VarProCriterion(lmbda, num_classes=10)
                                   
optimizer = torch.optim.SGD(resnet.feature_model.parameters(), lr=lr)
#optimizer.add_param_group({'lr': 5*student_width*lmbda, 'params': student.outer.weight})

In [8]:
problem = LearningProblem(resnet, train_loader, optimizer, criterion)