In [1]:
import torch
from PIL import Image
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, TrainingArguments, Trainer, ViTImageProcessor

feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

In [2]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [3]:
# Load Data Set
train_dataset = ImageFolder(root="./train", transform=transform)
val_dataset = ImageFolder(root="./val", transform=transform)
test_dataset = ImageFolder(root="./test", transform=transform)

print(train_dataset.class_to_idx) # print index, delete later

{'FAKE': 0, 'REAL': 1}


In [4]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, feature_extractor):
        self.dataset = dataset
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        encoding = self.feature_extractor(image, return_tensors="pt")
        encoding["labels"] = torch.tensor(label)
        return {key: val.squeeze() for key, val in encoding.items()}  # Flatten tensors

train_dataset = CustomDataset(train_dataset, feature_extractor)
val_dataset = CustomDataset(val_dataset, feature_extractor)

In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=2,  # Binary classification (AI-generated vs. Real)
    ignore_mismatched_sizes=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [None]:
training_args = TrainingArguments(
    output_dir="./vit_cifake",
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=5,
    save_total_limit=2,
    learning_rate=5e-5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    # dataloader_num_workers=4,
    # gradient_accumulation_steps=2,
    # optim="adamw_torch",
    # disable_tqdm=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Epoch,Training Loss,Validation Loss


In [None]:
results = trainer.evaluate()
print(results)

In [None]:
from PIL import Image

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    encoding = feature_extractor(image, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")

    with torch.no_grad():
        outputs = model(**encoding)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()

    labels = ["Real", "AI-Generated"]
    return labels[predicted_class]

# print(predict("test/real/0001.jpg"))