# Imports

In [None]:
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import joblib
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from transformers import ViTModel, ViTFeatureExtractor


In [None]:
df = pd.read_excel('Dados para estudo.xlsx')




In [None]:
df['image_paths'] = df['Imagens'].apply(
    lambda x: [os.path.join(r'..\data\extracted_images', img.strip()) for img in x.split(',')]
)

VIT

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

Codificação das labels


In [None]:
le_regiao = LabelEncoder()
le_especie = LabelEncoder()

df['regiao_label'] = le_regiao.fit_transform(df['regiao_corpo'])
df['especie_label'] = le_especie.fit_transform(df['Espécie'])

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

# Configurando Dataset

In [None]:
class MultilabelDataset(Dataset):
    def __init__(self, dataframe, feature_extractor, transform=None, max_images=10):
        self.data = dataframe
        self.feature_extractor = feature_extractor
        self.transform = transform
        self.max_images = max_images

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_paths = self.data.iloc[idx]['image_paths']
        
        images = []
        for img_path in image_paths[:self.max_images]:  # Truncar se exceder
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            else:
                image = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].squeeze()
            images.append(image)
        
        n_imgs = len(images)

        # Padding com imagens "em branco"
        if n_imgs < self.max_images:
            padding = self.max_images - n_imgs
            pad_image = torch.zeros_like(images[0])
            images += [pad_image] * padding
        
        images = torch.stack(images)

        regiao = self.data.iloc[idx]['regiao_label']
        especie = self.data.iloc[idx]['especie_label']
        
        return images, torch.tensor(regiao), torch.tensor(especie)


# Treinando DataSet

In [4]:
train_dataset = MultilabelDataset(train_df, feature_extractor)
val_dataset = MultilabelDataset(val_df, feature_extractor)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)


# Multilabel

In [None]:
class ViTMultilabel(nn.Module):
    def __init__(self, num_regioes, num_especies):
        super(ViTMultilabel, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        hidden_size = self.vit.config.hidden_size
        self.classifier_regiao = nn.Linear(hidden_size, num_regioes)
        self.classifier_especie = nn.Linear(hidden_size, num_especies)
    
    def forward(self, x):
        batch_size, n_imgs, C, H, W = x.shape
        x = x.view(-1, C, H, W) 
        
        outputs = self.vit(pixel_values=x)
        pooled_output = outputs.pooler_output
        
        # Volta a separar por exame
        pooled_output = pooled_output.view(batch_size, n_imgs, -1) 
        
        # Agrega: média das features
        agg_output = pooled_output.mean(dim=1)  
        
        regiao_logits = self.classifier_regiao(agg_output)
        especie_logits = self.classifier_especie(agg_output)
        
        return regiao_logits, especie_logits


Device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:

model = ViTMultilabel(
    num_regioes=len(le_regiao.classes_), 
    num_especies=len(le_especie.classes_)
).to(device)

Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()

In [None]:

optimizer = optim.AdamW(model.parameters(), lr=5e-5)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Treino de Modelo

In [7]:
epochs = 10

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for imgs, regiao_labels, especie_labels in train_loader:
        imgs = imgs.to(device)
        regiao_labels = regiao_labels.to(device)
        especie_labels = especie_labels.to(device)
        
        optimizer.zero_grad()
        regiao_out, especie_out = model(imgs)
        
        loss_regiao = criterion(regiao_out, regiao_labels)
        loss_especie = criterion(especie_out, especie_labels)
        
        loss = loss_regiao + loss_especie
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")


Epoch 1/10, Loss: 1.4996
Epoch 2/10, Loss: 0.7190
Epoch 3/10, Loss: 0.4514
Epoch 4/10, Loss: 0.2941
Epoch 5/10, Loss: 0.2202
Epoch 6/10, Loss: 0.1762
Epoch 7/10, Loss: 0.2831
Epoch 8/10, Loss: 0.1304
Epoch 9/10, Loss: 0.1034
Epoch 10/10, Loss: 0.0831


# Validando modelo

In [8]:
model.eval()
correct_regiao = 0
correct_especie = 0
total = 0

with torch.no_grad():
    for imgs, regiao_labels, especie_labels in val_loader:
        imgs = imgs.to(device)
        regiao_labels = regiao_labels.to(device)
        especie_labels = especie_labels.to(device)
        
        regiao_out, especie_out = model(imgs)
        
        _, pred_regiao = torch.max(regiao_out, 1)
        _, pred_especie = torch.max(especie_out, 1)
        
        correct_regiao += (pred_regiao == regiao_labels).sum().item()
        correct_especie += (pred_especie == especie_labels).sum().item()
        total += regiao_labels.size(0)

print(f"Val Acc Região: {correct_regiao/total:.2f}, Espécie: {correct_especie/total:.2f}")


Val Acc Região: 0.96, Espécie: 0.96


# Salvando modelo


In [9]:
torch.save(model.state_dict(), 'vit_multilabel_model.pth')

In [10]:
joblib.dump(le_regiao, 'le_regiao.pkl')
joblib.dump(le_especie, 'le_especie.pkl')


['le_especie.pkl']