In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
from torchvision.transforms.autoaugment import AutoAugmentPolicy

from diffusers import StableDiffusionPipeline

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [2]:
!unzip data.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train_images/283.jpg  
  inflating: __MACOSX/data/train_images/._283.jpg  
  inflating: data/train_images/4647.jpg  
  inflating: __MACOSX/data/train_images/._4647.jpg  
  inflating: data/train_images/3128.jpg  
  inflating: __MACOSX/data/train_images/._3128.jpg  
  inflating: data/train_images/2236.jpg  
  inflating: __MACOSX/data/train_images/._2236.jpg  
  inflating: data/train_images/5559.jpg  
  inflating: __MACOSX/data/train_images/._5559.jpg  
  inflating: data/train_images/6050.jpg  
  inflating: __MACOSX/data/train_images/._6050.jpg  
  inflating: data/train_images/1059.jpg  
  inflating: __MACOSX/data/train_images/._1059.jpg  
  inflating: data/train_images/3896.jpg  
  inflating: __MACOSX/data/train_images/._3896.jpg  
  inflating: data/train_images/2550.jpg  
  inflating: __MACOSX/data/train_images/._2550.jpg  
  inflating: data/train_images/4121.jpg  
  inflating: __MACOSX/data/train_images/

In [3]:
# Generate novel training data
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")  # Use GPU for speed

novel_super_class_count = 1000
novel_sub_class_count = 20

# Create a prompt that excludes dogs, birds, and reptiles
for i in range(int(novel_super_class_count/10)):
  prompt = f"batch {i}: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off"
  negative_prompt = "dog, bird, reptile, no animal"
  print(prompt)
  # Generate the image
  images = pipe(
      prompt=prompt,
      negative_prompt=negative_prompt,
      # height=64,
      # width=64,
      num_inference_steps=30,
      guidance_scale=12,
      height=512,
      width=512,
      num_images_per_prompt=10
      ).images
  for j in range(10):
    image = images[j]
  # Save the image
    image.save(f"data/train_images/{j+i*10+6288}.jpg")

existing_subclasses = pd.read_csv("data/subclass_mapping.csv")["class"].tolist()
subclass_negative_prompt = f"{', '.join(existing_subclasses)}, dog, reptile, no animal, low quality, blurry"

for i, super_class in enumerate(["bird", "dog", "reptile"]):
  for m in range(int(novel_sub_class_count/10)):
    subclass_prompt = f"batch {m}: nature photograph of a {super_class}, centered composition, photorealistic, the {super_class} is more than 50% of the image, the {super_class} is fully visible within frame and not cut off"
        # Generate the image
    images = pipe(
            prompt=subclass_prompt,
            negative_prompt=subclass_negative_prompt,
            num_inference_steps=30,
            guidance_scale=12,
            height=512,
            width=512,
            num_images_per_prompt=10
        ).images

    for j in range(10):
        image = images[j]
        # Save with continuing index from previous loop
        image.save(f"data/train_images/{j+m*10+novel_sub_class_count*i+6288+novel_super_class_count}.jpg")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/541 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/492M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/342 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.72k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

scheduler_config.json:   0%|          | 0.00/308 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/617 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/547 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/806 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

batch 0: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 1: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 2: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 3: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 4: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 5: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 6: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 7: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 8: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 9: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 10: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 11: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 12: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 13: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 14: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 15: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 16: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 17: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 18: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 19: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 20: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 21: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 22: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 23: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 24: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 25: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 26: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 27: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 28: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 29: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 30: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 31: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 32: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 33: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 34: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 35: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 36: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 37: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 38: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 39: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 40: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 41: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 42: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 43: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 44: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 45: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 46: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 47: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 48: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 49: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 50: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 51: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 52: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 53: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 54: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 55: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 56: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 57: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 58: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 59: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 60: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 61: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 62: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 63: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 64: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 65: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 66: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 67: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 68: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 69: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 70: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 71: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


batch 72: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 73: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 74: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 75: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 76: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 77: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 78: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 79: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 80: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 81: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 82: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 83: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 84: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 85: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 86: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 87: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 88: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 89: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 90: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 91: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 92: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 93: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 94: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 95: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 96: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 97: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 98: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

