In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [37]:
class Birddataset(Dataset):
    def __init__(self, image_dir: str, allowed_classes: List[str], dataset_type: str = "train", do_transform: bool = True):
        """
        Custom dataset class for bird image classification.

        Args:
            image_dir (str): Path to the root directory containing class subdirectories with images.
            allowed_classes (List[str]): List of allowed class names to include.
            dataset_type (str): Either "train" or "test" to control the split of the dataset.
            do_transform (bool): Whether to apply transformations to images.

        Attributes:
            train_samples (List[Tuple[str, str]]): List of training samples as (image_path, class_name) tuples.
            test_samples (List[Tuple[str, str]]): List of test samples as (image_path, class_name) tuples.
            transform: Transformations to be applied to images.
        """
        self.image_dir = image_dir
        self.allowed_classes = allowed_classes
        self.dataset_type = dataset_type
        self.do_transform = do_transform

        # Predefined image transformations (Normalization for ImageNet-pretrained models)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        # Initialize empty lists for train and test samples
        self.train_samples = []
        self.test_samples = []

        # Preload file paths in parallel for faster processing
        self._load_samples()

    def _load_samples(self) -> None:
        """
        Loads image file paths and splits them into training and test sets.
        Uses multithreading for efficient file scanning.
        """
        with ThreadPoolExecutor() as executor:
            futures = []

            # Iterate over all class directories in the image directory
            for class_name in os.listdir(self.image_dir):
                class_path = os.path.join(self.image_dir, class_name)
                if os.path.isdir(class_path) and (class_name in self.allowed_classes or class_name == "unlabeled"):
                    futures.append(executor.submit(self._get_class_samples, class_path, class_name))

            for future in futures:
                class_samples = future.result()

                # Separate handling for 'unlabeled' images (assumed to be part of the training set)
                if class_samples[0][1] == "unlabeled":
                    self.train_samples.extend(class_samples)
                else:
                    # Split samples into training and test sets (80-20 split)
                    random.seed(42)
                    random.shuffle(class_samples)
                    self.train_samples.extend(class_samples[:-3])  # All but last 3 for training
                    self.test_samples.extend(class_samples[-3:])   # Last 3 for testing

    def _get_class_samples(self, class_dir: str, class_name: str) -> List[Tuple[str, str]]:
        """
        Gets all image file paths for a particular class.

        Args:
            class_dir (str): Directory path for the class.
            class_name (str): Name of the class.

        Returns:
            List[Tuple[str, str]]: List of (image_path, class_name) tuples.
        """
        return [(os.path.join(class_dir, img_entry.name), class_name) 
                for img_entry in os.scandir(class_dir) if img_entry.is_file()]

    def __len__(self) -> int:
        """
        Returns the total number of images in the dataset split (train or test).

        Returns:
            int: Length of the dataset split (number of images).
        """
        if self.dataset_type == "train":
            return len(self.train_samples)
        else:
            return len(self.test_samples)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        """
        Retrieves an image and its corresponding class index from the dataset.

        Args:
            index (int): Index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, int]: 
                - Transformed image tensor.
                - Class index corresponding to the image.
        """
        # Select sample based on dataset type
        if self.dataset_type == "train":
            img_path, class_name = self.train_samples[index]
        else:
            img_path, class_name = self.test_samples[index]

        # Lazy loading of the image
        image = Image.open(img_path).convert("RGB")

        # Apply transformations if enabled
        if self.do_transform:
            image = self.transform(image)

        # Get the class index from allowed_classes
        class_id = self.allowed_classes.index(class_name)

        return image, class_id

In [43]:
from encoder import *

In [52]:
encoder_config = EncoderConfig(
    image_size=128,
    hidden_size=512,
    intermediate_size=512 * 3,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_channels=3,
    patch_size=8,
    layer_norm_eps=1e-6,
    attention_dropout=0.0,
    num_image_tokens=None,
    do_random_mask=True,
    mask_ratio=0.75
)

In [55]:
model = Encoder(encoder_config)