In [None]:
import pandas as pd
from pathlib import Path

path = Path('/kaggle/input/wikiart') 

# Specific

## Processing filename

Because Kaggle does not allow non-ascii in files' name, I wrote a script to rename those files and replace non-ascii characters with underscore (_). The code below was executed before uploading the dataset

In [None]:
# import os
# import unicodedata
# import re
# import shutil

# def normalize_filename(filename):
#     normalized = unicodedata.normalize('NFKD', filename)
#     ascii_filename = ''
#     for char in normalized:
#         if ord(char) < 128 and ord(char) != 39:
#             ascii_filename += char
#         else:
#             replacements = {
#                 'ä': 'a', 'ö': 'o', 'ü': 'u', 'ß': 'ss',
#                 'á': 'a', 'é': 'e', 'í': 'i', 'ó': 'o', 'ú': 'u',
#                 'à': 'a', 'è': 'e', 'ì': 'i', 'ò': 'o', 'ù': 'u',
#                 'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'û': 'u',
#             }
#             ascii_filename += replacements.get(char, '_') # Replace with _ 

#     ascii_filename = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', ascii_filename)
    
#     return ascii_filename

# def rename_non_ascii_files(dir):
#     renamed_files = []
#     skipped_files = []
    
#     for art_style in os.listdir(dir):
#         style_path = os.path.join(dir, art_style)
#         if os.path.isdir(style_path):
#             print(f"Processing {art_style}...")
#             for file in os.listdir(style_path):
#                 filepath = os.path.join(style_path, file)
#                 has_non_ascii = any(ord(char) > 127 for char in file)
#                 if has_non_ascii:
#                     new_filename = normalize_filename(file)
#                     new_filepath = os.path.join(style_path, new_filename)
#                     try:
#                         print(f"  Renaming: {filepath} -> {new_filepath}")
#                         shutil.move(filepath, new_filepath)
#                         renamed_files.append((filepath, new_filepath))
#                     except Exception as e:
#                         print(f"  Error renaming {filepath}: {e}")
#                         skipped_files.append((filepath, str(e)))
    
#     print(f"\nRenamed {len(renamed_files)} files")
#     print(f"Skipped {len(skipped_files)} files due to errors")
    
#     return {
#         "renamed_files": renamed_files,
#         "skipped_files": skipped_files
#     }

# rename_non_ascii_files('wikiart')

## Building custom dataset

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, WeightedRandomSampler
import torchvision.transforms as transforms

TASKS_LIST = ['style', 'artist', 'genre']

corrupted_images = [
    'Baroque/rembrandt_woman-standing-with-raised-hands.jpg',
    'Post_Impressionism/vincent-van-gogh_l-arlesienne-portrait-of-madame-ginoux-1890.jpg'
]

