In [None]:

import torch
from torchvision import datasets, transforms

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')


# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Repeat grayscale channel 3 times
                               ]))
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Repeat grayscale channel 3 times
                              ]))


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:00<00:00, 192MB/s]


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:02<00:00, 4303807.18it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 129726.30it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:05<00:00, 277275.92it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4509000.89it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:


from torch.utils.data import DataLoader

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

features, labels_list = [],[]

model.eval()

for data, labels in train_loader:
    data = data.to(device)
    labels = labels.to(device)

    with torch.no_grad():
      out = model(data)

    features.append(out.cpu())
    labels_list.append(labels.cpu())

features = torch.cat(features, dim=0)
labels_list = torch.cat(labels_list, dim=0)

In [None]:
train_features = torch.utils.data.TensorDataset(features, labels_list)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

features, labels_list = [],[]

model.eval()

for data, labels in test_loader:
    data = data.to(device)
    labels = labels.to(device)

    with torch.no_grad():
      out = model(data)

    features.append(out.cpu())
    labels_list.append(labels.cpu())

features = torch.cat(features, dim=0)
labels_list = torch.cat(labels_list, dim=0)

In [None]:
test_features = torch.utils.data.TensorDataset(features, labels_list)

In [None]:
with open("vit_train_features", "wb") as f:
    torch.save(train_features, f)

In [None]:
with open("vit_test_features", "wb") as f:
    torch.save(test_features, f)