* https://github.com/little51/dinov3-train

In [1]:
# !pip install timm==1.0.20 -i https://pypi.mirrors.ustc.edu.cn/simple
# !pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124

In [5]:
import torch.nn as nn  ## torch neural network
import timm  ## timm is a library for vision model 
from torchvision import datasets  # lib to load torchvision model
import torch  # torch

REPO_DIR = "/home/yang/MyRepos/dinov3"

In [6]:
import torch

REPO_DIR = "/home/yang/MyRepos/dinov3"
BACKBONE_DIR = "/home/yang/MyRepos/meta_dino_models/dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"
# BACKBONE_DIR = "/home/yang/MyRepos/meta_dino_models/dinov3_vit7b16_pretrain_lvd1689m-a955f4ea.pth"
DEPTHER_DIR = "/home/yang/MyRepos/meta_dino_models/dinov3_vit7b16_imagenet1k_linear_head-90d8ed92.pth"

dinov3_vitl16 = torch.hub.load(REPO_DIR, 'dinov3_vitl16', source='local', weights=BACKBONE_DIR)
# dinov3_vit7b16_lc = torch.hub.load(REPO_DIR, 'dinov3_vith16plus', source="local", weights=DEPTHER_DIR, backbone_weights=BACKBONE_DIR)

# Download BackBone model
* Note that it uses num_classes=0 to disable classification head of the backbone
* It freezes all the parameters to prevent back-propagate

In [7]:
import torch.nn as nn
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_model():
    # list all timm supported models
    model_list = timm.list_models()
    print(model_list)
    
    # timm.create_model is used to load pretrained vision model, pretrained to determine if load weights
    model = timm.create_model('timm/vit_small_patch16_dinov3.lvd1689m', pretrained=True, num_classes=0)
    print(model)

    input_temp = torch.randn((1, 3, 256, 256))
    print(model(input_temp).shape) # get the shape of the output
    
    # timm.create_model is used to
    model.eval()  # set the model in evaluation mode
    for param in model.parameters():
        param.requires_grad = False
    print("model loaded")
    print(sum(param.numel() for param in model.parameters()))
    return model

model = load_model()

