# Imports

In [None]:
import torch
import torch.nn as nn
from transformers import ViTModel, ViTFeatureExtractor
import joblib
from PIL import Image
from torchvision import transforms

# Define a classe do modelo

In [None]:
class ViTMultilabel(nn.Module):
    def __init__(self, num_regioes, num_especies):
        super(ViTMultilabel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        hidden_size = self.vit.config.hidden_size
        self.classifier_regiao = nn.Linear(hidden_size, num_regioes)
        self.classifier_especie = nn.Linear(hidden_size, num_especies)
    
    def forward(self, x):
        batch_size, n_imgs, C, H, W = x.shape
        x = x.view(-1, C, H, W)
        outputs = self.vit(pixel_values=x)
        pooled_output = outputs.pooler_output
        pooled_output = pooled_output.view(batch_size, n_imgs, -1)
        agg_output = pooled_output.mean(dim=1)
        regiao_logits = self.classifier_regiao(agg_output)
        especie_logits = self.classifier_especie(agg_output)
        return regiao_logits, especie_logits


  from .autonotebook import tqdm as notebook_tqdm


# Carrega os encoders

In [None]:
le_regiao = joblib.load('le_regiao.pkl')
le_especie = joblib.load('le_especie.pkl')

# Instancia e carrega o modelo


In [None]:
model = ViTMultilabel(
    num_regioes=len(le_regiao.classes_), 
    num_especies=len(le_especie.classes_)
)

model.load_state_dict(torch.load('vit_multilabel_model.pth'))
model.eval()


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTMultilabel(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn)

# Testando a previsão

Setup


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

Carrega e prepara a imagem


In [None]:
# Caminho da imagem nova
image_path = r'' 

image = Image.open(image_path).convert('RGB')
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values'] 
pixel_values = pixel_values.unsqueeze(1) 
pixel_values = pixel_values.to(device)


Predição

In [None]:

with torch.no_grad():
    regiao_logits, especie_logits = model(pixel_values)
    regiao_pred = torch.argmax(regiao_logits, dim=1).cpu().item()
    especie_pred = torch.argmax(especie_logits, dim=1).cpu().item()




Região: ['CORPO_INTEIRO'], Espécie: AVE


Converte para labels legíveis

In [None]:
regiao_label = le_regiao.inverse_transform([regiao_pred])[0]
especie_label = le_especie.inverse_transform([especie_pred])[0]

print(f"Região: {regiao_label}, Espécie: {especie_label}")