## Preprocessing

In [2]:
import os
from glob import glob
import pandas as pd
from PIL import Image

In [4]:
# preprocessing
img_paths = []
labels = []
styles = []
data_types = []

dataset_paths = ['data/Real_AI_SD_LD_Dataset/test', 'data/Real_AI_SD_LD_Dataset/train']

for p in dataset_paths:
    data_type = os.path.basename(os.path.normpath(p))
    for dir in os.listdir(p):
        path = os.path.join(p, dir)
        
        if os.path.isdir(path):
            # extract label and style
            if dir[:2] == 'AI':
                label = "ai"
                style = dir[6:].replace('-', '_')
            else:
                label = "real"
                style = dir.replace('-', '_')
                
            image_files = glob(os.path.join(path, '*.jpg'))
            
            for image in image_files:
                try:
                    img = Image.open(image)
                    img.verify()
                    
                    img_paths.append(image)
                    labels.append(label)
                    styles.append(style)
                    data_types.append(data_type)
                except (IOError, SyntaxError) as e:
                    print(f"Error opening image {image}: {e}")
                


In [6]:
data = pd.DataFrame({
    'image_path': img_paths, 
    'label': labels,
    'style': styles,
    'type': data_types
})

data.head()

Unnamed: 0,image_path,label,style,type
0,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,ukiyo_e,test
1,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,ukiyo_e,test
2,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,ukiyo_e,test
3,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,ukiyo_e,test
4,data/Real_AI_SD_LD_Dataset/test/AI_LD_ukiyo-e/...,ai,ukiyo_e,test


In [7]:
data['style'].unique()

array(['ukiyo_e', 'art_nouveau', 'post_impressionism', 'impressionism',
       'expressionism', 'baroque', 'surrealism', 'realism', 'romanticism',
       'renaissance'], dtype=object)

In [10]:
data.to_csv('dataset.csv', index=False)

## Training the model

In [17]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
import os

In [18]:
# load csv
df = pd.read_csv('dataset.csv')
test_df = df[df['type'] == 'test']
train_df = df[df['type'] == 'train']


# take a sample of the dataset
# test_df = test_df.groupby('label', group_keys=False).apply(lambda x: x.sample(n=250, random_state=42)).reset_index(drop=True)
# train_df = train_df.groupby('label', group_keys=False).apply(lambda x: x.sample(n=250, random_state=42)).reset_index(drop=True)
# sampled_df = pd.concat([test_df, train_df], axis=0, ignore_index=True)

In [19]:
class ArtDataset(Dataset):
    def __init__(self, df, feature_extractor, train=True):
        self.df = df[df['type'] == ('train' if train else 'test')]
        self.feature_extractor = feature_extractor
        self.label_mapping = {label: idx for idx, label in enumerate(df['style'].unique())}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['image_path']).convert("RGB")
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze()
        
        label = self.label_mapping[row['style']]
        return pixel_values, label

