In [None]:
# =========================
# Install Dependencies
# =========================
!pip install -q diffusers transformers accelerate safetensors seaborn scikit-learn

# =========================
# HuggingFace Login
# =========================
from huggingface_hub import login
login()

# =========================
# Imports
# =========================
import os
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from diffusers import StableDiffusionPipeline

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# =========================
# Device
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================
# Dog Breeds
# =========================
DOG_BREEDS = [
    "Labrador Retriever","Golden Retriever","German Shepherd","French Bulldog",
    "Bulldog","Poodle (Standard)","Beagle","Rottweiler","Dachshund",
    "Yorkshire Terrier","Boxer","Doberman Pinscher","Siberian Husky",
    "Great Dane","Shih Tzu","Chihuahua","Pug","Border Collie",
    "Australian Shepherd","Belgian Malinois","Akita","Alaskan Malamute",
    "Samoyed","Bernese Mountain Dog","Saint Bernard","Newfoundland",
    "Cane Corso","Greyhound","Whippet","Bloodhound","Basset Hound",
    "Rhodesian Ridgeback","Weimaraner","Vizsla","Cocker Spaniel",
    "Cavalier King Charles Spaniel","Pomeranian","Chow Chow",
    "Shiba Inu","Basenji"
]

# =========================
# Stable Diffusion Loader
# =========================
def load_pipeline():
    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16
    ).to("cuda")
    return pipe

# =========================
# Generate Dataset
# =========================
def generate_dataset(pipe, breeds, output_dir, images_per_class=10):
    os.makedirs(output_dir, exist_ok=True)

    for breed in breeds:
        class_dir = os.path.join(output_dir, breed.replace(" ", "_"))
        os.makedirs(class_dir, exist_ok=True)

        for i in range(images_per_class):
            prompt = f"a high quality photo of a {breed}, ultra realistic, cinematic lighting, 4k"
            image = pipe(
                prompt,
                num_inference_steps=30,
                guidance_scale=7.5
            ).images[0]

            path = os.path.join(class_dir, f"{breed.replace(' ', '_')}_{i}.png")
            image.save(path)
            print("Saved:", path)

# =========================
# Dataloaders
# =========================
def get_loaders(data_dir, batch_size=16):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])

    dataset = datasets.ImageFolder(root=data_dir, transform=transform)

    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size

    train_ds, val_ds = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    return dataset, train_loader, val_loader

# =========================
# Model
# =========================
def build_model(num_classes):
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

# =========================
# Train & Evaluate
# =========================
def train(model, train_loader, val_loader, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(epochs):
        model.train()
        loss_sum = 0.0

        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            loss_sum += loss.item()

        acc = evaluate(model, val_loader)
        print(
            f"Epoch [{epoch+1}/{epochs}] "
            f"Loss: {loss_sum/len(train_loader):.4f} "
            f"Val Acc: {acc:.2f}%"
        )

# =========================
# Evaluation
# =========================
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, 1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return 100 * correct / total

# =========================
# Confusion Matrix
# =========================
def plot_confusion(model, loader, class_names):
    y_true, y_pred = [], []

    model.eval()
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, 1)

            y_true.extend(labels.numpy())
            y_pred.extend(preds.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(14, 12))
    sns.heatmap(
        cm,
        xticklabels=class_names,
        yticklabels=class_names,
        cmap="Blues",
        cbar=True
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.xticks(rotation=90)
    plt.title("Confusion Matrix – ResNet18")
    plt.tight_layout()
    plt.show()

# =========================
# Run Pipeline
# =========================
DATASET_DIR = "40_dog_dataset"

pipe = load_pipeline()
generate_dataset(pipe, DOG_BREEDS, DATASET_DIR, images_per_class=10)

dataset, train_loader, val_loader = get_loaders(DATASET_DIR)

model = build_model(len(dataset.classes))
train(model, train_loader, val_loader, epochs=10)

torch.save(model.state_dict(), "resnet18_40_dog_breeds.pth")
plot_confusion(model, val_loader, dataset.classes)


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.
Flax classes are deprecated and will be removed in Diffusers v1.0.0. We recommend migrating to PyTorch classes or pinning your version of Diffusers.


Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

safety_checker/model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

text_encoder/model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

vae/diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_5.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_6.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_7.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_8.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Labrador_Retriever/Labrador_Retriever_9.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_5.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_6.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_7.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_8.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Golden_Retriever/Golden_Retriever_9.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_5.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_6.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_7.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_8.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/German_Shepherd/German_Shepherd_9.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_5.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_6.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_7.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_8.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/French_Bulldog/French_Bulldog_9.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_5.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_6.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_7.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_8.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Bulldog/Bulldog_9.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_0.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_1.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_2.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_3.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_4.png


  0%|          | 0/30 [00:00<?, ?it/s]

Saved: 40_dog_dataset/Poodle_(Standard)/Poodle_(Standard)_5.png


  0%|          | 0/30 [00:00<?, ?it/s]