In [1]:
import torch

In [2]:
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg').cuda()

Using cache found in /home/singh/.cache/torch/hub/facebookresearch_dinov2_main


In [3]:
dinov2_vitb14_reg.eval()
for name, param in dinov2_vitb14_reg.named_parameters():
    print(name, param.shape)
    

cls_token torch.Size([1, 1, 768])
pos_embed torch.Size([1, 1370, 768])
register_tokens torch.Size([1, 4, 768])
mask_token torch.Size([1, 768])
patch_embed.proj.weight torch.Size([768, 3, 14, 14])
patch_embed.proj.bias torch.Size([768])
blocks.0.norm1.weight torch.Size([768])
blocks.0.norm1.bias torch.Size([768])
blocks.0.attn.qkv.weight torch.Size([2304, 768])
blocks.0.attn.qkv.bias torch.Size([2304])
blocks.0.attn.proj.weight torch.Size([768, 768])
blocks.0.attn.proj.bias torch.Size([768])
blocks.0.ls1.gamma torch.Size([768])
blocks.0.norm2.weight torch.Size([768])
blocks.0.norm2.bias torch.Size([768])
blocks.0.mlp.fc1.weight torch.Size([3072, 768])
blocks.0.mlp.fc1.bias torch.Size([3072])
blocks.0.mlp.fc2.weight torch.Size([768, 3072])
blocks.0.mlp.fc2.bias torch.Size([768])
blocks.0.ls2.gamma torch.Size([768])
blocks.1.norm1.weight torch.Size([768])
blocks.1.norm1.bias torch.Size([768])
blocks.1.attn.qkv.weight torch.Size([2304, 768])
blocks.1.attn.qkv.bias torch.Size([2304])
blocks

In [5]:
from PIL import Image
from torchvision import transforms
import os
import copy
from tqdm import tqdm

dataset_folder = './dataset/imagenet/images'
image_files = [f for f in os.listdir(dataset_folder)][:5_000]

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

100%|██████████| 5000/5000 [00:32<00:00, 154.82it/s]


In [None]:
from torchvision import datasets, transforms

dataset_folder = './path/to/folder'  # Replace with the path to your image folder

# Create the ImageFolder dataset
dataset = datasets.ImageFolder(dataset_folder, transform=transform)

# Create the DataLoader
batch_size = 32  # Set the batch size
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [13]:
outputs = outputs[:, :128].cpu().numpy()
outputs.shape

(5000, 128)

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

In [5]:
# FGSM attack (same as before)
def fgsm_attack(model, x, y, epsilon):
    x.requires_grad = True
    output = model(x)
    loss = F.cross_entropy(output, y)
    loss.backward()
    
    x_adv = x + epsilon * x.grad.sign()
    x_adv = torch.clamp(x_adv, 0, 1)
    
    return x_adv.detach()

In [14]:

# Fast Adversarial Finetuning
def fast_adversarial_finetuning(model, train_loader, optimizer, epsilon, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # Generate adversarial examples
        data_adv = fgsm_attack(model, data, target, epsilon)
        
        # Finetune on clean and adversarial data
        optimizer.zero_grad()
        output_clean = model(data)
        output_adv = model(data_adv)
        loss_clean = F.cross_entropy(output_clean, target)
        loss_adv = F.cross_entropy(output_adv, target)
        loss = 0.5 * (loss_clean + loss_adv)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item()}')


# for param in list(model.parameters())[:-2]:  # Finetune only the last two layers
#     param.requires_grad = False

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

transform = transforms.Compose([
    transforms.Resize(224),  # ResNet18 expects 224x224 images
    transforms.ToTensor(),
])
train_loader = DataLoader(
    datasets.CIFAR10('../data', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True)

epsilon = 0.03  # Maximum perturbation (might need adjustment for your dataset)
num_epochs = 5  # Typically fewer epochs for finetuning

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    fast_adversarial_finetuning(model, train_loader, optimizer, epsilon, device)

print("Finetuning complete")