batch 99: nature photograph of a sea animal, like fish, shark, etc, centered composition, photorealistic, 8k, the animal is fully visible within frame and not cut off


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [4]:
# Create Dataset class for multilabel classification
class MultiClassImageDataset(Dataset):
    def __init__(self, ann_df, super_map_df, sub_map_df, img_dir, transform=None):
        self.ann_df = ann_df
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.ann_df)

    def __getitem__(self, idx):
        img_name = self.ann_df['image'][idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        super_idx = self.ann_df['superclass_index'][idx]
        super_label = self.super_map_df['class'][super_idx]

        sub_idx = self.ann_df['subclass_index'][idx]
        sub_label = self.sub_map_df['class'][sub_idx]

        if self.transform:
            image = self.transform(image)

        return image, super_idx, super_label, sub_idx, sub_label

class MultiClassImageTestDataset(Dataset):
    def __init__(self, super_map_df, sub_map_df, img_dir, transform=None):
        self.super_map_df = super_map_df
        self.sub_map_df = sub_map_df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self): # Count files in img_dir
        return len([fname for fname in os.listdir(self.img_dir)])

    def __getitem__(self, idx):
        img_name = str(idx) + '.jpg'
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, img_name

In [5]:
def calculate_weights(train_ann_df):
    # Assuming you have the counts for both superclass and subclass
    superclass_counts = train_ann_df['superclass_index'].value_counts().sort_index()
    subclass_counts = train_ann_df['subclass_index'].value_counts().sort_index()

    # Calculate weights for both
    superclass_weights = 1.0 / superclass_counts
    superclass_weights = superclass_weights / superclass_weights.sum()
    superclass_weights = torch.tensor(superclass_weights.values, dtype=torch.float32)

    subclass_weights = 1.0 / subclass_counts
    subclass_weights = subclass_weights / subclass_weights.sum()
    subclass_weights = torch.tensor(subclass_weights.values, dtype=torch.float32)

    # superclass_weights = torch.cat([superclass_weights, torch.tensor([0.0], dtype=torch.float32)])
    # subclass_weights = torch.cat([subclass_weights, torch.tensor([0.0], dtype=torch.float32)])

    print(superclass_weights)
    print(subclass_weights)
    return superclass_weights, subclass_weights

In [6]:
train_ann_df = pd.read_csv('data/train_data.csv')

# Add new rows with superclass 3 and subclass 87
new_rows_super = pd.DataFrame({
    'image': [f"{i}.jpg" for i in range(6288, 6288+novel_super_class_count)],  # Generate image paths from 6288.JPG to 6387.JPG
    'superclass_index': [3] * novel_super_class_count,
    'subclass_index': [87] * novel_super_class_count
})
train_ann_df = pd.concat([train_ann_df, new_rows_super], ignore_index=True)
new_rows_bird = pd.DataFrame({
    'image': [f"{i}.jpg" for i in range(6288+novel_super_class_count, 6288+novel_super_class_count+novel_sub_class_count)],  # Generate image paths from 6288.JPG to 6387.JPG
    'superclass_index': [0] * novel_sub_class_count,
    'subclass_index': [87] * novel_sub_class_count
})
train_ann_df = pd.concat([train_ann_df, new_rows_bird], ignore_index=True)
new_rows_dog = pd.DataFrame({
    'image': [f"{i}.jpg" for i in range(6288+novel_super_class_count+novel_sub_class_count, 6288+novel_super_class_count+novel_sub_class_count*2)],  # Generate image paths from 6288.JPG to 6387.JPG
    'superclass_index': [1] * novel_sub_class_count,
    'subclass_index': [87] * novel_sub_class_count
})
train_ann_df = pd.concat([train_ann_df, new_rows_dog], ignore_index=True)
new_rows_reptile = pd.DataFrame({
    'image': [f"{i}.jpg" for i in range(6288+novel_super_class_count+novel_sub_class_count*2, 6288+novel_super_class_count+novel_sub_class_count*3)],  # Generate image paths from 6288.JPG to 6387.JPG
    'superclass_index': [2] * novel_sub_class_count,
    'subclass_index': [87] * novel_sub_class_count
})
train_ann_df = pd.concat([train_ann_df, new_rows_reptile], ignore_index=True)
# print(train_ann_df)

# test_ann_df = pd.read_csv('data/test_data.csv')
super_map_df = pd.read_csv('data/superclass_mapping.csv')
sub_map_df = pd.read_csv('data/subclass_mapping.csv')

train_img_dir = 'data/train_images'
test_img_dir = 'data/test_images'

augmentation_setups = {
    "baseline": transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
    "autoaugment": transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.AutoAugment(policy=AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ]),
    "manual_combo": transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
}

# image_preprocessing = transforms.Compose([
#     transforms.Resize((384, 384)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

superclass_weights, subclass_weights = calculate_weights(train_ann_df)

# Create train and val split
train_dataset = MultiClassImageDataset(train_ann_df, super_map_df, sub_map_df, train_img_dir, transform=augmentation_setups["manual_combo"])
train_dataset, val_dataset = random_split(train_dataset, [0.9, 0.1])

