In [10]:
import numpy as np
import matplotlib.pyplot as plt

from utils import *

import torchvision.models as models
from torchvision import transforms
from torch import nn
from PIL import Image


In [3]:
data_dir = "../Data_Cropped_and_Resized"

### Feature Extraction

In [11]:
def get_resnet_features(img, model):
    
    model_conv_features = nn.Sequential(*list(model.children())[:-1]).to('cpu')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    if np.max(img) > 1:
        img = img.astype(np.uint8)
    else:
        img = (img * 255.0).astype(np.uint8)
    img = Image.fromarray(img).convert('RGB')
    img = preprocess(img)

    return model_conv_features(img.unsqueeze(0).to('cpu')).squeeze().detach().numpy()

def extract_features(split_path, feature_func, kwargs=None):
    features = []
    images = []
    labels = []

    for label, class_name in class_mappings.items():
        class_path = os.path.join(split_path, class_name)
        for img_name in os.listdir(class_path):
            image = cv2.imread(os.path.join(class_path, img_name), cv2.IMREAD_GRAYSCALE)
            image = preprocess_image(image)
            feat = feature_func(image.copy(), **kwargs)
            images.append(image)
            features.append(feat)
            labels.append(label)

    return images, features, labels


In [12]:
train_path = f"{data_dir}/train"

resnet101 = models.resnet101(pretrained=True)
train_imgs, train_resnet_feat, train_labels = extract_features(train_path, get_resnet_features, {"model": resnet101})


