In [None]:
import os
import numpy as np
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Model
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

# -------------------------------
# 1. Baixar dataset do TensorFlow
# -------------------------------
pasta_imagens = "imagens"
os.makedirs(pasta_imagens, exist_ok=True)

if len(os.listdir(pasta_imagens)) == 0:
    dataset, info = tfds.load("cats_vs_dogs", split="train[:6]", with_info=True, as_supervised=True)
    for i, (img, label) in enumerate(tfds.as_numpy(dataset)):
        caminho = os.path.join(pasta_imagens, f"img{i+1}.jpg")
        plt.imsave(caminho, img.astype("uint8"))
    print("✅ 6 imagens salvas na pasta 'imagens/'")
else:
    print("📂 Imagens já existem na pasta.")

# -------------------------------
# 2. Carregar modelo pré-treinado
# -------------------------------
base_model = ResNet50(weights="imagenet", include_top=False, pooling="avg")
model = Model(inputs=base_model.input, outputs=base_model.output)

# -------------------------------
# 3. Função para extrair embeddings
# -------------------------------
def extrair_embedding(caminho_img):
    img = image.load_img(caminho_img, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    features = model.predict(x, verbose=0)
    return features.flatten()

# -------------------------------
# 4. Processar todas as imagens
# -------------------------------
arquivos = [os.path.join(pasta_imagens, f) for f in os.listdir(pasta_imagens) if f.endswith((".jpg", ".png", ".jpeg"))]

if len(arquivos) == 0:
    raise Exception("❌ Nenhuma imagem encontrada na pasta 'imagens/'.")

print(f"📷 {len(arquivos)} imagens encontradas.")

embeddings = [extrair_embedding(img) for img in arquivos]

# -------------------------------
# 5. Calcular similaridade
# -------------------------------
similaridades = cosine_similarity(embeddings)

# -------------------------------
# 6. Mostrar matriz de similaridade
# -------------------------------
plt.figure(figsize=(8,6))
plt.imshow(similaridades, cmap="viridis")
plt.colorbar()
plt.title("Matriz de Similaridade (Cosine)", fontsize=14)
plt.xticks(range(len(arquivos)), [f"img{i+1}" for i in range(len(arquivos))], rotation=45)
plt.yticks(range(len(arquivos)), [f"img{i+1}" for i in range(len(arquivos))])
plt.show()


Mounted at /content/drive
