# Feature Extraction (ResNet-34)

This notebook loads the prepared WikiArt dataset, builds a ResNet-34 feature extractor,
and saves a mapping from image path â†’ feature vector.

- **Input:** `../wikiart_full.pkl` (from `example_dataset.ipynb`)
- **Backbone:** ResNet-34 (ImageNet pretrained)
- **Output:** `../image2features.pkl` (dict: `{image_path: np.ndarray(features)}`)

In [None]:
import numpy as np
import os
import pickle
import random
import torch

from torchvision import models, transforms
from tqdm.notebook import tqdm
from PIL import Image

In [None]:
class load_dataset(torch.utils.data.Dataset):
    def __init__(self,list_, transform=None):
        
        self.list_ = list_
        self.transform = transform
        
    def __getitem__(self,index):
        
        img = Image.open(self.list_[0][index])

        if self.transform is not None:
            img = self.transform(img)
        
        return self.list_[0][index], img
    
    def __len__(self):
        return len(self.list_[0])
    
    
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    
def create_dataset(data):

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    
    X_train = [i[0] for i in data['train']]
    train_lists = [X_train]
    X_val = [i[0] for i in data['val']]
    val_lists = [X_val]
    X_test = [i[0] for i in data['test']]
    test_lists = [X_test]
    
    training_dataset = load_dataset(list_ = train_lists,
                                   transform = data_transforms['train'] )
    val_dataset = load_dataset(list_ = val_lists,
                               transform = data_transforms['val'] )
    test_dataset = load_dataset(list_ = test_lists,
                               transform = data_transforms['test'] )
    
    dataloaders_dict = {'train': torch.utils.data.DataLoader(training_dataset, batch_size=128, 
                                                             shuffle=False, num_workers=0),
                       'val':torch.utils.data.DataLoader(val_dataset, batch_size=128, 
                                                         shuffle=False, num_workers=0),
                       'test':torch.utils.data.DataLoader(test_dataset, batch_size=128, 
                                                          shuffle=False, num_workers=0)}
    
    return dataloaders_dict

In [None]:
set_seed(42)

# Load data
data_path = 'wikiart_full.pkl'
data = pickle.load(open(data_path, 'rb'))
dataloader = create_dataset(data)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load model
model_conv = models.resnet34(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False
    
feature_extractor = torch.nn.Sequential(*list(model_conv.children())[:-1])
feature_extractor[-2][-1].relu = torch.nn.Identity()
feature_extractor = feature_extractor.to(device)

In [None]:
# Feature extraction
feature_extractor.eval()

image2features = {}

for key in dataloader:
    
    for images, inputs in tqdm(dataloader[key]):

            inputs = inputs.to(device) 
            
            with torch.set_grad_enabled(False):

                output = feature_extractor(inputs)
                
                for id_, image in enumerate(images):

                    image2features[image] = output[id_].view(-1).detach().cpu().numpy()

In [None]:
# Save features as a pickle file for later use
with open('../image2features.pkl', 'wb') as file:
    pickle.dump(image2features, file)