# Alexnet based feature extractor

In [None]:
# imports
import warnings; warnings.filterwarnings('ignore')
import os
import torch
from multiprocessing import Pool
from torchvision.models.feature_extraction import create_feature_extractor
from sklearn.decomposition import PCA
import numpy as np
from tqdm import tqdm
from PIL import Image
from src.utils import DATA_DIR

In [None]:
# get data and model
subjs = ['subj01', 'subj02', 'subj03', 'subj04', 'subj05', 'subj06', 'subj07', 'subj08']
N_SAMPLES = 0
model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet')
feature_extractor = create_feature_extractor(model, return_nodes=["features.2"])

## Get image data

In [None]:
def get_img_files(subj):
    subj_img_dir = os.path.join(DATA_DIR, subj, 'training_split/training_images')
    subj_img_files = [os.path.join(subj_img_dir, f) for f in os.listdir(subj_img_dir) if f.endswith('.png')]
    return sorted(subj_img_files)

def load_img_files(subj):
    # images are pngs
    img_files = get_img_files(subj)
    img_files = img_files[:N_SAMPLES] if N_SAMPLES else img_files
    imgs = []
    for f in tqdm(img_files):  # make sure not to have too many files open
        with Image.open(f) as img:
            img = img.convert('RGB').resize((224, 224))
            img = torch.from_numpy(np.array(img))
            imgs.append(img)
    imgs = torch.stack(imgs)
    imgs = imgs / 255.0
    imgs = imgs.permute(0, 3, 1, 2)
    imgs = normalize(imgs)
    return imgs

def normalize(imgs):
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    imgs = imgs.float()
    for i in range(3):
        imgs[:, i, :, :] = (imgs[:, i, :, :] - means[i]) / stds[i]
    return imgs


In [None]:
def run_subj(subj):
    pca = PCA(n_components=100)
    data = load_img_files(subj)
    feats = feature_extractor(data)
    feats = torch.hstack([torch.flatten(l, start_dim=1) for l in feats.values()])
    feats = feats.detach().numpy()
    feats = feats.reshape(feats.shape[0], -1)
    feats = pca.fit_transform(feats)
    np.save(os.path.join(DATA_DIR, subj, 'training_split', 'alexnet_pca.npy'), feats)



In [None]:
# with Pool(2) as p:
#    p.map(run_subj, subjs)
# run last 4 subjects in parallel
for subj in subjs[4:]:
    print(f'running {subj}')
    run_subj(subj)
    print()