<a href="https://colab.research.google.com/github/yashwantherukulla/SnakeSpotter/blob/main/Snake_Bg_Forkthis24.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 # Snake vs Background Classifier using PyTorch

 This notebook is a binary classifier to distinguish between snake images and background images using a VGG16 model with modified head.

 ## Imports and Setup

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import numpy as np
import os
import zipfile
import shutil
from pathlib import Path
import requests
from typing import Dict, List, Tuple
from tqdm.auto import tqdm
from timeit import default_timer as timer
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import itertools
import random
from PIL import Image
import pandas as pd
import seaborn as sns

try:
    from torchinfo import summary
except ImportError:
    print("[INFO] Couldn't find torchinfo... installing it.")
    !pip install -q torchinfo
    from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

torch.manual_seed(4709471861038091579)

 ## Data Preparation

In [None]:
data_path = Path("data/")
if Path("data/Data").is_dir():
    print("Data directory exists")
else:
    print(f"Did not find {data_path} directory, creating one...")
    data_path.mkdir(parents=True, exist_ok=True)

    with open(data_path / "BaseDataset.zip", "wb") as f:
        request = requests.get("https://figshare.com/ndownloader/files/35784053")
        print("Downloading data...")
        f.write(request.content)

    with zipfile.ZipFile(data_path / "BaseDataset.zip", "r") as zip_ref:
        print("Unzipping data...")
        zip_ref.extractall(data_path)

    os.remove(data_path / "BaseDataset.zip")

train_dir = data_path / "Data/train"
valid_dir = data_path / "Data/valid"
test_dir = data_path / "Data/test"

 ## Data Transforms and Loaders

In [None]:
target_size = (150, 150)
mean = np.array([0.0, 0.0, 0.0])
std = np.array([0.0, 0.0, 0.0])

data_transforms = {
    "train": transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ]),
    "valid": transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ]),
    "test": transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
    ])
}

batch_size = 32
num_workers = os.cpu_count()

train_data = datasets.ImageFolder(train_dir, transform=data_transforms["train"])
valid_data = datasets.ImageFolder(valid_dir, transform=data_transforms["valid"])
test_data = datasets.ImageFolder(test_dir, transform=data_transforms["test"])

class_names = train_data.classes

train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(valid_data)}")
print(f"Number of test samples: {len(test_data)}")
print(f"Classes: {class_names}")


 ## Visualize Sample Image

In [None]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.figure(figsize=(8, 8))
plt.imshow(img.permute(1, 2, 0))
plt.title(f"Label: {class_names[label]}")
plt.axis('off')
plt.show()

 ## Model Setup

In [None]:
weights = torchvision.models.VGG16_Weights.DEFAULT
model = torchvision.models.vgg16(weights=weights).to(device)

for param in model.parameters():
    param.requires_grad = False

model.classifier = nn.Sequential(
    nn.Linear(4096, 256),
    nn.ReLU(),
    nn.Dropout(0.98),
    nn.Linear(256, 64),
    nn.ReLU(),
    nn.Dropout(0.75),
    nn.Linear(64, 1),
    nn.Sigmoid()
).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

 ## Training and Validation Functions

In [None]:
def train_step(model: nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    model.train()
    train_loss, train_acc = 0, 0

    for batch, (X, y) in enumerate(dataloader):

        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

        train_acc += (y_pred == y).sum().item() / len(y_pred)

    return train_loss, train_acc

def valid_step(model: nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: nn.Module,
               device: torch.device) -> Tuple[float, float]:
    valid_loss, valid_acc = 0, 0

    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):

            valid_pred = model(X)
            loss = loss_fn(valid_pred, y)
            valid_loss += loss.item()

            valid_acc += (y_pred == y).sum().item() / len(valid_pred)

    return valid_loss, valid_acc


In [None]:
def train(model: nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          valid_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
    results = {"train_loss": [], "train_acc": [], "valid_loss": [], "valid_acc": []}

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model, train_dataloader, loss_fn, optimizer, device)
        valid_loss, valid_acc = valid_step(model, valid_dataloader, loss_fn, device)

        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"valid_loss: {valid_loss:.4f} | "
            f"valid_acc: {valid_acc:.4f}"
        )

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["valid_loss"].append(valid_loss)
        results["valid_acc"].append(valid_acc)

    return results


 ## Model Training

In [None]:
epochs = 15
start_time = timer()
results = train(model, train_dataloader, valid_dataloader, optimizer, loss_fn, epochs, device)
end_time = timer()
print(f"[INFO] Total training time: {end_time-start_time:.3f} seconds")

 ## Loading Pretrained Weights


In [None]:
def load_with_pretrained_weights(PATH: str):
  model = torchvision.models.vgg16(weights=weights).to(device)
  for param in model.parameters():
      param.requires_grad = False
  model.fc = nn.Sequential(
      nn.Linear(4096, 256),
      nn.ReLU(),
      nn.Dropout(0.9),
      nn.Linear(256, 64),
      nn.ReLU(),
      nn.Dropout(0.7),
      nn.Linear(64, 1),
      nn.Sigmoid()
  ).to(device)

  model.load_state_dict(torch.load(PATH))
  return model

In [None]:
model = load_with_pretrained_weights("/content/updated-snake-bg-model-weights-v2-ft.pth")

 ## Visualize Results

