1. Install Dependencies

In [15]:

%pip install transformers datasets torch torchvision evaluate

Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Downloading evaluate-0.4.3-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.3
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


 2. Prepare the MNIST Dataset

In [9]:
from datasets import load_dataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

# Load MNIST dataset
train_dataset = load_dataset("fashion_mnist", split="train[:1%]")
test_dataset = load_dataset("fashion_mnist", split="test[:1%]")

# Define transformations
transform = Compose([
    Resize((224, 224)),  # Resize to 224x224 for ViT
    ToTensor(),          # Convert to Tensor
    Normalize((0.5,), (0.5,))  # Normalize grayscale images
])

# Apply transformations
def transform_dataset(batch):
    batch["image"] = [transform(img.convert("RGB")) for img in batch["image"]]
    return batch

train_dataset = train_dataset.with_transform(transform_dataset)
test_dataset = test_dataset.with_transform(transform_dataset)


3. Load a Pretrained Vision Transformer

In [10]:
from transformers import ViTForImageClassification, ViTFeatureExtractor

# Load feature extractor and model
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=10,  # MNIST has 10 classes
    id2label={i: str(i) for i in range(10)},  # Map IDs to digit labels
    label2id={str(i): i for i in range(10)}
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


4. Prepare the DataLoader

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32)


5. Fine-tune the Model

In [None]:
import torch
from torch.optim import AdamW
from transformers import get_scheduler

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=5e-5)
num_training_steps = len(train_dataloader) * 3  # Assuming 3 epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

# Define loss function and move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
model.train()
for epoch in range(3):  # Train for 3 epochs
    for batch in train_dataloader:
        # Kiểm tra dữ liệu
        #print(type(batch["image"]))  # Kiểm tra kiểu dữ liệu

        # Nếu batch["image"] là list
        if isinstance(batch["image"], list):
            inputs = torch.stack(batch["image"]).to(device)
        else:
        # Nếu batch["image"] đã là tensor
            inputs = batch["image"].to(device)

        labels = torch.tensor(batch["label"]).to(device)
        outputs = model(pixel_values=inputs, labels=labels)
        loss = outputs.loss
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()


    print(f"Epoch {epoch + 1} complete. Loss: {loss.item():.4f}")


<class 'torch.Tensor'>


  labels = torch.tensor(batch["label"]).to(device)


<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
Epoch 1 complete. Loss: 1.8444
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
Epoch 2 complete. Loss: 1.4378
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tensor'>
<class 'torch.Tens

6. Evaluate the Model

In [17]:
import evaluate

# Load metric
metric = evaluate.load("accuracy")

# Evaluation loop
model.eval()
for batch in test_dataloader:
    
    if isinstance(batch["image"], list):
        inputs = torch.stack(batch["image"]).to(device)
    else:
        # Nếu batch["image"] đã là tensor
        inputs = batch["image"].to(device)

    labels = torch.tensor(batch["label"]).to(device)
    
    with torch.no_grad():
        outputs = model(pixel_values=inputs)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=labels)

# Compute final accuracy
final_accuracy = metric.compute()["accuracy"]
print(f"Test Accuracy: {final_accuracy * 100:.2f}%")



  labels = torch.tensor(batch["label"]).to(device)


Test Accuracy: 80.00%
