# Redes Prototipicas y Few shot learning

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import os
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import torch.optim as optim

class PrototypicalNetwork(nn.Module):
    def __init__(self):
        super(PrototypicalNetwork, self).__init__()
        weights = ResNet18_Weights.DEFAULT
        self.encoder = resnet18(weights=weights)
        self.encoder.fc = nn.Flatten()

    def forward(self, x):
        return self.encoder(x)

def train_episode(model, support_images, support_labels, query_images, query_labels, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # Mover al dispositivo adecuado y remover dimensiones innecesarias
    support_images = support_images.to(device).squeeze(0)  # Remover la dimensión extra de batch
    query_images = query_images.to(device).squeeze(0)
    support_labels = support_labels.to(device).squeeze(0)  # Asegúrate de que las etiquetas son 1D
    query_labels = query_labels.to(device).squeeze(0)
    
    # Obtener los embeddings
    support_embeddings = model(support_images)
    query_embeddings = model(query_images)
    
    # Calcular prototipos
    unique_labels = torch.unique(support_labels)
    prototypes = torch.stack([support_embeddings[support_labels == label].mean(0) for label in unique_labels])

    # Aquí puedes imprimir los embeddings si es necesario, después de calcular unique_labels
    # print(f"Embeddings de soporte para cada clase: {[support_embeddings[support_labels == label].mean(0) for label in unique_labels]}")

    distances = torch.cdist(query_embeddings, prototypes)
    log_p_y = torch.log_softmax(-distances, dim=1)
    loss = nn.NLLLoss()(log_p_y, query_labels)
    loss.backward()
    optimizer.step()

    # Calcular accuracy
    y_hat = log_p_y.argmax(1)
    correct_pred = torch.eq(y_hat, query_labels).sum().item()
    total = query_labels.size(0)
    accuracy = correct_pred / total

    return loss.item(), accuracy




def train(model, train_dataset, optimizer, n_way, k_shot, q_query, epochs=20):
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        total_accuracy = 0
        for episode in train_loader:
            support_images, support_labels, query_images, query_labels = episode

            loss, accuracy = train_episode(model, support_images, support_labels, query_images, query_labels, optimizer)
            total_loss += loss
            total_accuracy += accuracy

        avg_loss = total_loss / len(train_loader)
        avg_accuracy = total_accuracy / len(train_loader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")

class FewShotDataset(Dataset):
    def __init__(self, root_dir, n_way, k_shot, q_query, transform=None):
        self.root_dir = root_dir
        self.class_folders = os.listdir(root_dir)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.transform = transform
        
        # Verificar si cada clase tiene suficientes imágenes
        for class_name in self.class_folders:
            class_dir = os.path.join(self.root_dir, class_name)
            image_files = os.listdir(class_dir)
            if len(image_files) < (self.k_shot + self.q_query):
                raise ValueError(f"La clase {class_name} no tiene suficientes imágenes. Requiere al menos {self.k_shot + self.q_query} imágenes.")

    def __len__(self):
        return 1000  # Este valor puede ajustarse según el número de iteraciones deseadas por época

    def __getitem__(self, idx):
        # Intentar formar un episodio válido. Si no es posible, levantar una excepción
        for attempt in range(10):  # Número de intentos para formar un episodio válido
            try:
                episode_classes = np.random.choice(self.class_folders, self.n_way, replace=False)
                support_images = []
                query_images = []
                support_labels = []
                query_labels = []

                for i, class_name in enumerate(episode_classes):
                    class_dir = os.path.join(self.root_dir, class_name)
                    image_files = os.listdir(class_dir)
                    selected_files = np.random.choice(image_files, self.k_shot + self.q_query, replace=False)
                    support_files = selected_files[:self.k_shot]
                    query_files = selected_files[self.k_shot:]

                    for file_name in support_files:
                        img_path = os.path.join(class_dir, file_name)
                        img = Image.open(img_path).convert('RGB')
                        if self.transform:
                            img = self.transform(img)
                        support_images.append(img)
                        support_labels.append(i)

                    for file_name in query_files:
                        img_path = os.path.join(class_dir, file_name)
                        img = Image.open(img_path).convert('RGB')
                        if self.transform:
                            img = self.transform(img)
                        query_images.append(img)
                        query_labels.append(i)

                # Verificar que cada clase esté representada en el conjunto de soporte y consulta
                assert len(set(support_labels)) == self.n_way, "No todas las clases están representadas en el conjunto de soporte"
                assert len(set(query_labels)) == self.n_way, "No todas las clases están representadas en el conjunto de consulta"
                
                support_images = torch.stack(support_images)
                query_images = torch.stack(query_images)
                support_labels = torch.tensor(support_labels)
                query_labels = torch.tensor(query_labels)

                return support_images, support_labels, query_images, query_labels
            except ValueError as e:
                # Imprimir el error y continuar con el siguiente intento
                print(f"No se pudo formar un episodio válido en el intento {attempt+1}: {e}")
                if attempt == 9:
                    raise ValueError("No se pudo formar un episodio válido después de varios intentos.")

# Demás código para inicializar y entrenar el modelo...

# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instanciación del modelo y optimizador
model = PrototypicalNetwork().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Parámetros del entrenamiento
n_way = 5
k_shot = 5
q_query = 15
epochs = 12

# Ejemplo de transformación
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Crear el dataset
train_dataset = FewShotDataset(root_dir='directorio_datos_train', n_way=n_way, k_shot=k_shot, q_query=q_query, transform=transform)

# Comenzar el entrenamiento (asegúrate de ajustar el path del dataset)
train(model, train_dataset, optimizer, n_way, k_shot, q_query, epochs)


# Matriz de Confusion


In [None]:
import torch
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader
import numpy as np


# Parámetros del entrenamiento
n_way = 5
k_shot = 5
q_query = 15
epochs = 12

# Esta función ejecutará el modelo en modo de evaluación y recolectará las predicciones
def test_model(model, test_dataset, n_way, k_shot, q_query):
    model.eval()  # Poner el modelo en modo de evaluación
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
    all_preds = []
    all_labels = []
    
    with torch.no_grad():  # No se necesitan gradientes para la evaluación
        for episode in test_loader:
            support_images, support_labels, query_images, query_labels = episode
            support_images = support_images.to(device).squeeze(0)
            query_images = query_images.to(device).squeeze(0)
            support_labels = support_labels.to(device).squeeze(0)
            query_labels = query_labels.to(device).squeeze(0)
            
            # Obtener los embeddings
            support_embeddings = model(support_images)
            query_embeddings = model(query_images)
            
            # Calcular prototipos
            unique_labels = torch.unique(support_labels)
            prototypes = torch.stack([support_embeddings[support_labels == label].mean(0) for label in unique_labels])

            # Calcular distancias y hacer las predicciones
            distances = torch.cdist(query_embeddings, prototypes)
            log_p_y = torch.log_softmax(-distances, dim=1)
            y_hat = log_p_y.argmax(1)
            
            all_preds.extend(y_hat.cpu().numpy())
            all_labels.extend(query_labels.cpu().numpy())

    return all_preds, all_labels

# Crear el dataset de prueba
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = FewShotDataset(root_dir='directorio_datos_test', n_way=n_way, k_shot=k_shot, q_query=q_query, transform=test_transform)

# Ejecutar el modelo en el conjunto de prueba y recolectar predicciones
predictions, true_labels = test_model(model, test_dataset, n_way, k_shot, q_query)

# Calcular la matriz de confusión
conf_matrix = confusion_matrix(true_labels, predictions)
print(conf_matrix)


In [None]:
import PIL
import matplotlib.pyplot as plt
import seaborn as sns

class_names=['clase_0', 'clase_1', 'clase_2', 'clase_3','clase_4']
fig, ax = plt.subplots(figsize=(8, 8))  # Puedes ajustar el tamaño aquí
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.savefig('matriz.jpg')
plt.show()

# Guarda el modelo

In [None]:

# Guardar el modelo
def save_model(model, optimizer, epoch, file_path="model.pth"):
    # Crear un diccionario con la información que quieres guardar
    state = {
        'epoch': epoch,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
    }
    # Guardar el diccionario en el archivo del camino file_path
    torch.save(state, file_path)



# Ejemplo de cómo guardar el modelo
save_model(model, optimizer, epochs, "prototipicas.pth")

# Ejemplo de cómo cargar el modelo
# Nota: Debes crear la instancia de 'model' y 'optimizer' con las mismas características antes de llamar a esta función
#epoch = load_model(model, optimizer, "/mnt/data/my_prototypical_network.pth")


# Cargar el modelo


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import os
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn as nn
import torch.optim as optim

class PrototypicalNetwork(nn.Module):
    def __init__(self):
        super(PrototypicalNetwork, self).__init__()
        weights = ResNet18_Weights.DEFAULT
        self.encoder = resnet18(weights=weights)
        self.encoder.fc = nn.Flatten()

    def forward(self, x):
        return self.encoder(x)

def train_episode(model, support_images, support_labels, query_images, query_labels, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # Mover al dispositivo adecuado y remover dimensiones innecesarias
    support_images = support_images.to(device).squeeze(0)  # Remover la dimensión extra de batch
    query_images = query_images.to(device).squeeze(0)
    support_labels = support_labels.to(device).squeeze(0)  # Asegúrate de que las etiquetas son 1D
    query_labels = query_labels.to(device).squeeze(0)
    
    # Obtener los embeddings
    support_embeddings = model(support_images)
    query_embeddings = model(query_images)
    
    # Calcular prototipos
    unique_labels = torch.unique(support_labels)
    prototypes = torch.stack([support_embeddings[support_labels == label].mean(0) for label in unique_labels])

    # Aquí puedes imprimir los embeddings si es necesario, después de calcular unique_labels
    # print(f"Embeddings de soporte para cada clase: {[support_embeddings[support_labels == label].mean(0) for label in unique_labels]}")

    distances = torch.cdist(query_embeddings, prototypes)
    log_p_y = torch.log_softmax(-distances, dim=1)
    loss = nn.NLLLoss()(log_p_y, query_labels)
    loss.backward()
    optimizer.step()

    # Calcular accuracy
    y_hat = log_p_y.argmax(1)
    correct_pred = torch.eq(y_hat, query_labels).sum().item()
    total = query_labels.size(0)
    accuracy = correct_pred / total

    return loss.item(), accuracy




def train(model, train_dataset, optimizer, n_way, k_shot, q_query, epochs=20):
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        total_accuracy = 0
        for episode in train_loader:
            support_images, support_labels, query_images, query_labels = episode

            loss, accuracy = train_episode(model, support_images, support_labels, query_images, query_labels, optimizer)
            total_loss += loss
            total_accuracy += accuracy

        avg_loss = total_loss / len(train_loader)
        avg_accuracy = total_accuracy / len(train_loader)
        print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}")

class FewShotDataset(Dataset):
    def __init__(self, root_dir, n_way, k_shot, q_query, transform=None):
        self.root_dir = root_dir
        self.class_folders = os.listdir(root_dir)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query
        self.transform = transform
        
        # Verificar si cada clase tiene suficientes imágenes
        for class_name in self.class_folders:
            class_dir = os.path.join(self.root_dir, class_name)
            image_files = os.listdir(class_dir)
            if len(image_files) < (self.k_shot + self.q_query):
                raise ValueError(f"La clase {class_name} no tiene suficientes imágenes. Requiere al menos {self.k_shot + self.q_query} imágenes.")

    def __len__(self):
        return 1000  # Este valor puede ajustarse según el número de iteraciones deseadas por época

    def __getitem__(self, idx):
        # Intentar formar un episodio válido. Si no es posible, levantar una excepción
        for attempt in range(10):  # Número de intentos para formar un episodio válido
            try:
                episode_classes = np.random.choice(self.class_folders, self.n_way, replace=False)
                support_images = []
                query_images = []
                support_labels = []
                query_labels = []

                for i, class_name in enumerate(episode_classes):
                    class_dir = os.path.join(self.root_dir, class_name)
                    image_files = os.listdir(class_dir)
                    selected_files = np.random.choice(image_files, self.k_shot + self.q_query, replace=False)
                    support_files = selected_files[:self.k_shot]
                    query_files = selected_files[self.k_shot:]

                    for file_name in support_files:
                        img_path = os.path.join(class_dir, file_name)
                        img = Image.open(img_path).convert('RGB')
                        if self.transform:
                            img = self.transform(img)
                        support_images.append(img)
                        support_labels.append(i)

                    for file_name in query_files:
                        img_path = os.path.join(class_dir, file_name)
                        img = Image.open(img_path).convert('RGB')
                        if self.transform:
                            img = self.transform(img)
                        query_images.append(img)
                        query_labels.append(i)

                # Verificar que cada clase esté representada en el conjunto de soporte y consulta
                assert len(set(support_labels)) == self.n_way, "No todas las clases están representadas en el conjunto de soporte"
                assert len(set(query_labels)) == self.n_way, "No todas las clases están representadas en el conjunto de consulta"
                
                support_images = torch.stack(support_images)
                query_images = torch.stack(query_images)
                support_labels = torch.tensor(support_labels)
                query_labels = torch.tensor(query_labels)

                return support_images, support_labels, query_images, query_labels
            except ValueError as e:
                # Imprimir el error y continuar con el siguiente intento
                print(f"No se pudo formar un episodio válido en el intento {attempt+1}: {e}")
                if attempt == 9:
                    raise ValueError("No se pudo formar un episodio válido después de varios intentos.")

# Restaurar el modelo
def load_model(model, optimizer, file_path="model.pth"):
    # Cargar el estado (si existe)
    state = torch.load(file_path, map_location=device)
    
    # Cargar el estado del modelo y del optimizador desde el archivo
    model.load_state_dict(state['model_state'])
    optimizer.load_state_dict(state['optimizer_state'])
    
    return state['epoch']


# Configuración del dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instanciación del modelo y optimizador
model = PrototypicalNetwork().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Nota: Debes crear la instancia de 'model' y 'optimizer' con las mismas características antes de llamar a esta función
epoch = load_model(model, optimizer, "/kaggle/input/modeloo/prototipicas.pth")


# Recall


In [None]:
from sklearn.metrics import recall_score

# Supongamos que 'true_labels' son tus etiquetas verdaderas y 'predictions' las predicciones de tu modelo
recall = recall_score(true_labels, predictions, average=None)  # 'None' calcula el recall para cada clase

print(recall)