In [1]:
import os
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification\\research'

In [2]:
os.chdir("../")

In [3]:
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification'

In [4]:
from dataclasses import dataclass
from pathlib import Path
from torch.utils.data import DataLoader

@dataclass
class data_transformation_config:
    root_dir: Path
    train_dir: Path
    test_dir: Path
    batch_size: int
    shuffle: bool
    color_transform: dict
    spatial_transform: dict
    normalize_transform: dict
    data_loader_params: dict
    

In [5]:
from Food_Classification.utils.common import read_yaml, create_directory
from Food_Classification.constants import *
from Food_Classification.entity.artifact_entity import DataTransformationArtifact
class ConfigurationManager:
    def __init__(self,
                 config_file_path = CONFIG_FILE_PATH,
                 params_file_path = PARAMS_FILE_PATH
                 ):
        
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)   

        create_directory([self.config.artifacts_root])
        

        
    def get_data_transform_config(self) -> data_transformation_config:
        config = self.config.data_transforms


        train_dir = os.path.join(self.config.data_ingestion.unzip_dir,'food_40_percent','train')
        test_dir = os.path.join(self.config.data_ingestion.unzip_dir,'food_40_percent','test')
        create_directory([config.root_dir])
        data_transformation_configuration = data_transformation_config(root_dir= config.root_dir,
                                                        train_dir= Path(train_dir),
                                                        test_dir= Path(test_dir),
                                                        batch_size= self.params.BATCH_SIZE,
                                                        shuffle= self.params.SHUFFLE,
                                                        color_transform ={'brightness': BRIGHTNESS,
                                                                                'contrast': CONTRAST,
                                                                                'saturation': SATURATION,
                                                                                'hue': HUE},
                                                        spatial_transform= {'vertical_flip': VERTICLE_FLIP,
                                                                            'resize': RESIZE,
                                                                            'center_crop': CENTER_CROP,
                                                                            'rotation': RANDOMROTATION
                                                                            },
                                                        normalize_transform= {'mean': NORMALIZE_MEAN,
                                                                                'std': NORMALIZE_STD},
                                                        data_loader_params= {"num_workers": NUM_WORKERS,
                                                                                "pin_memory": PIN_MEMORY})
        
        return data_transformation_configuration
    


In [6]:
#%%writefile src\Food_Classification\components\data_transformation.py
import os
from typing import Tuple
import joblib
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from Food_Classification import logger
from Food_Classification.entity.config_entity import data_transformation_config
from Food_Classification.entity.artifact_entity import DataTransformationArtifact
#from Food_Classification.config.configuration import ConfigurationManager

class data_transformation:
    def __init__(self, config: data_transformation_config):
        self.config = config

    def tranform_train_data(self) -> transforms.Compose:
        try: 
            logger.info("Transforming train data")
            train_transform: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                      transforms.CenterCrop(self.config.spatial_transform['center_crop']),
                                                                      transforms.RandomRotation(self.config.spatial_transform['rotation']),
                                                                      transforms.RandomVerticalFlip(self.config.spatial_transform['vertical_flip']),
                                                                      transforms.ColorJitter(**self.config.color_transform),
                                                                      transforms.ToTensor(),
                                                                      transforms.Normalize(**self.config.normalize_transform)])
            

            return train_transform
        except Exception as e:
            raise e
        
    def test_transform(self) -> transforms.Compose:
        try:
            logger.info("Transforming test data")
            test_transform: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                      transforms.CenterCrop(self.config.spatial_transform['center_crop']),
                                                                      transforms.ToTensor(),
                                                                      transforms.Normalize(**self.config.normalize_transform)])

            return test_transform
        
        except Exception as e:
            raise e


    def create_dataloaders(self, train_transform: transforms.Compose, test_transform: transforms.Compose) -> Tuple[DataLoader, DataLoader]:
        try:
            logger.info("creating_dataloaders")

            train_data: Dataset = ImageFolder(root =self.config.train_dir, 
                                              transform= train_transform)

            test_data: Dataset = ImageFolder(root= self.config.test_dir, 
                                             transform = test_transform)
            
            class_names: list = train_data.classes

            train_dataloader: DataLoader = DataLoader(dataset= train_data,
                                                      batch_size= self.config.batch_size,
                                                      shuffle = self.config.shuffle,
                                                      **self.config.data_loader_params)
            
            test_dataloader: DataLoader = DataLoader(dataset= test_data,
                                                      batch_size= self.config.batch_size,
                                                      shuffle = False,
                                                      **self.config.data_loader_params)
            
            logger.info('DataLoaders created')
            return train_dataloader, test_dataloader, class_names
        
        except Exception as e:
            raise e
    
        
    def initiate_data_transformation(self) -> DataTransformationArtifact:
        try:
            logger.info("Initiating data transformation")

            train_transform: transforms.Compose = self.tranform_train_data()

            test_transform: transforms.Compose = self.test_transform()

            train_transform_filename = os.path.join(self.config.root_dir, "train_transforms.pkl")  # Create filename with path and extension
            test_transform_filename = os.path.join(self.config.root_dir, "test_transforms.pkl")

            joblib.dump(train_transform, train_transform_filename)
            joblib.dump(test_transform, test_transform_filename)


            train_dataloader, test_dataloader, class_name = self.create_dataloaders(train_transform=train_transform, test_transform=test_transform)

            data_transformation_artifats: DataTransformationArtifact = DataTransformationArtifact(transformed_train_object=train_dataloader,
                                                                                                  transformed_test_object= test_dataloader)

            return data_transformation_artifats
        
        except Exception as e:
            raise e


In [8]:
# Assuming CONFIG_FILE_PATH points to the correct location of 'confi.yaml'
config = ConfigurationManager(config_file_path=CONFIG_FILE_PATH)
transformation_config = config.get_data_transform_config()
transform = data_transformation(config=transformation_config)
train_dataloader, test_dataloader, class_names = transform.initiate_data_transformation()

[2024-05-07 19:19:10,987: INFO: common: yaml file config\config.yaml loaded successfully]
[2024-05-07 19:19:10,990: INFO: common: yaml file params.yaml loaded successfully]
[2024-05-07 19:19:10,992: INFO: common: directory artifacts created successfully]
[2024-05-07 19:19:10,993: INFO: common: directory artifacts/data_transforms created successfully]


TypeError: __init__() missing 2 required positional arguments: 'transformed_train_object' and 'transformed_test_object'

In [None]:
%%writefile src\Food_Classification\pipeline\stage_3_data_transformation_pipeline.py
from Food_Classification.config.configuration import ConfigurationManager
from Food_Classification.components.data_transformation import data_transformation
from Food_Classification import logger

STAGE_NAME = "Prepare Base Model"

class DataTransformationPipeline:
    def __init__(self):
        pass

    def main(self):
        try:
            config = ConfigurationManager()
            transformation_config = config.get_data_transform_config()
            transform = data_transformation(config=transformation_config)
            train_dataloader, test_dataloader, class_names = transform.initiate_data_transformation()
        except Exception as e:
            raise e
        

if __name__ == "__main__":
    try:
        logger.info(f">>>>>>> stage : {STAGE_NAME} <<<<<<<<")
        data_transformations = DataTransformationPipeline()
        data_transformations.main()
        logger.info(f">>>>>>> stage : {STAGE_NAME} completed <<<<<<<< \n\nx========x")

    except Exception as e:
        logger.exception(e)
        raise e


Writing src\Food_Classification\pipeline\stage_3_data_transformation_pipeline.py
