In [12]:
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image
import torch
from transformers import ViTModel, ViTImageProcessor, AutoModel, CLIPImageProcessor
import timm
import open_clip
from torchvision import transforms
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


device = "cuda" if torch.cuda.is_available() else "cpu"


def load_models(model_name):
    # Modelos baseados em ViT (transformers e auto model)
    if model_name == 'ViT_huge':
        model = ViTModel.from_pretrained('google/vit-huge-patch14-224-in21k')
        feature_extractor = ViTImageProcessor.from_pretrained('google/vit-huge-patch14-224-in21k')
    elif model_name == 'ViT_large':
        model = ViTModel.from_pretrained('google/vit-large-patch16-224-in21k')
        feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch16-224-in21k')
    elif model_name == 'ViT_base':
        model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
    elif model_name == 'ViT_small':
        model = AutoModel.from_pretrained('WinKawaks/vit-small-patch16-224')
        feature_extractor = ViTImageProcessor.from_pretrained('WinKawaks/vit-small-patch16-224')
    elif model_name == 'VITAmin':
        # Carrega o modelo ViTami
        model = timm.create_model('vitamin_large_384', pretrained=True, num_classes=0)
        model.eval()


        data_config = timm.data.resolve_model_data_config(model)
        feature_extractor = timm.data.create_transform(**data_config, is_training=False)
        
        #model = AutoModel.from_pretrained('jienengchen/ViTamin-XL-384px', trust_remote_code=True)
        #feature_extractor = CLIPImageProcessor.from_pretrained('jienengchen/ViTamin-XL-384px')
    # Modelo openclip
    elif model_name == 'openclip_vitg14':
        model, _, feature_extractor = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
    # Modelo mambaout (já existente)
    elif model_name == 'mambaout':
        model = timm.create_model('mambaout_base_plus_rw.sw_e150_r384_in12k_ft_in1k', pretrained=True, num_classes=0)
        data_config = timm.data.resolve_model_data_config(model)
        feature_extractor = timm.data.create_transform(**data_config, is_training=False)
    # Modelo MahmoodLab/UNI (ViT-L/16 via DINOv2)
    #elif model_name == 'UNI':
        #from huggingface_hub import login
        #from timm.data import resolve_data_config
        #from timm.data.transforms_factory import create_transform
        
        # Faça login com HF_TOKEN antes de rodar
        #login()
        # Carrega backbone ViT-L/16 pré-treinado em Mass-100K (DINOv2)
        #model = timm.create_model('hf-hub:MahmoodLab/uni',pretrained=True, init_values=1e-5,dynamic_img_size=True)
        # Cria transform baseado na pré-configuração do modelo
        #data_config = resolve_data_config(model.pretrained_cfg, model=model)
        #feature_extractor = create_transform(**data_config)
    
    # Modelo prov-gigapath/prov-gigapath (tile encoder)
    #elif model_name == 'prov_gigapath':
        #import timm
        #from torchvision import transforms
        
        # Tile encoder pré-treinado
    #    model = timm.create_model(
    #        "hf_hub:prov-gigapath/prov-gigapath",
    #        pretrained=True
    #    )
        # Transforms recomendados: Resize→CenterCrop→ToTensor→Normalize
    #    feature_extractor = transforms.Compose([
    #        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
    #        transforms.CenterCrop(224),
    #        transforms.ToTensor(),
    #        transforms.Normalize(mean=(0.485, 0.456, 0.406),
    #                             std=(0.229, 0.224, 0.225))
    #    ])   
    # Novas arquiteturas do timm/huggingface
    elif model_name in ['resnet18', 'squeezenet1_0', 'resnet50', 'resnet101', 'efficientnet_b0', 
                        'inception_resnet_v2', 'nasnetalarge', 'inception_v3', 'xception', 'darknet53', 
                        'vit_so400m_patch14_siglip_378.webli_ft_in1k', 
                        'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', 'mobilenetv4_hybrid_large.ix_e600_r384_in1k', 
                        'convnextv2_huge.fcmae_ft_in22k_in1k_384']:
        # Certifique-se que os nomes abaixo correspondem aos identificadores disponíveis no timm
        model = timm.create_model(model_name, pretrained=True, num_classes=0)
        data_config = timm.data.resolve_model_data_config(model)
        feature_extractor = timm.data.create_transform(**data_config, is_training=False)
    else:
        raise ValueError(f"Modelo {model_name} não suportado.")
    
    model.eval()
    return model.to(device), feature_extractor


