<a href="https://colab.research.google.com/github/ymubarka/finetuning_methods/blob/main/Finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook is inspired by methods described in:
1. [Discriminative Fine-Tuning](https://paperswithcode.com/method/discriminative-fine-tuning)
2. [Surgical Finetuning](https://arxiv.org/pdf/2210.11466.pdf)
  1. Specify the layer as opposed to the block
  2. AutoRGN implementation may be different

In [None]:
# Args
finetuning_method = "reg"  # "reg", "surgical", "discriminitive", "auto"
lr = 3e-3
number_warmup_epochs = 2

# Required for surgical finetuning
# Specify the range of layers to finetune (specify first and last layer (inclusive)) 
# Use the names from [param_name for param_name, _ in model.named_parameters()]
layers_to_finetune = ['conv1.weight', "fc.bias"]  # This finetunes the entire network


batch_size = 256
num_workers = 8

epochs = 20
T_max = epochs
momentum = 0.9

In [None]:
import numpy as np
import torch
import torchvision.models as models
from torchvision.datasets import CIFAR10
from torchvision import transforms
from tqdm.notebook import tqdm

import os


def load_moco(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model = models.__dict__[checkpoint['arch']]()

    state_dict = checkpoint['state_dict']
    for k in list(state_dict.keys()):
        # retain only encoder_q up to before the embedding layer
        if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
            # remove prefix
            state_dict[k[len("module.encoder_q."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]
    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    return model

def train(model, data_loader, optimizer, criterion, epoch, use_cuda):
    model.train()
    total_loss, train_bar = 0.0, tqdm(data_loader)
    correct = 0
    total = 0
    for data, labels in train_bar:
        if use_cuda:
            data, labels = data.cuda(non_blocking=True), labels.cuda(non_blocking=True)
        
        optimizer.zero_grad()
        logits = model(data)
        _, predicted = torch.max(logits.data, axis=1)

        loss = criterion(logits, labels)
        correct += sum((predicted == labels).tolist())

        loss.backward()
        optimizer.step()

        total += labels.size(0)
        total_loss += loss.item() 
        acc = correct / total * 100
        train_bar.set_description(f"Train Epoch: [{epoch+1}/20], Avg. Loss: {total_loss / total:.4f}, Avg. Acc: {acc:.2f}")

    return total_loss / total

def test(model, data_loader, epoch, use_cuda):
    model.eval()
    total_acc, total_num, test_bar = 0.0, 0, tqdm(data_loader)
    with torch.no_grad():
      for data, labels in test_bar:
          if use_cuda:
              data = data.cuda(non_blocking=True)

          logits = model(data)
          _, test_preds = torch.max(logits.data, axis=1)

          total_num += data.size(0)
          total_acc += (test_preds.cpu() == labels).float().sum().item()
          acc = total_acc / total_num * 100

          test_bar.set_description(f"Test Epoch: [{epoch+1}/20], Avg. Acc: {acc:.2f}")

def autoRGN_step(model, optimizer):
        grads = []
        for param in model.parameters():
            if param.requires_grad:
                grads.append((torch.norm(param.grad.view(-1), 2) / torch.norm(param.view(-1), 2)).cpu().detach())
        grads = torch.stack(grads)

        grads_min, grads_max = grads.min(), grads.max()
        new_min, new_max = 0.55, 1

        grads_norm = (grads - grads_min)/(grads_max - grads_min)*(new_max - new_min) + new_min
        for i, g in enumerate(optimizer.param_groups):
            g['lr'] = g['lr'] * grads_norm[i]
            
def warmup(current_step: int):
    return 1 / (10 ** (float(number_warmup_epochs - current_step)))


use_cuda = torch.cuda.is_available()

# Data Prepping
default_transformations = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.228, 0.224, 0.225])
    ])

cifar10_train = CIFAR10(root='drive/MyDrive/finetuning_project/data/', train=True, transform=default_transformations, download=True)
cifar10_test = CIFAR10(root='drive/MyDrive/finetuning_project/data/', train=False, transform=default_transformations, download=True)

train_loader = torch.utils.data.DataLoader(cifar10_train,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=True,
                                          pin_memory=True)
test_loader = torch.utils.data.DataLoader(cifar10_test,
                                          batch_size=batch_size,
                                          num_workers=num_workers,
                                          shuffle=False,
                                          pin_memory=True)


# I got the pretrained model from the original moco repo from facebookresearch
# https://github.com/facebookresearch/moco
model = load_moco("drive/MyDrive/finetuning_project/models/moco_v2_800ep_pretrain.pth.tar")
model.fc = torch.nn.Linear(2048, 10)
if use_cuda:
    model = model.cuda()

parameters = model.parameters()

results_dir = f"drive/MyDrive/finetuning_project/trained_models/FT_{finetuning_method}"
if number_warmup_epochs > 0:
    results_dir += f"_warm-{number_warmup_epochs}"
if not os.path.exists(results_dir):
    os.mkdir(results_dir)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
if finetuning_method == "surgical":
    all_layers = [param_name for param_name, _ in model.named_parameters()]
    first_layer = all_layers.index(layers_to_finetune[0])
    last_layer = all_layers.index(layers_to_finetune[1])

    for param in model.parameters():
          param.requires_grad = False
    for param in list(model.parameters())[first_layer:last_layer + 1]:
          param.requires_grad = True
elif finetuning_method == "discriminitive":
    parameters = []
    for param in list(model.parameters())[::-1]:
        parameters += [{'params': param,
                        'lr'    : lr}]
        lr *= 0.9
elif finetuning_method == "auto":
    parameters = []
    for name, param in model.named_parameters():
        parameters += [{'params': param,
                        'lr'    : lr}]

In [None]:
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum)

warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup)
train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=lr)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [warmup_scheduler, train_scheduler], [number_warmup_epochs])

for epoch in range(epochs):
    train(model, train_loader, optimizer, criterion, epoch, use_cuda)
    torch.save({'epoch': epoch, 
                'state_dict': model.state_dict(), 
                'optimizer' : optimizer.state_dict()}, results_dir + '/model_last.pth')
    test(model, test_loader, epoch, use_cuda)

    if finetuning_method == "auto":
        autoRGN_step(model, optimizer)
    else:
        scheduler.step()

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]



  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

In [None]:
from google.colab import runtime
runtime.unassign()