In [None]:
# Set up the environment:
!pip install transformers
!pip install torch torchvision
!pip install datasets


In [None]:
# load packages
import os
import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split, DataLoader
from datasets import Dataset


In [None]:
# load model
# Load the pre-trained ViT model



In [None]:
# Preprocess the data:
data_dir = "archive/"

data_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

image_dataset = datasets.ImageFolder(data_dir, data_transforms)
train_size = int(0.8 * len(image_dataset))
val_size = len(image_dataset) - train_size
train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

train_hf_dataset = Dataset.from_dict({"image": [x[0] for x in train_dataset], "label": [x[1] for x in train_dataset]})
val_hf_dataset = Dataset.from_dict({"image": [x[0] for x in val_dataset], "label": [x[1] for x in val_dataset]})
}


In [None]:
# Load the pre-trained ViT model
from transformers import ViTForImageClassification, ViTFeatureExtractor, TrainingArguments, Trainer

model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=7)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)



In [None]:
# Preprocess the dataset using the feature_extractor:
def preprocess_images(dataset):
    def preprocess_function(examples):
        images = examples["image"]
        images = [torch.squeeze(image).permute(1, 2, 0).numpy() for image in images]
        inputs = feature_extractor(images=images, return_tensors="pt")
        return {"pixel_values": inputs["pixel_values"], "labels": examples["label"]}
    return dataset.map(preprocess_function, batched=True)

train_hf_dataset = preprocess_images(train_hf_dataset)
val_hf_dataset = preprocess_images(val_hf_dataset)

In [None]:
#Train the model
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_hf_dataset,
    eval_dataset=val_hf_dataset,
)

trainer.train()


In [None]:
trainer.evaluate()

In [None]:
# upload the model to huggingface
# from transformers import ViTFeatureExtractor, ViTForImageClassification
# feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')