# Create test dataset
test_dataset = MultiClassImageTestDataset(super_map_df, sub_map_df, test_img_dir, transform=augmentation_setups["baseline"])

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(test_dataset,
                         batch_size=1,
                         shuffle=False)

tensor([0.2200, 0.1955, 0.1733, 0.4113])
tensor([0.0150, 0.0074, 0.0150, 0.0147, 0.0074, 0.0147, 0.0074, 0.0074, 0.0147,
        0.0150, 0.0150, 0.0147, 0.0150, 0.0150, 0.0147, 0.0150, 0.0147, 0.0150,
        0.0074, 0.0147, 0.0147, 0.0074, 0.0074, 0.0150, 0.0074, 0.0150, 0.0147,
        0.0074, 0.0074, 0.0073, 0.0074, 0.0074, 0.0150, 0.0144, 0.0150, 0.0072,
        0.0074, 0.0074, 0.0150, 0.0150, 0.0147, 0.0147, 0.0147, 0.0074, 0.0074,
        0.0150, 0.0074, 0.0074, 0.0150, 0.0074, 0.0074, 0.0147, 0.0074, 0.0150,
        0.0150, 0.0147, 0.0147, 0.0073, 0.0074, 0.0147, 0.0147, 0.0073, 0.0074,
        0.0072, 0.0074, 0.0074, 0.0072, 0.0150, 0.0150, 0.0073, 0.0074, 0.0074,
        0.0074, 0.0147, 0.0150, 0.0074, 0.0074, 0.0150, 0.0147, 0.0150, 0.0147,
        0.0073, 0.0147, 0.0147, 0.0074, 0.0150, 0.0147, 0.0007])


In [7]:
from re import sub
class EfficientNetV2MultiHead(nn.Module):
    def __init__(self, num_super=4, num_sub=88, pretrained=True):
        super().__init__()
        weights = EfficientNet_V2_S_Weights.DEFAULT if pretrained else None
        pre_trained = efficientnet_v2_s(weights=weights)
        self.backbone = pre_trained.features
        self.avgpool = pre_trained.avgpool
        self.fc = nn.Linear(1280, 256)
        self.dropout = nn.Dropout(0.3)
        self.classifer_super = nn.Linear(256, num_super)
        self.classifer_sub = nn.Linear(256 + num_super, num_sub)

    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc(x))
        x = self.dropout(x)
        # print(x.size())
        super_out = self.classifer_super(x)
        # print(super_out)

        super_probs = F.softmax(super_out, dim=1)
        x_with_super = torch.cat([x, super_out], dim=1)

        sub_out = self.classifer_sub(x_with_super)
        sub_probs = F.softmax(sub_out, dim=1)
        return super_out, sub_out