class TRAIN_SpecificArtGANDataset(Dataset):
    def __init__(self, data_path, task='genre'):
        super().__init__()
        assert (task in TASKS_LIST), f'Task should be either {TASKS_LIST}\n'
        
        self.data_path = data_path
        self.task = task
        self.classes_csv = pd.read_csv(data_path / f'{task}_class.txt', sep=" ", names=['label', 'name'])

        # --------------------------- Cleaning data
        data_csv = pd.read_csv(data_path / f'{task}_train.csv', names=['filename', 'label'])
        data_csv = data_csv.query("filename not in @corrupted_images")
        
        import re
        import unicodedata
        def process_filename(f): # Remove non-ascii characters and '/' from filenames
            dirname, filename = f.split('/', 1)
            normalized = unicodedata.normalize('NFKD', filename)
            ascii_filename = ''
            for char in normalized:
                if ord(char) < 128 and ord(char) != 39:
                    ascii_filename += char
                else:
                    replacements = {
                        'ä': 'a', 'ö': 'o', 'ü': 'u', 'ß': 'ss',
                        'á': 'a', 'é': 'e', 'í': 'i', 'ó': 'o', 'ú': 'u',
                        'à': 'a', 'è': 'e', 'ì': 'i', 'ò': 'o', 'ù': 'u',
                        'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'û': 'u',
                    }
                    ascii_filename += replacements.get(char, '_') # Replace with _ 
        
            ascii_filename = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', ascii_filename)
            return dirname + "/" + ascii_filename
        
        data_csv.loc[:, "filename"] = data_csv["filename"].map(process_filename) # Pandas 3.0
        self.data_csv = data_csv
        self.imgs_path = data_csv["filename"].tolist()
        self.labels = data_csv["label"].tolist()

        # --------------------------- Custom transforms for train dataset
        self.train_transforms = transforms.Compose([
            transforms.RandomResizedCrop((224,224)),
            transforms.RandomHorizontalFlip(0.3),
            transforms.RandomVerticalFlip(0.3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.imgs_path)

    def get_label_name(self, label):
        return self.classes_csv['name'][label]
    
    def transform(self, img):
        img = self.train_transforms(img)
        return img

    def get_raw(self, idx):
        img = self.imgs_path[idx]
        label = self.labels[idx]

        return {
            'images': img,
            'labels': label
        }
    
    def __getitem__(self, idx):
        img = Image.open(self.data_path / self.imgs_path[idx])
        label = self.labels[idx]

        img = self.transform(img)
        return {
            'images': img,
            'labels': label
        }


train_dataset = TRAIN_SpecificArtGANDataset(data_path=path, task='style')

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, WeightedRandomSampler
import torchvision.transforms as transforms

TASKS_LIST = ['style', 'artist', 'genre']

corrupted_images = [
    'Baroque/rembrandt_woman-standing-with-raised-hands.jpg',
    'Post_Impressionism/vincent-van-gogh_l-arlesienne-portrait-of-madame-ginoux-1890.jpg'
]

class VAL_SpecificArtGANDataset(Dataset):
    def __init__(self, data_path, task='genre'):
        super().__init__()
        assert (task in TASKS_LIST), f'Task should be either {TASKS_LIST}\n'
        
        self.data_path = data_path
        self.task = task
        self.classes_csv = pd.read_csv(data_path / f'{task}_class.txt', sep=" ", names=['label', 'name'])

        # --------------------------- Cleaning data
        data_csv = pd.read_csv(data_path / f'{task}_val.csv', names=['filename', 'label'])
        data_csv = data_csv.query("filename not in @corrupted_images")
        
        import re
        import unicodedata
        def process_filename(f):
            dirname, filename = f.split('/', 1)
            normalized = unicodedata.normalize('NFKD', filename)
            ascii_filename = ''
            for char in normalized:
                if ord(char) < 128 and ord(char) != 39:
                    ascii_filename += char
                else:
                    replacements = {
                        'ä': 'a', 'ö': 'o', 'ü': 'u', 'ß': 'ss',
                        'á': 'a', 'é': 'e', 'í': 'i', 'ó': 'o', 'ú': 'u',
                        'à': 'a', 'è': 'e', 'ì': 'i', 'ò': 'o', 'ù': 'u',
                        'â': 'a', 'ê': 'e', 'î': 'i', 'ô': 'o', 'û': 'u',
                    }
                    ascii_filename += replacements.get(char, '_') # Replace with _ 
        
            ascii_filename = re.sub(r'[^a-zA-Z0-9_\-\.]', '_', ascii_filename)
            return dirname + "/" + ascii_filename
        
        data_csv.loc[:, "filename"] = data_csv["filename"].map(process_filename) # Pandas 3.0
        self.data_csv = data_csv
        self.imgs_path = data_csv["filename"].tolist()
        self.labels = data_csv["label"].tolist()

        # --------------------------- Custom transforms for valid dataset
        self.val_transforms = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
    def __len__(self):
        return len(self.imgs_path)

    def get_label_name(self, label):
        return self.classes_csv['name'][label]
    
    def transform(self, img):
        img = self.val_transforms(img)
        return img

    def get_label(self, idx):
        return self.labels[idx]
        
    def __getitem__(self, idx):
        img = Image.open(self.data_path / self.imgs_path[idx])
        label = self.labels[idx]

        img = self.transform(img)
        return {
            'images': img,
            'labels': label
        }

test_dataset = VAL_SpecificArtGANDataset(data_path=path, task='style')

In [None]:
# Because of imbalanced dataset, we need a custom sampler to oversample the minority classes
# This case, I used WeightedRandomSampler to sample class based on the class weights (which is the inverse of class frequency)
def oversampling_dataset(labels):
    unique, counts = np.unique(labels, return_counts=True)
    class_weights = [1.0/c for c in counts]
    weights_y = [class_weights[i] for i in labels]
    return WeightedRandomSampler(weights_y, len(weights_y))

In [None]:
import matplotlib.pyplot as plt

sample_data = train_dataset[0]
image, label = sample_data['images'], sample_data['labels']
plt.imshow(np.array(image).transpose(1, 2, 0))
plt.title(train_dataset.get_label_name(label))
plt.show()

## Model architecture

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import torchvision.transforms.v2 as v2

batch_size = 64
num_workers = 0

generator1 = torch.Generator().manual_seed(86)

task = 'artist'
train_dataset = TRAIN_SpecificArtGANDataset(data_path=path, task=task)
val_dataset = VAL_SpecificArtGANDataset(data_path=path, task=task)
    
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    sampler=oversampling_dataset(train_dataset.labels)
)
val_dataloader = DataLoader(
    val_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers, 
    sampler=oversampling_dataset(val_dataset.labels)
)

def labels_getter(batch):
    return batch[1]
    
cutmix = v2.CutMix(num_classes=len(train_dataset.classes_csv), labels_getter=labels_getter)
mixup = v2.MixUp(num_classes=len(train_dataset.classes_csv), labels_getter=labels_getter)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

In [None]:
from timm import create_model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet50 = create_model('resnet50', pretrained=True, num_classes=len(train_dataset.classes_csv))
model = resnet50
# vit_model = create_model('vit_base_patch16_224', pretrained=True, num_classes=len(train_dataset.classes_csv))
# model = vit_model
model.to(device)
print(type(model))

In [None]:
n_epochs = 3
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.66)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=2,
    min_lr=1e-6,
    eps=0.05
)

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    losses = []
    with torch.no_grad():
        with tqdm(total=len(val_dataset), desc=f'Validating', unit='img') as pbar:
            for batch in dataloader:
                images, labels = batch['images'], batch['labels']
                images = images.to(device, dtype=torch.float32, memory_format=torch.channels_last)
                labels = labels.to(device, dtype=torch.long)
                
                labels_pred = model(images)
                
                loss = criterion(labels_pred, labels)
                losses.append(loss.item())
                
                predicted = torch.argmax(torch.softmax(labels_pred, dim=1), dim=1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                pbar.update(images.shape[0])

    avg_loss = sum(losses) / len(losses)
    accuracy = correct / total
    model.train()
    return avg_loss, accuracy

In [None]:
from tqdm import tqdm

global_step = 0
n_train = len(train_dataset)

# Per epoch
train_losses = []
val_losses = []
val_accuracy = []

# Per step
train_losses_steps = []

best_model = model
best_val_acc = 0.0

In [None]:
Image.MAX_IMAGE_PIXELS = 1000000000   
for epoch in range(1, n_epochs+1):
    model.train()
    epoch_loss = 0.0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{n_epochs}', unit='img') as pbar:
        for batch in train_dataloader:
            images, labels = batch['images'], batch['labels']
            # images, labels = cutmix_or_mixup(images, labels)
            # labels = labels.argmax(dim=1)
            images = images.to(device, dtype=torch.float32, memory_format=torch.channels_last)
            labels = labels.to(device, dtype=torch.long)

            labels_pred = model(images)
            loss = criterion(labels_pred, labels)
            
            epoch_loss += loss.item()
            train_losses_steps.append(loss.item() / batch_size)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            global_step += 1
            pbar.update(images.shape[0])
            pbar.set_postfix(**{f'loss (batch)': loss.item()})


    train_losses.append(epoch_loss / len(train_dataloader))
    val_loss, val_acc = evaluate(model, val_dataloader, criterion, device)
    scheduler.step(val_loss)
    val_losses.append(val_loss)
    val_accuracy.append(val_acc)
    print(f'Validation loss: {val_loss}, Validation accuracy: {val_acc}')
    if val_acc > best_val_acc:
        best_model = model
        best_val_acc = val_acc

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = best_model

## Plotting model training/validation loss and accuracy

In [None]:
import math
plt.figure(figsize=(12, 8))

plt.subplot(1, 3, 1)
plt.plot(range(1, len(train_losses_steps)+1), train_losses_steps, label='Training Loss (per step)')
plt.title('Train loss (per step)')
plt.ylabel('Loss')
plt.xlabel('Steps')
plt.xticks(np.arange(1, len(train_losses_steps)+1, 100))
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(range(1, n_epochs + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, n_epochs + 1), val_losses, label='Validation Loss', marker='x')
plt.title('Train/Valid loss (per epoch)')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(range(1, n_epochs + 1), val_accuracy, label='Validation Accuracy')
plt.title('Validation accuracy (per epoch)')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.grid(True)

plt.show()

## Saving model

In [None]:
# state_dict = {
#     'lr': learning_rate,
#     'global_step': global_step,
#     'current_epochs': epoch,
#     'n_epochs': n_epochs,
#     'model_state_dict': model.state_dict(),
#     'optim_state_dict': optimizer.state_dict(),
#     'scheduler_state_dict': scheduler.state_dict()
# }

# torch.save(state_dict, str(f'/kaggle/working/wikiart_{task}_ResNet_epoch{}.pt'.format(epoch)))

## Load model

In [None]:
# model_path = "Enter model path here"
# state_dict = torch.load(model_path, map_location=device)
# model.load_state_dict(state_dict['model_state_dict'])
# optimizer.load_state_dict(state_dict['optim_state_dict'])
# scheduler.load_state_dict(state_dict['scheduler_state_dict'])
# global_step = state_dict['global_step']
# epoch = state_dict['current_epochs']
# n_epochs = state_dict['n_epochs']

# Test

## Confusion matrix

In [None]:
# Use validation dataset for testing
test_dataset = VAL_SpecificArtGANDataset(data_path=path, task='artist')

test_dataloader = DataLoader(
    test_dataset, 
    batch_size=batch_size, 
    num_workers=num_workers,
    sampler=oversampling_dataset(test_dataset.labels)
)

In [None]:
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score, accuracy_score

all_pred, all_label, all_label_pred = [], [], []
def evaluate_test(model, dataloader, device):
    model.eval() # Set the model to evaluation mode
    global all_pred, all_label, all_label_pred
    with torch.no_grad():
        with tqdm(total=len(test_dataset), desc=f'Testing', unit='img') as pbar:
            for batch in dataloader:
                images, labels = batch['images'], batch['labels']
                images = images.to(device, dtype=torch.float32, memory_format=torch.channels_last)
                labels = labels.to(device, dtype=torch.long)
                
                labels_pred = model(images)
                labels_pred = nn.Softmax(dim=1)(labels_pred)
                all_label_pred.append(labels_pred.cpu().numpy())

                predicted = torch.argmax(labels_pred, dim=1) # Get the index of the class with the highest probability
                all_pred.append(predicted.cpu().numpy()) 
                all_label.append(labels.cpu().numpy())
                pbar.update(images.shape[0])

    all_pred = np.concatenate(all_pred)
    all_label = np.concatenate(all_label)
    
    # Create confusion matrix
    cm = confusion_matrix(all_label, all_pred, labels=np.arange(0, len(test_dataset.classes_csv)))
    cmDisplay = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.arange(0, len(test_dataset.classes_csv)))
    fig, ax = plt.subplots(figsize=(15,15))
    cmDisplay.plot(ax=ax)

    # Calculate metrics
    print(f"Accuracy: {accuracy_score(all_label, all_pred, normalize=True)}")
    print(f"Micro F1: {f1_score(all_label, all_pred, average='micro')}")
    print(f"Weighted F1: {f1_score(all_label, all_pred, average='weighted')}")
    model.train()