In [25]:
# Load feature extractor and model
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=10)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
train_dataset = ArtDataset(df, feature_extractor, train=True)
test_dataset = ArtDataset(df, feature_extractor, train=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [27]:
print(train_dataset.label_mapping)

pixel_values, label = train_dataset[0]
print("Label :", label)

{'ukiyo_e': 0, 'art_nouveau': 1, 'post_impressionism': 2, 'impressionism': 3, 'expressionism': 4, 'baroque': 5, 'surrealism': 6, 'realism': 7, 'romanticism': 8, 'renaissance': 9}
Label : 0


In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

Device: cuda


In [29]:
checkpoint_dir = "vit_style_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
for epoch in range(5):
    model.train()
    train_loss = 0.0
    for images, styles in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        images = images.to(device)
        styles = torch.tensor(styles).to(device)
        
        outputs = model(images).logits
        loss = criterion(outputs, styles)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    print(f"Epoch {epoch+1} - Training loss: {train_loss/len(train_loader)}")
    checkpoint_path = os.path.join(checkpoint_dir, f"vit_checkpoint_epoch_{epoch+1}.pt")
    torch.save({
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epoch,
    }, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

  styles = torch.tensor(styles).to(device)
Training Epoch 1: 100%|██████████| 4845/4845 [1:48:36<00:00,  1.34s/it]


Epoch 1 - Training loss: 0.5023614762844931
Checkpoint saved at vit_style_checkpoints/vit_checkpoint_epoch_1.pt


Training Epoch 2: 100%|██████████| 4845/4845 [1:49:56<00:00,  1.36s/it]


Epoch 2 - Training loss: 0.28467922668531154
Checkpoint saved at vit_style_checkpoints/vit_checkpoint_epoch_2.pt


Training Epoch 3: 100%|██████████| 4845/4845 [1:50:00<00:00,  1.36s/it]


Epoch 3 - Training loss: 0.16707889793667283
Checkpoint saved at vit_style_checkpoints/vit_checkpoint_epoch_3.pt


Training Epoch 4: 100%|██████████| 4845/4845 [1:49:58<00:00,  1.36s/it]


Epoch 4 - Training loss: 0.09191392988516296
Checkpoint saved at vit_style_checkpoints/vit_checkpoint_epoch_4.pt


Training Epoch 5: 100%|██████████| 4845/4845 [1:50:05<00:00,  1.36s/it]


Epoch 5 - Training loss: 0.06046665606593042
Checkpoint saved at vit_style_checkpoints/vit_checkpoint_epoch_5.pt


In [None]:
# Continue from last epoch
start_epoch = 4
checkpoint_dir = "vit_style_checkpoints"
checkpoint_path = os.path.join(checkpoint_dir, "vit_checkpoint_epoch_4.pt")

# Load feature extractor and model
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

train_dataset = ArtDataset(df, feature_extractor, train=True)
test_dataset = ArtDataset(df, feature_extractor, train=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)
    print(checkpoint.keys())
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optimizer_state'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Device: cuda
Loading checkpoint from vit_style_checkpoints/vit_checkpoint_epoch_4.pt


  checkpoint = torch.load(checkpoint_path)


odict_keys(['vit.embeddings.cls_token', 'vit.embeddings.position_embeddings', 'vit.embeddings.patch_embeddings.projection.weight', 'vit.embeddings.patch_embeddings.projection.bias', 'vit.encoder.layer.0.attention.attention.query.weight', 'vit.encoder.layer.0.attention.attention.query.bias', 'vit.encoder.layer.0.attention.attention.key.weight', 'vit.encoder.layer.0.attention.attention.key.bias', 'vit.encoder.layer.0.attention.attention.value.weight', 'vit.encoder.layer.0.attention.attention.value.bias', 'vit.encoder.layer.0.attention.output.dense.weight', 'vit.encoder.layer.0.attention.output.dense.bias', 'vit.encoder.layer.0.intermediate.dense.weight', 'vit.encoder.layer.0.intermediate.dense.bias', 'vit.encoder.layer.0.output.dense.weight', 'vit.encoder.layer.0.output.dense.bias', 'vit.encoder.layer.0.layernorm_before.weight', 'vit.encoder.layer.0.layernorm_before.bias', 'vit.encoder.layer.0.layernorm_after.weight', 'vit.encoder.layer.0.layernorm_after.bias', 'vit.encoder.layer.1.atten

KeyError: 'param_groups'

In [30]:
# save model
output_dir = "./vit_style_checkpoints"
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = os.path.join(output_dir, "vit_art_style_classification_final.pth")
torch.save(model.state_dict(), checkpoint_path)
print(f"Model checkpoint saved at {checkpoint_path}")

Model checkpoint saved at ./vit_style_checkpoints/vit_art_style_classification_final.pth


In [31]:
model.eval()
all_styles = []
all_preds = []

with torch.no_grad():
    for images, styles in test_loader:
        images = images.to(device)
        styles = torch.tensor(styles).to(device)
        outputs = model(images).logits
        _, preds = torch.max(outputs, dim=1)
        
        all_styles.extend(styles.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

print("Test Accuracy:", accuracy_score(all_styles, all_preds))
print(classification_report(all_styles, all_preds, target_names=['ukiyo_e', 'art_nouveau', 'post_impressionism', 'impressionism',
       'expressionism', 'baroque', 'surrealism', 'realism', 'romanticism',
       'renaissance']))

  styles = torch.tensor(styles).to(device)


Test Accuracy: 0.8591333333333333
                    precision    recall  f1-score   support

           ukiyo_e       1.00      0.99      1.00      3000
       art_nouveau       0.89      0.89      0.89      3000
post_impressionism       0.72      0.74      0.73      3000
     impressionism       0.72      0.79      0.75      3000
     expressionism       0.89      0.82      0.85      3000
           baroque       0.92      0.88      0.90      3000
        surrealism       0.95      0.92      0.94      3000
           realism       0.79      0.79      0.79      3000
       romanticism       0.82      0.84      0.83      3000
       renaissance       0.90      0.93      0.91      3000

          accuracy                           0.86     30000
         macro avg       0.86      0.86      0.86     30000
      weighted avg       0.86      0.86      0.86     30000



In [32]:
pd.Series(all_styles).unique()

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])