In [1]:
"""
Testing generalization performance of pre-trained ResNet-26 on VisDA real test
with simple linear probe
"""

# Do linear probe with both synthetic training data and real training data

import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
import torch.optim as optim
import wandb

from main import get_model
from visda17 import get_visda_dataloaders
from targeted_synth_training import train_epoch, validate_epoch


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_classes = 12

config = {
    'model' : 'resnet26',
    'ft_strategy' : 'full',
    'img_size' : 'small',
    'init' : 'scratch'
}

if config['model'] == 'resnet26':
    net = get_model('resnet26', num_class=num_classes, dataset='imagenet12' if not config['init'] == 'scratch' else None)
elif config['model'] == 'resnet18':
    net = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    net.fc = nn.Linear(512, num_classes)
else:
    raise ValueError("Bad model name specified.")

print(f"Network trainable parameter count: {count_parameters(net)/1.0e6:.1f} M")
net.to(device='cuda:0')

# Freeze params
if config['fit_strategy'] == 'linear_probe':
    needs_training = lambda name : name in ['linear.weight', 'linear.bias', 'fc.weight', 'fc.bias']
elif config['fit_strategy'] == 'last-1':
    needs_training = lambda name : '2.3' in name
elif config['fit_strategy'] == 'last-2':
    needs_training = lambda name : '2.3' in name or '2.2' in name
elif config['fit_strategy'] == 'last-3':
    needs_training = lambda name : '2.3' in name or '2.2' in name or '2.1'
elif config['fit_strategy'] == 'bn_only':
    needs_training = lambda name : 'bn' in name

frozen = []
unfrozen = []
for name, m in net.named_parameters():
    if needs_training(name) or config['ft_strategy'] == 'full':
        m.requires_grad = True
        unfrozen.append(name)
    else:
        m.requires_grad = False
        frozen.append(name)

print("frozen params: ", frozen)
print("trainable params: ", unfrozen)

train_loader, val_loader, test_loader = get_visda_dataloaders(
    train_dir='/export/r32/data/visda17/train',
    val_dir='/export/r32/data/visda17/validation',
    test_dir='/export/r32/data/visda17/test'
)

lr = 0.01
optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
                        lr=lr,
                        momentum=0.9,
                        weight_decay=1e-4)


Network trainable parameter count: 11.6 M


KeyError: 'fit_strategy'

In [None]:

wandb.init(
        project='targeted-generalization',
        config=config
)

print("training...")
train_epoch(train_loader, net, optimizer)

print("validating...")
validate_epoch(test_loader, net)

In [2]:

for name, m in net.named_parameters():
    print(name)

conv1.weight
bn1.weight
bn1.bias
blocks.0.0.conv1.weight
blocks.0.0.bn1.weight
blocks.0.0.bn1.bias
blocks.0.0.conv2.1.weight
blocks.0.0.bn2.weight
blocks.0.0.bn2.bias
blocks.0.1.conv1.weight
blocks.0.1.bn1.weight
blocks.0.1.bn1.bias
blocks.0.1.conv2.1.weight
blocks.0.1.bn2.weight
blocks.0.1.bn2.bias
blocks.0.2.conv1.weight
blocks.0.2.bn1.weight
blocks.0.2.bn1.bias
blocks.0.2.conv2.1.weight
blocks.0.2.bn2.weight
blocks.0.2.bn2.bias
blocks.0.3.conv1.weight
blocks.0.3.bn1.weight
blocks.0.3.bn1.bias
blocks.0.3.conv2.1.weight
blocks.0.3.bn2.weight
blocks.0.3.bn2.bias
blocks.1.0.conv1.weight
blocks.1.0.bn1.weight
blocks.1.0.bn1.bias
blocks.1.0.conv2.1.weight
blocks.1.0.bn2.weight
blocks.1.0.bn2.bias
blocks.1.1.conv1.weight
blocks.1.1.bn1.weight
blocks.1.1.bn1.bias
blocks.1.1.conv2.1.weight
blocks.1.1.bn2.weight
blocks.1.1.bn2.bias
blocks.1.2.conv1.weight
blocks.1.2.bn1.weight
blocks.1.2.bn1.bias
blocks.1.2.conv2.1.weight
blocks.1.2.bn2.weight
blocks.1.2.bn2.bias
blocks.1.3.conv1.weight
block

In [None]:
"""
Now using sklearn linear regression to rule out any linear layer training issues with PyTorch
"""

