In [3]:
def write_code(flag=True):
    if flag:
        code = """print("Flag is true")
        """
    else:
        code = """print("Flag is false")
        """
        
    return code

exec(write_code(False))

Flag is false


In [8]:
import torch
from torch.utils.data import DataLoader
from pathlib import Path

def loaddata(dataset_path: Path):
    ds=torch.load(dataset_path, weights_only=False)
    dl=DataLoader(dataset=ds)
    return dl

In [1]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from pathlib import Path
from typing import Tuple, Optional, Literal
import os
import pandas as pd
from CNNClassifier.logger import logger


class ImageDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        images_path: str,
        transform: Optional[transforms.Compose] = None,
        type: Literal["train", "val", "test"] = "train"
    ):
        self.data_path = data_path
        self.images_path = images_path
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
        self.type = type
        self.data = []
        self.classes = set()
        data_df = pd.read_csv(data_path)
        for _, row in data_df.iterrows():
            image_name = row['images']
            label = row['label']
            self.data.append((image_name, label))
            self.classes.add(label)
        
        self.classes = sorted(list(self.classes))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
        self.data = [(img_name, self.class_to_idx[label]) for img_name, label in self.data]
        logger.info(f"Created dataset with {len(self.data)} images and {len(self.classes)} classes")
    
    def num_classes(self) -> int:
        return len(self.classes)
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        if idx >= self.__len__():
            logger.error(f"Index {idx} is out of range")
            raise IndexError(f"Index {idx} is out of range")
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name, label = self.data[idx]
        img_path = os.path.join(self.images_path, img_name)
        
        try:
            image = Image.open(img_path).convert('RGB')
            image = self.transform(image)
            return image, label
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {e}")
            raise e
        
    def export_dataset(self, save_path: Path) -> None:
        try:
            metadata = {
                'data': self.data,  # List of (image_name, label) tuples
                'data_path': self.data_path,
                'images_path': self.images_path,
                'transform': self.transform,
                'class_to_idx': self.class_to_idx,
                'type': self.type
            }
            
            torch.save(metadata, save_path)
            logger.info(f"Dataset metadata exported to {save_path}")
        except Exception as e:
            logger.error(f"Error exporting dataset metadata: {e}")
            raise e

    def __eq__(self, other: 'ImageDataset') -> bool:
        return self.data == other.data and self.images_path == other.images_path and self.data_path == other.data_path and self.type == other.type and self.num_classes() == other.num_classes() and self.__len__() == other.__len__()
     
    @classmethod
    def load_dataset(cls, metadata_path: Path) -> 'ImageDataset':
        try:
            metadata = torch.load(metadata_path, weights_only=False)
            dataset = cls(
                data_path=metadata['data_path'],
                images_path=metadata['images_path'],
                type=metadata['type']
            )
            dataset.data = metadata['data']
            # dataset.class_to_idx = metadata['class_to_idx']
            logger.info(f"Load dataset from {metadata_path} with {len(dataset)} images and {dataset.num_classes()} classes")
            return dataset
        except Exception as e:
            logger.error(f"Error loading dataset: {e}")
            raise e


In [25]:
train_og = ImageDataset(
    data_path=Path("../data/train/train_data.csv"),
    images_path=Path("../data/train/images"),
    type="train"
)

val_og = ImageDataset(
    data_path=Path("../data/val/val_data.csv"),
    images_path=Path("../data/val/images"),
    type="val"
)


[32m2025-05-19 01:23:40.499[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m42[0m - [1mCreated dataset with 5646 images and 4 classes[0m
[32m2025-05-19 01:23:40.515[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m42[0m - [1mCreated dataset with 806 images and 4 classes[0m


In [26]:
train_og.export_dataset(Path("../data/train/train_metadata.pt"))
val_og.export_dataset(Path("../data/val/val_metadata.pt"))

[32m2025-05-19 01:23:42.703[0m | [1mINFO    [0m | [36m__main__[0m:[36mexport_dataset[0m:[36m81[0m - [1mDataset metadata exported to ../data/train/train_metadata.pt[0m
[32m2025-05-19 01:23:42.709[0m | [1mINFO    [0m | [36m__main__[0m:[36mexport_dataset[0m:[36m81[0m - [1mDataset metadata exported to ../data/val/val_metadata.pt[0m


In [2]:
train_new = ImageDataset.load_dataset(Path("../data/train/train_metadata.pt"))
val_new = ImageDataset.load_dataset(Path("../data/val/val_metadata.pt"))

[32m2025-05-19 02:13:14.144[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m42[0m - [1mCreated dataset with 5646 images and 4 classes[0m
[32m2025-05-19 02:13:14.145[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_dataset[0m:[36m100[0m - [1mLoad dataset from ../data/train/train_metadata.pt with 5646 images and 4 classes[0m
[32m2025-05-19 02:13:14.160[0m | [1mINFO    [0m | [36m__main__[0m:[36m__init__[0m:[36m42[0m - [1mCreated dataset with 806 images and 4 classes[0m
[32m2025-05-19 02:13:14.160[0m | [1mINFO    [0m | [36m__main__[0m:[36mload_dataset[0m:[36m100[0m - [1mLoad dataset from ../data/val/val_metadata.pt with 806 images and 4 classes[0m


In [28]:
print(f"Train: {train_og == train_new}")
print(f"Val: {val_og == val_new}")

Train: True
Val: True