evaluate_test(model, test_dataloader, device)

### Classification report

In [None]:
from sklearn.metrics import classification_report

true_labels = [test_dataset.get_label_name(label) for label in test_dataset.classes_csv['label']]

print(classification_report(all_label, all_pred, target_names=test_dataset.classes_csv['name'].tolist()))

## GradCAM (for ViT)

In [None]:
# !pip install grad-cam

In [None]:
# import torch
# import numpy as np
# import cv2
# import matplotlib.pyplot as plt
# from pytorch_grad_cam import GradCAM
# from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform

# target_labels=[1, 2, 4, 5, 6]
# examples_per_label=2
# label_examples = {label: [] for label in target_labels}

# for batch in test_dataloader:
#     imgs, labels = batch['images'], batch['labels']
#     for i, label in enumerate(labels):
#         label_item = label.item()
#         if label_item in target_labels and len(label_examples[label_item]) < examples_per_label:
#             label_examples[label_item].append(imgs[i])

#     if all(len(label_examples[label]) == examples_per_label for label in target_labels):
#         break

In [None]:
# model.eval()
# target_layers = [model.blocks[-1].norm1]
# cam = GradCAM(model=model, target_layers=target_layers, reshape_transform=vit_reshape_transform)

# for label, imgs in label_examples.items():
#     for i, img in enumerate(imgs):
#         input_tensor = img.unsqueeze(0).to(device)
#         grayscale_cam = cam(input_tensor=input_tensor, targets=None)
#         grayscale_cam = grayscale_cam[0, :]
        
#         # Convert the image tensor to a NumPy array for visualization.
#         # Assumes the image is in (C, H, W) format and normalized between 0 and 1.
#         img_np = img.cpu().numpy().transpose(1, 2, 0)
        
#         heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
#         heatmap = np.float32(heatmap) / 255
        
#         # Overlay the heatmap on the original image.
#         overlay = heatmap + np.float32(img_np)
#         overlay = overlay / np.max(overlay)
        
#         plt.figure(figsize=(12, 4))
#         plt.subplot(1, 3, 1)
#         plt.imshow(img_np)
#         plt.title("Original Image")
#         plt.axis("off")
        
#         plt.subplot(1, 3, 2)
#         plt.imshow(grayscale_cam, cmap="jet")
#         plt.title("GradCAM Heatmap")
#         plt.axis("off")
        
#         plt.subplot(1, 3, 3)
#         plt.imshow(overlay)
#         plt.title("Overlay")
#         plt.axis("off")
        
#         plt.suptitle(f"Label: {label} | Example: {i+1}", fontsize=14)
#         plt.tight_layout()
#         plt.show()