['aimv2_1b_patch14_224', 'aimv2_1b_patch14_336', 'aimv2_1b_patch14_448', 'aimv2_3b_patch14_224', 'aimv2_3b_patch14_336', 'aimv2_3b_patch14_448', 'aimv2_huge_patch14_224', 'aimv2_huge_patch14_336', 'aimv2_huge_patch14_448', 'aimv2_large_patch14_224', 'aimv2_large_patch14_336', 'aimv2_large_patch14_448', 'bat_resnext26ts', 'beit3_base_patch16_224', 'beit3_giant_patch14_224', 'beit3_giant_patch14_336', 'beit3_large_patch16_224', 'beit_base_patch16_224', 'beit_base_patch16_384', 'beit_large_patch16_224', 'beit_large_patch16_384', 'beit_large_patch16_512', 'beitv2_base_patch16_224', 'beitv2_large_patch16_224', 'botnet26t_256', 'botnet50ts_256', 'caformer_b36', 'caformer_m36', 'caformer_s18', 'caformer_s36', 'cait_m36_384', 'cait_m48_448', 'cait_s24_224', 'cait_s24_384', 'cait_s36_384', 'cait_xs24_384', 'cait_xxs24_224', 'cait_xxs24_384', 'cait_xxs36_224', 'cait_xxs36_384', 'coat_lite_medium', 'coat_lite_medium_384', 'coat_lite_mini', 'coat_lite_small', 'coat_lite_tiny', 'coat_mini', 'coat_s

# Define new target model
* Define a simple 10-classes linear classification head

In [8]:
def def_model(model):
    # Extract dinov3 features
    feature_dim = model.num_features
    print(f"feature dimensio is {feature_dim}")
    # Set 10 classes
    num_classes = 10
    # Create a simple MLP
    classifier = nn.Sequential(
        nn.Linear(feature_dim, 256),
        # nn.ReLU(inplace=True), # inplace replace the input with output
        nn.ReLU(), # inplace replace the input with output
        nn.Dropout(p=0.1),
        nn.Linear(256, num_classes)
    )
    # Connect the backbone with the classifier head
    class CustomDINOv3(nn.Module):
        def __init__(self, backbone, head):
            super().__init__()
            self.backbone = backbone
            self.head = head

        def forward(self, x):
            features = self.backbone(x)
            output = self.head(features)
            return output

    custom_model = CustomDINOv3(model, classifier)
    print("Defined classification model")
    return custom_model, num_classes, feature_dim

custom_model, num_classes, feature_dim = def_model(model)
print(num_classes)
print(feature_dim)
print(custom_model)

feature dimensio is 384
Defined classification model
10
384
CustomDINOv3(
  (backbone): Eva(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (rope): RotaryEmbeddingDinoV3()
    (norm_pre): Identity()
    (blocks): ModuleList(
      (0-11): 12 x EvaBlock(
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (attn): EvaAttention(
          (qkv): Linear(in_features=384, out_features=1152, bias=False)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (norm): Identity()
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path1): Identity()
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=38

# Load Data
* CIFAR-10 dataset: https://www.cs.toronto.edu/~kriz/cifar.html
* This data set has 60000 32*32 images in 10 classes, 6000 per class
* Use datasets & dataloader lib: https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html

In [12]:
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import ToPILImage
from PIL import Image, ImageDraw, ImageFont
    
def load_dataset(model):
    # obtain the data config corresponding to the model
    # timm get the data preprocess config for vision model, which can be used to get transforms
    data_config = timm.data.resolve_model_data_config(model) 
    print(data_config)
    transforms = timm.data.create_transform(
        **data_config,
        is_training=True
    )
    # Download CIFAR-10 dataset
    # CIFAR-10 dataset: 60000 32x32 colour images in 10 classes, with 6000 images per class
    train_dataset = datasets.CIFAR10(
        root = './data',
        train = True,
        download = True,
        transform = transforms
    )
    test_dataset = datasets.CIFAR10(
        root = './data',
        train = False,
        download = True,
        transform = transforms
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size = 32,
        shuffle = True,
        num_workers = 12
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size = 32,
        shuffle = True,
        num_workers = 12
    )
    print("dataset loaded")
    return train_loader, test_loader, data_config

train_loader, test_loader, dataconfig = load_dataset(model)
count = 0
for images, labels in train_loader:
    to_pil = ToPILImage()
    image_tensor = images[0, :, :, :]
    print(image_tensor.shape)
    image = to_pil(image_tensor)
    print(image_tensor)
    image.show()
    print(labels[0])
    count+=1
    if count > 1: break

totensor_transform = transforms.ToTensor()
test_dataset = datasets.CIFAR10(
    root = './data',
    train = False,
    download = True,
)
image_pil = test_dataset[1][0]
image_pil.show()
image_tensor = totensor_transform(image_pil)
print(image_tensor.shape)

{'input_size': (3, 256, 256), 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'crop_pct': 1.0, 'crop_mode': 'center'}
Files already downloaded and verified
Files already downloaded and verified
dataset loaded
torch.Size([3, 256, 256])
tensor([[[-1.5014, -1.5014, -1.5014,  ..., -1.6042, -1.6042, -1.6042],
         [-1.5014, -1.5014, -1.5014,  ..., -1.6042, -1.6042, -1.6042],
         [-1.5014, -1.5014, -1.5014,  ..., -1.6042, -1.6042, -1.6042],
         ...,
         [-0.7993, -0.7993, -0.7993,  ..., -1.2959, -1.3130, -1.3130],
         [-0.7993, -0.7993, -0.7993,  ..., -1.2959, -1.3130, -1.3130],
         [-0.7993, -0.7993, -0.7993,  ..., -1.2959, -1.3130, -1.3130]],

        [[-1.5105, -1.5105, -1.5105,  ..., -1.4405, -1.4405, -1.4405],
         [-1.5105, -1.5105, -1.5105,  ..., -1.4405, -1.4405, -1.4405],
         [-1.5105, -1.5105, -1.5105,  ..., -1.4405, -1.4405, -1.4405],
         ...,
         [-0.7052, -0.7052, -0.7052,  ..., -1.2129, -1.

# Train
* Train 10 epoch, get loss close or below 0.3

In [18]:
def train_model(custom_model, train_loader):
    custom_model.to(device)
    criterion = nn.CrossEntropyLoss()
    # Adam to optimze the classification hear
    optimizer = optim.Adam(custom_model.head.parameters(), lr=0.001)
    num_epochs = 10
    for epoch in range(num_epochs):
        custom_model.train()  # set the model to train mode
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = custom_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    print("Train completed")

train_model(custom_model, train_loader)

Epoch [1/10], Loss: 0.5261
Epoch [2/10], Loss: 0.4342
Epoch [3/10], Loss: 0.4184
Epoch [4/10], Loss: 0.4075
Epoch [5/10], Loss: 0.3980
Epoch [6/10], Loss: 0.3928
Epoch [7/10], Loss: 0.3952
Epoch [8/10], Loss: 0.3858
Epoch [9/10], Loss: 0.3806
Epoch [10/10], Loss: 0.3766
Train completed


# Save model

In [13]:
def save_model(custom_model, num_classes, feature_dim, data_config):
    torch.save({
        'classifier_state_dict': custom_model.head.state_dict(),
        'num_classes': num_classes,
        'feature_dim': feature_dim,
        'training_config': {
            'model_name': 'facebook/dinov3-base-pretrain-lvd1689m',
            'input_size': data_config['input_size']
        }
    }, 'dino_classifier_head.pth')
    print('model saved')

save_model(custom_model, num_classes, feature_dim, dataconfig)

model saved


# Evaluated model

In [20]:
def eval_model(custom_model, test_loader):
    correct = 0
    total = 0
    custom_model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = custom_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"accuracy is: {100* correct / total:.2f}%")

eval_model(custom_model, test_loader)

accuracy is: 87.15%


# Inference

In [21]:
def process_image(data_config, file_path):
    transforms = timm.data.create_transform(
        **data_config,
        is_training=False
    )
    if isinstance(file_path, str):
        image = Image.open(file_path).convert('RGB')
    else:
        image = file_path.convert('RGB')
    input_tensor = transforms(image).unsqueeze(0).to(device)
    return input_tensor

def classifier(model, input_tensor):
    with torch.no_grad():
        output = model(input_tensor)
        prob = torch.softmax(output, dim=1)
        confidence, predicted = torch.max(prob, 1)
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                  'dog', 'frog', 'horse', 'ship', 'truck']
    print(f"Predicted: {class_names[predicted.item()]}")
    return class_names[predicted.item()]

In [22]:
image_path_1 = "/home/yang/MyRepos/object_detection/images/cat2.jpg"
image_path_2 = "/home/yang/MyRepos/object_detection/images/dog1.jpg"
valid_dataset = datasets.CIFAR10(
    root = './data',
    train = False,
    download = True,
)

image = Image.open(image_path_1).convert('RGB')
image.show()
input_tensor = process_image(dataconfig, image_path_1)
class_name = classifier(custom_model, input_tensor)

image = Image.open(image_path_2).convert('RGB')
image.show()
input_tensor = process_image(dataconfig, image_path_2)
class_name = classifier(custom_model, input_tensor)

image = valid_dataset[10][0]
image.show()
input_tensor = process_image(dataconfig, image)
class_name = classifier(custom_model, input_tensor)

Files already downloaded and verified
Predicted: cat
Predicted: dog
Predicted: airplane


In [15]:
image = valid_dataset[53][0]
image_resized = image.resize((100, 100))
image_resized.show()
input_tensor = process_image(dataconfig, image)
class_name = classifier(custom_model, input_tensor)

Predicted: cat


In [93]:
valid_dataset = datasets.CIFAR10(
    root = './data',
    train = False,
    download = True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size = 32,
    shuffle = True,
    num_workers = 12
)

print(valid_loader)

Files already downloaded and verified
<torch.utils.data.dataloader.DataLoader object at 0x76edc473e480>