In [None]:
def plot_loss_curves(results: Dict[str, List]):
    plt.figure(figsize=(15, 7))

    plt.subplot(1, 2, 1)
    plt.plot(results["train_loss"], label="train_loss")
    plt.plot(results["valid_loss"], label="valid_loss")
    plt.title("Loss")
    plt.xlabel("Epochs")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(results["train_acc"], label="train_acc")
    plt.plot(results["valid_acc"], label="valid_acc")
    plt.title("Accuracy")
    plt.xlabel("Epochs")
    plt.legend()

    plt.tight_layout()
    plt.show

In [None]:
plot_loss_curves(results)

In [None]:
def pred_and_plot_image(img_path: str,
                        model: nn.Module =model,
                        class_names: List[str] = class_names,
                        img_size: Tuple[int, int] = (224, 224),
                        transform: torchvision.transforms = None,
                        device: torch.device=device):
  from PIL import Image

  img = Image.open(img_path)
  with torch.inference_mode():
    trans_img = img_transform(img).unsqueeze(dim=0)
    target_img_prob = model(trans_img)

  actual_label = img_path.split('/')[-2]
  target_img_label = torch.round(target_img_prob).to(torch.int8).item()
  target_img_prob = target_img_prob.to(torch.float16).item()

  plt.figure()
  plt.imshow(img)
  plt.title(f"True: {actual_label} | Pred: {class_names[target_img_label]} | Prob: {target_img_prob:.3f}")
  plt.axis(False)

In [None]:
def pred_and_plot_images_dataloader(dataloader: DataLoader = test_dataloader,
                         model: nn.Module = model,
                         class_names: List[str] = class_names,
                         img_size: Tuple[int, int] = (224, 224),
                         device: torch.device = device,
                         num_images: int = 5):
    model.to(device)
    model.eval()

    total_images = len(dataloader.dataset)

    random_indices = random.sample(range(total_images), num_images)
    rev_norm = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ])
    with torch.inference_mode():
        for idx in random_indices:
            img, actual_label = dataloader.dataset[idx]
            img = img.unsqueeze(0).to(device)

            target_img_prob = model(img)
            target_img_label = torch.round(target_img_prob).to(torch.int8).item()
            target_img_prob = target_img_prob.to(torch.float16).item()

            img = rev_norm(img).squeeze(0).to("cpu")
            plt.figure()
            plt.imshow(img.permute(1, 2, 0))
            if class_names[actual_label] == "background":
                target_img_prob = 1 - target_img_prob
            plt.title(f"True: {class_names[actual_label]} | Pred: {class_names[target_img_label]} | Prob: {target_img_prob:.3f}")
            plt.axis(False)
            plt.show()

In [None]:
pred_and_plot_image("/content/data/Data/test/snake/180 (649).JPG")

In [None]:
pred_and_plot_images_dataloader()

## Evaluation Metrics

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import numpy as np


def test_step(model: nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: nn.Module,
              device: torch.device) -> Tuple[float, float, np.ndarray, np.ndarray]:
  test_loss = 0
  all_predictions = []
  all_labels = []

  with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader):
      test_pred = model(X)
      loss = loss_fn(test_pred, y)
      test_loss += loss.item()


      all_predictions.extend(y_pred)
      all_labels.extend(y)

  test_loss /= len(dataloader)

  return test_loss, all_predictions

def test(model: nn.Module,
         dataloader: torch.utils.data.DataLoader,
         loss_fn: nn.Module,
         device: torch.device) -> Dict[str, float]:
  test_loss, all_predictions, all_labels = test_step(model, dataloader, loss_fn, device)

  accuracy = accuracy_score(all_labels, all_predictions)
  precision = precision_score(all_labels, all_predictions, zero_division=0)
  recall = recall_score(all_labels, all_predictions, zero_division=0)
  f1 = f1_score(all_labels, all_predictions, zero_division=0)

  class_report = classification_report(all_labels, all_predictions, output_dict=True)
  conf_matrix = confusion_matrix(all_labels, all_predictions)

  results = {
    "test_loss": test_loss,
    "accuracy": accuracy,
    "precision": precision,
    "recall": recall,
    "f1_score": f1,
    "classification_report": class_report,
    "confusion_matrix": conf_matrix
  }

  return results

def visualize_results(results: Dict[str, float]):
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

  sns.heatmap(results['confusion_matrix'], annot=True, fmt='d', cmap='Blues', ax=ax1)
  ax1.set_title('Confusion Matrix')
  ax1.set_xlabel('Predicted')
  ax1.set_ylabel('Actual')

  class_report = results['classification_report']
  del class_report['accuracy']

  report_df = pd.DataFrame(class_report).T
  report_df = report_df.drop('support', axis=1)

  sns.heatmap(report_df, annot=True, cmap='YlGnBu', ax=ax2)
  ax2.set_title('Classification Report')
  ax2.set_xlabel('Metrics')
  ax2.set_ylabel('Classes')

  plt.tight_layout()
  plt.show()

  print(f"Test Loss: {results['test_loss']:.4f}")
  print(f"Accuracy: {results['accuracy']:.4f}")
  print(f"Precision: {results['precision']:.4f}")
  print(f"Recall: {results['recall']:.4f}")
  print(f"F1 Score: {results['f1_score']:.4f}")

In [None]:
results = test(model, test_dataloader, loss_fn, device)
visualize_results(results)

## Misc


In [None]:
torch.random.initial_seed() #4709471861038091579

In [None]:
PATH = "/content/updated-snake-bg-model-weights-v2-ft.pth"
torch.save(model.state_dict(), PATH)