## Preprocessing

In [3]:
import os
from glob import glob
import pandas as pd
from PIL import Image

In [8]:
# preprocessing
img_paths = []
labels = []
data_types = []

dataset_paths = ['data/Real_AI_SD_LD_Dataset/test', 'data/Real_AI_SD_LD_Dataset/train']

for p in dataset_paths:
    data_type = os.path.basename(os.path.normpath(p))
    for dir in os.listdir(p):
        path = os.path.join(p, dir)
        
        if os.path.isdir(path):
            # extract label
            if dir[:2] == 'AI':
                label = "ai"
            else:
                label = "real"
                
            image_files = glob(os.path.join(path, '*.jpg'))
            
            for image in image_files:
                try:
                    img = Image.open(image)
                    img.verify()
                    
                    img_paths.append(image)
                    labels.append(label)
                    data_types.append(data_type)
                except (IOError, SyntaxError) as e:
                    print(f"Error opening image {image}: {e}")
                


In [9]:
data = pd.DataFrame({
    'image_path': img_paths, 
    'label': labels,
    'type': data_types
})

data.head()

Unnamed: 0,image_path,label,type
0,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,test
1,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,test
2,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,test
3,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,test
4,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,test


In [10]:
data.to_csv('dataset.csv', index=False)

## Training the model

In [26]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
import os

In [59]:
# load csv
df = pd.read_csv('dataset.csv')
test_df = df[df['type'] == 'test']
train_df = df[df['type'] == 'train']


# take a sample of the dataset
# test_df = test_df.groupby('label', group_keys=False).apply(lambda x: x.sample(n=250, random_state=42)).reset_index(drop=True)
# train_df = train_df.groupby('label', group_keys=False).apply(lambda x: x.sample(n=250, random_state=42)).reset_index(drop=True)
# sampled_df = pd.concat([test_df, train_df], axis=0, ignore_index=True)

In [60]:
class ArtDataset(Dataset):
    def __init__(self, df, feature_extractor, train=True):
        self.df = df[df['type'] == ('train' if train else 'test')]
        self.feature_extractor = feature_extractor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path']).convert("RGB")
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze()
        label = 1 if row['label'] == 'ai' else 0
        return pixel_values, label

In [66]:
# Load feature extractor and model
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=2)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [67]:
train_dataset = ArtDataset(df, feature_extractor, train=True)
test_dataset = ArtDataset(df, feature_extractor, train=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [68]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

Device: cuda


In [None]:
checkpoint_dir = "vit_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
for epoch in range(5):
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        images = images.to(device)
        labels = torch.tensor(labels).to(device)
        
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    print(f"Epoch {epoch+1} - Training loss: {train_loss/len(train_loader)}")
    checkpoint_path = os.path.join(checkpoint_dir, f"vit_checkpoint_epoch_{epoch+1}.pt")
    torch.save({
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epoch,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

  labels = torch.tensor(labels).to(device)
Training Epoch 1: 100%|██████████| 4845/4845 [1:49:43<00:00,  1.36s/it]


Epoch 1 - Training loss: 0.026729605788335213
Checkpoint saved at vit_checkpoints/vit_checkpoint_epoch_1.pt


Training Epoch 2: 100%|██████████| 4845/4845 [1:50:42<00:00,  1.37s/it]


Epoch 2 - Training loss: 0.00870591711777908
Checkpoint saved at vit_checkpoints/vit_checkpoint_epoch_2.pt


Training Epoch 3: 100%|██████████| 4845/4845 [1:50:47<00:00,  1.37s/it]


Epoch 3 - Training loss: 0.006089948401729621
Checkpoint saved at vit_checkpoints/vit_checkpoint_epoch_3.pt


Training Epoch 4: 100%|██████████| 4845/4845 [1:50:43<00:00,  1.37s/it]


Epoch 4 - Training loss: 0.0049372117144087626
Checkpoint saved at vit_checkpoints/vit_checkpoint_epoch_4.pt


Training Epoch 5: 100%|██████████| 4845/4845 [1:50:40<00:00,  1.37s/it]


Epoch 5 - Training loss: 0.004522934531089671
Checkpoint saved at vit_checkpoints/vit_checkpoint_epoch_5.pt


In [70]:
# save model
output_dir = "./vit_checkpoints"
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = os.path.join(output_dir, "vit_art_classification_final.pth")
torch.save(model.state_dict(), checkpoint_path)
print(f"Model checkpoint saved at {checkpoint_path}")

Model checkpoint saved at ./vit_checkpoints/vit_art_classification_final.pth


## Testing

In [None]:
model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = torch.tensor(labels).to(device)
        outputs = model(images).logits
        _, preds = torch.max(outputs, dim=1)
        
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

print("Test Accuracy:", accuracy_score(all_labels, all_preds))
print(classification_report(all_labels, all_preds, target_names=["ai", "real"]))

  labels = torch.tensor(labels).to(device)


Test Accuracy: 0.9951666666666666
              precision    recall  f1-score   support

          ai       1.00      0.99      0.99     10000
        real       0.99      1.00      1.00     20000

    accuracy                           1.00     30000
   macro avg       1.00      0.99      0.99     30000
weighted avg       1.00      1.00      1.00     30000