class Trainer():
    def __init__(self, model, super_criterion, sub_criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model
        self.super_criterion = super_criterion
        self.sub_criterion = sub_criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader

    def train_epoch(self):
        running_loss = 0.0
        for i, data in enumerate(self.train_loader):
            inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

            self.optimizer.zero_grad()
            super_outputs, sub_outputs = self.model(inputs)
            loss = self.super_criterion(super_outputs, super_labels) + self.sub_criterion(sub_outputs, sub_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Training loss: {running_loss/i:.3f}')

    def validate_epoch(self):
        super_correct = 0
        sub_correct = 0
        total = 0
        running_loss = 0.0
        running_super_loss = 0.0
        running_sub_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                inputs, super_labels, sub_labels = data[0].to(device), data[1].to(device), data[3].to(device)

                super_outputs, sub_outputs = self.model(inputs)
                super_loss = self.super_criterion(super_outputs, super_labels)
                sub_loss = self.sub_criterion(sub_outputs, sub_labels)
                loss = super_loss + sub_loss
                _, super_predicted = torch.max(super_outputs.data, 1)
                _, sub_predicted = torch.max(sub_outputs.data, 1)

                total += super_labels.size(0)
                super_correct += (super_predicted == super_labels).sum().item()
                sub_correct += (sub_predicted == sub_labels).sum().item()
                running_loss += loss.item()
                running_super_loss += super_loss.item()
                running_sub_loss += sub_loss.item()

        print(f'Validation loss: {running_loss/i:.3f}')
        print(f'Validation superclass loss: {running_super_loss/i:.3f}')
        print(f'Validation subclass loss: {running_sub_loss/i:.3f}')
        print(f'Validation superclass acc: {100 * super_correct / total:.2f} %')
        print(f'Validation subclass acc: {100 * sub_correct / total:.2f} %')

    def test(self, save_to_csv=False, return_predictions=False):
        # threshold = 0.4
        if not self.test_loader:
            raise NotImplementedError('test_loader not specified')

        # Evaluate on test set, in this simple demo no special care is taken for novel/unseen classes
        test_predictions = {'image': [], 'superclass_index': [], 'subclass_index': []}
        with torch.no_grad():
            for i, data in enumerate(self.test_loader):
                inputs, img_name = data[0].to(device), data[1]

                super_outputs, sub_outputs = self.model(inputs)
                super_probs = F.softmax(super_outputs, dim=1)
                sub_probs = F.softmax(sub_outputs, dim=1)

                super_max_probs, super_predicted = torch.max(super_probs, 1)
                sub_max_probs, sub_predicted = torch.max(sub_probs, 1)

                test_predictions['image'].append(img_name[0])
                test_predictions['superclass_index'].append(super_predicted.item())
                test_predictions['subclass_index'].append(sub_predicted.item())

        test_predictions = pd.DataFrame(data=test_predictions)

        if save_to_csv:
            test_predictions.to_csv('example_test_predictions.csv', index=False)

        if return_predictions:
            return test_predictions

In [8]:
# Init model and trainer
device = 'cuda'
num_super_classes = 4
num_sub_classes = 88
model = EfficientNetV2MultiHead(num_super=num_super_classes, num_sub=num_sub_classes, pretrained=True).to(device)
# print(model)

superclass_weights = superclass_weights.to(device)
subclass_weights = subclass_weights.to(device)
super_criterion = nn.CrossEntropyLoss(weight=superclass_weights)
sub_criterion = nn.CrossEntropyLoss(weight=subclass_weights)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
trainer = Trainer(model, super_criterion, sub_criterion, optimizer, train_loader, val_loader, test_loader)

Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 237MB/s]


In [None]:
!ls data/train_images/

In [None]:
# Training loop
for param in model.backbone.parameters():
    param.requires_grad = False
last_blocks = list(model.backbone)[-2:]  # Unfreeze last 3 blocks
for block in last_blocks:
    for param in block.parameters():
        param.requires_grad = True

print(f'Start training:')

for epoch in range(20):
    print(f'Epoch {epoch+1}')
    trainer.train_epoch()
    trainer.validate_epoch()
    print('')

print('Finished Training')

Start training:
Epoch 1


In [42]:
test_predictions = trainer.test(save_to_csv=True, return_predictions=True)

In [43]:
# Quick script for evaluating generated csv files with ground truth

super_correct = 0
sub_correct = 0
seen_super_correct = 0
seen_sub_correct = 0
unseen_super_correct = 0
unseen_sub_correct = 0

total = 0
seen_super_total = 0
unseen_super_total = 0
seen_sub_total = 0
unseen_sub_total = 0

for i in range(100):
    super_pred = test_predictions['superclass_index'][i]
    sub_pred = test_predictions['subclass_index'][i]

    super_gt = test_ann_df['superclass_index'][i]
    sub_gt = test_ann_df['subclass_index'][i]

    # Total setting
    if super_pred == super_gt:
        super_correct += 1
    if sub_pred == sub_gt:
        sub_correct += 1
    total += 1

    # Unseen superclass setting
    if super_gt == 3:
        if super_pred == super_gt:
            unseen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        unseen_super_total += 1
        unseen_sub_total += 1

    # Seen superclass, unseen subclass setting
    if super_gt != 3 and sub_gt == 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            unseen_sub_correct += 1
        seen_super_total += 1
        unseen_sub_total += 1

    # Seen superclass and subclass setting
    if super_gt != 3 and sub_gt != 87:
        if super_pred == super_gt:
            seen_super_correct += 1
        if sub_pred == sub_gt:
            seen_sub_correct += 1
        seen_super_total += 1
        seen_sub_total += 1

print('Superclass Accuracy')
print(f'Overall: {100*super_correct/total:.2f} %')
print(f'Seen: {100*seen_super_correct/seen_super_total:.2f} %')
print(f'Unseen: {100*unseen_super_correct/unseen_super_total:.2f} %')

print('\nSubclass Accuracy')
print(f'Overall: {100*sub_correct/total:.2f} %')
print(f'Seen: {100*seen_sub_correct/seen_sub_total:.2f} %')
print(f'Unseen: {100*unseen_sub_correct/unseen_sub_total:.2f} %')

Superclass Accuracy
Overall: 37.00 %
Seen: 52.86 %
Unseen: 0.00 %

Subclass Accuracy
Overall: 2.00 %
Seen: 1.43 %
Unseen: 3.33 %