def feature_extraction(image_path, model, feature_extractor, model_name):
    image = Image.open(image_path).convert('RGB')
    
    # Para modelos Transformers (ViTs e afins)
    if model_name in ['ViT_huge', 'ViT_large', 'ViT_base', 'ViT_small']:
        inputs = feature_extractor(images=image, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
        # Aqui estamos usando a média dos embeddings da última camada
        features = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    
    # Para modelo openclip
    elif 'openclip' in model_name:
        image_tensor = feature_extractor(image).unsqueeze(0).to(device)
        with torch.no_grad():
            feats = model.encode_image(image_tensor)
        features = feats.squeeze().cpu().numpy()
    
    # Para modelos que possuem método forward_features (alguns modelos do timm)
    elif hasattr(model, 'forward_features'):
        input_tensor = feature_extractor(image).unsqueeze(0).to(device)
        with torch.no_grad():
            feats = model.forward_features(input_tensor)
            feats = model.forward_head(feats, pre_logits=True)
        features = feats.squeeze().cpu().numpy()
    
    # Para modelos do timm (demais)
    else:
        input_tensor = feature_extractor(image)
        if not torch.is_tensor(input_tensor):
            raise ValueError("A transformação do timm não retornou um tensor.")
        if input_tensor.ndim == 3:
            input_tensor = input_tensor.unsqueeze(0)
        input_tensor = input_tensor.to(device)
        with torch.no_grad():
            feats = model(input_tensor)
        features = feats.squeeze().cpu().numpy()
    
    return features


def features_to_df(folder_path, model, feature_extractor, model_name):
    data = []
    for subfolder in os.listdir(folder_path):
        subfolder_path = os.path.join(folder_path, subfolder)
        if os.path.isdir(subfolder_path):
            for image_file in tqdm(os.listdir(subfolder_path), desc=f"Processando {subfolder}"):
                image_path = os.path.join(subfolder_path, image_file)
                if image_file.lower().endswith(('png', 'jpg', 'jpeg', '.bmp')):
                    try:
                        feats = feature_extraction(image_path, model, feature_extractor, model_name)
                        data.append([image_path, *feats])
                    except Exception as e:
                        print(f"Erro ao processar a imagem {image_path}: {e}")
    if data:
        num_features = len(data[0]) - 1
    else:
        num_features = 0
    columns = ['image_path'] + [f'feature_{i}' for i in range(num_features)]
    df = pd.DataFrame(data, columns=columns)
    return df


def save_dataframe_to_csv(df, save_path, file_name, model_name):
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    full_path = os.path.join(save_path, file_name + '.csv')
    df.to_csv(full_path, index=False)
    print(f"Arquivos do modelo {model_name} salvos em: {full_path}")


if __name__ == "__main__":
    model_choices = [
        'ViT_large', 
        #'ViT_huge', 'ViT_large', 'ViT_base', 'ViT_small', 'VITAmin', 
        #'mambaout', 'openclip_vitg14',
       #
        #'resnet18', 'squeezenet1_0', 'resnet50', 'resnet101', 'efficientnet_b0', 
        #'inception_resnet_v2', 'nasnetalarge', 'inception_v3', 'xception', 'darknet53', 
        #'vit_so400m_patch14_siglip_378.webli_ft_in1k', 
        #'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k', 
        #'mobilenetv4_hybrid_large.ix_e600_r384_in1k', 'convnextv2_huge.fcmae_ft_in22k_in1k_384',
        
    ]
    
    #fonte_folder = '/home/diego/Documents/posdoc_2025/database/FMD/images/'
    #fonte_folder = '/home/diego/code/features/posdoc_2025/FMD/'
    #salvar_folder = '/home/diego/code/features/posdoc_2025/feature_vector/FMD/'


    fonte_folder = 'D:/trabFinalIA/simpsons/Train'
    salvar_folder = 'D://trabFinalIA'
        
    #fonte_folder = 'home/diego/code/features/SOYPR/Fold1/'
    #salvar_folder ='home/diego/code/features/posdoc_2025/python_code/experimentosSOYPR/'
    


    for model_choice in model_choices:
        try:
            print(f"Processando com o modelo: {model_choice}")
            model, feat_ext = load_models(model_choice)
            df = features_to_df(fonte_folder, model, feat_ext, model_choice)
            file_name = f'result_final_{model_choice}'
            save_dataframe_to_csv(df, salvar_folder, file_name, model_choice)
        except Exception as e:
            print(f"Erro ao processar o modelo {model_choice}: {e}")


Processando com o modelo: ViT_large


Processando bart: 100%|██████████| 78/78 [00:29<00:00,  2.68it/s]
Processando homer: 100%|██████████| 61/61 [00:20<00:00,  2.96it/s]
Processando lisa: 100%|██████████| 33/33 [00:10<00:00,  3.09it/s]
Processando maggie: 100%|██████████| 30/30 [00:09<00:00,  3.09it/s]
Processando marge: 100%|██████████| 24/24 [00:09<00:00,  2.62it/s]


Arquivos do modelo ViT_large salvos em: D://trabFinalIA\result_final_ViT_large.csv
