<a href="https://colab.research.google.com/github/zaviruuu/Naga--ML-Based-Snake-Identifier-for-Sri-Lanka-/blob/snake_identification_model/Snake_Identification_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##NĀGA - SNAKE IDENTIFICATION MODEL

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
SRC_DIR = "/content/drive/MyDrive/DSGP_Group_32/NĀGA/Snake Identification Model/Dataset"
OUT_DIR = "/content/drive/MyDrive/DSGP_Group_32/NĀGA/Snake Identification Model/Output"

In [None]:
#Split ratios (simple + standard)
TRAIN_RATIO = 0.70
VAL_RATIO   = 0.15
TEST_RATIO  = 0.15

#Image settings
IMG_SIZE = (224, 224)
BATCH_SIZE = 16
EPOCHS = 15
SEED = 42

In [None]:
##EDA
#Class distribution
from pathlib import Path
import matplotlib.pyplot as plt

DATA_DIR = "/content/drive/MyDrive/DSGP_Group_32/NĀGA/Snake Identification Model/Dataset"

classes = sorted([d.name for d in Path(DATA_DIR).iterdir() if d.is_dir()])
counts = []
for c in classes:
    counts.append(len([p for p in Path(DATA_DIR, c).glob("*") if p.suffix.lower() in [".jpg",".jpeg",".png"]]))

plt.figure()
plt.bar(classes, counts)
plt.xticks(rotation=45, ha="right")
plt.title("Class Distribution (Images per Class)")
plt.xlabel("Class")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

In [None]:
#Sample images
import random
import matplotlib.pyplot as plt
from PIL import Image

def show_samples(data_dir, classes, n_per_class=3):
    plt.figure(figsize=(n_per_class*3, len(classes)*3))
    k = 1
    for c in classes:
        imgs = [p for p in Path(data_dir, c).glob("*") if p.suffix.lower() in [".jpg",".jpeg",".png"]]
        pick = random.sample(imgs, min(n_per_class, len(imgs)))
        for p in pick:
            img = Image.open(p).convert("RGB")
            plt.subplot(len(classes), n_per_class, k)
            plt.imshow(img)
            plt.axis("off")
            plt.title(c)
            k += 1
    plt.tight_layout()
    plt.show()

show_samples(DATA_DIR, classes, n_per_class=3)

In [None]:
#Brightness distribution
import numpy as np
from PIL import Image

all_imgs = []
for c in classes:
    all_imgs += [p for p in Path(DATA_DIR, c).glob("*") if p.suffix.lower() in [".jpg",".jpeg",".png"]]

sample = random.sample(all_imgs, min(400, len(all_imgs)))

brightness = []
for p in sample:
    img = Image.open(p).convert("L")
    brightness.append(np.array(img).mean())

plt.figure()
plt.hist(brightness, bins=30)
plt.title("Brightness Distribution (Sample)")
plt.xlabel("Mean brightness (0–255)")
plt.ylabel("Count")
plt.tight_layout()
plt.show()

In [None]:
##Split train/val/test
import random, shutil, os
from pathlib import Path

random.seed(SEED)

#Detect classes (folder names)
classes = [d.name for d in Path(SRC_DIR).iterdir() if d.is_dir()]
print("Classes found:", classes)

#Create output folder structure
for split in ["train", "val", "test"]:
    for c in classes:
        Path(f"{OUT_DIR}/{split}/{c}").mkdir(parents=True, exist_ok=True)

IMG_EXTS = {".jpg"}

#Copy files into splits
for c in classes:
    files = [f for f in Path(f"{SRC_DIR}/{c}").glob("*") if f.suffix.lower() in IMG_EXTS]
    random.shuffle(files)

    n = len(files)
    n_train = int(n * TRAIN_RATIO)
    n_val   = int(n * VAL_RATIO)

    train_files = files[:n_train]
    val_files   = files[n_train:n_train + n_val]
    test_files  = files[n_train + n_val:]

    for f in train_files:
        shutil.copy2(f, f"{OUT_DIR}/train/{c}/{f.name}")
    for f in val_files:
        shutil.copy2(f, f"{OUT_DIR}/val/{c}/{f.name}")
    for f in test_files:
        shutil.copy2(f, f"{OUT_DIR}/test/{c}/{f.name}")

    print(f"{c}: total={n} | train={len(train_files)} | val={len(val_files)} | test={len(test_files)}")

print("\nSplit complete ->", OUT_DIR)

In [None]:
#Load data(with simple preprocessing)
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

#Train: rescale + augmentation
train_gen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    zoom_range=0.1,
    horizontal_flip=True)

#Val/Test: onlyrescale(no augmentation)
val_gen = ImageDataGenerator(rescale=1./255)
test_gen = ImageDataGenerator(rescale=1./255)