# Dependencies

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt

# Custom Dataset

In [2]:
class SimpleVeinDataset(Dataset):
    def __init__(self, csv_file, image_dir, mask_dir):
        """
        Args:
            csv_file (str): Path to CSV with BatID and ImageID columns
            image_dir (str): Path to image folder
            mask_dir (str): Path to mask folder
        """
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.mask_dir = mask_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image ID
        image_id = self.data.iloc[idx]['ImageID']
        
        # Load image and mask
        img_path = os.path.join(self.image_dir, f"{image_id}.png")
        mask_path = os.path.join(self.mask_dir, f"{image_id}.jpg")
        
        # Open and convert to grayscale
        image = Image.open(img_path).convert('L')
        mask = Image.open(mask_path).convert('L')
        
        # Convert to tensor and normalize
        image = torch.from_numpy(np.array(image)).float() / 255.0
        mask = torch.from_numpy(np.array(mask)).float() / 255.0
        
        # Add channel dimension
        image = image.unsqueeze(0)
        mask = mask.unsqueeze(0)
        
        return image, mask

# Data Loading

In [3]:
# Create dataset
dataset = SimpleVeinDataset(
    csv_file='../Dataset/dataset.csv',
    image_dir='../Dataset/Images',
    mask_dir='../Dataset/Masks'
)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=1, #change to 2 when i get to buildling training 
    shuffle=False  #also change this to true
)

# Validating Dataset

In [None]:
def group_and_display_all_bat_images(csv_file, image_dir, mask_dir):
    """
    Group and display all images from all bats
    
    Args:
        csv_file (str): Path to CSV with BatId and ImageID columns
        image_dir (str): Path to image folder
        mask_dir (str): Path to mask folder
    """
    # Read the CSV
    df = pd.read_csv(csv_file)
    
    # Group by BatId
    grouped = df.groupby('BatID')
    
    # Create dataset for loading images
    dataset = SimpleVeinDataset(csv_file, image_dir, mask_dir)
    
    # Get all bat IDs
    all_bats = list(grouped.groups.keys())
    
    for bat_id in all_bats:
        # Get all image IDs for this bat
        bat_images = grouped.get_group(bat_id)
        num_images = len(bat_images)
        
        # Create a figure with enough subplots for all images and masks
        fig, axes = plt.subplots(num_images, 2, figsize=(10, 5*num_images))
        plt.suptitle(f'Bat ID: {bat_id}', fontsize=16)
        
        # If there's only one image, wrap axes in list
        if num_images == 1:
            axes = axes.reshape(1, -1)
        
        # Load and display each image for this bat
        for idx, (_, row) in enumerate(bat_images.iterrows()):
            # Find index in dataset
            dataset_idx = df[df['ImageID'] == row['ImageID']].index[0]
            image, mask = dataset[dataset_idx]
            
            # Display image
            axes[idx, 0].imshow(image.squeeze(), cmap='gray')
            axes[idx, 0].set_title(f'Image {row["ImageID"]}')
            axes[idx, 0].axis('off')
            
            # Display mask
            axes[idx, 1].imshow(mask.squeeze(), cmap='gray')
            axes[idx, 1].set_title(f'Mask {row["ImageID"]}')
            axes[idx, 1].axis('off')
        
        plt.tight_layout()
        plt.show()

# Example usage:
"""
group_and_display_all_bat_images(
    csv_file='your_data.csv',
    image_dir='path/to/Images',
    mask_dir='path/to/Masks'
)
"""

def analyze_bat_distribution(csv_file):
    """
    Analyze the distribution of images across bats
    
    Args:
        csv_file (str): Path to CSV with BatId and ImageID columns
    """
    df = pd.read_csv(csv_file)
    grouped = df.groupby('BatID')
    
    # Get distribution statistics
    image_counts = grouped.size()
    
    print("\nBat Image Distribution:")
    print(f"Total number of bats: {len(image_counts)}")
    print(f"Total number of images: {len(df)}")
    print(f"Average images per bat: {image_counts.mean():.2f}")
    print(f"Min images per bat: {image_counts.min()}")
    print(f"Max images per bat: {image_counts.max()}")
    
    # Plot distribution
    plt.figure(figsize=(10, 5))
    plt.hist(image_counts, bins='auto')
    plt.title('Distribution of Images per Bat')
    plt.xlabel('Number of Images')
    plt.ylabel('Number of Bats')
    plt.show()
    
    # Display counts for each bat
    print("\nDetailed image counts per bat:")
    for bat_id, count in image_counts.items():
        print(f"Bat {bat_id}: {count} images")

"""
**Uncomment this to view all the wing images for each bat!**
group_and_display_all_bat_images(
    csv_file='../Dataset/dataset.csv',
    image_dir='../Dataset/Images',
    mask_dir='../Dataset/Masks'
)
"""

"\ngroup_and_display_all_bat_images(\n    csv_file='../Dataset/dataset.csv',\n    image_dir='../Dataset/Images',\n    mask_dir='../Dataset/Masks'\n)\n"