In [1]:
import os
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification\\research'

In [2]:
os.chdir("../")
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification'

In [3]:
from dataclasses import dataclass
from pathlib import Path
from torch.utils.data import DataLoader
@dataclass
class DataTransformConfig:
    root_dir: Path
    train_dir: Path
    test_dir: Path
    color_transform: dict
    spatial_transform: dict
    data_loader_params: dict
    normalize: dict
    train_transforms_file: Path
    test_transforms_file: Path


In [4]:
from torch.utils.data import DataLoader

from Food_Classification.constants import *
from Food_Classification.utils.common import read_yaml, create_directory

class ConfigurationManager:
    def __init__(self, 
                config_file_path= CONFIG_FILE_PATH,
                params_file_path = PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
    
        create_directory([self.config.artifacts_root])
        
    
    def get_DataTransformConfig(self) -> DataTransformConfig:
        config = self.config.data_transforms
        create_directory([config.root_dir])
        train_dir = os.path.join(self.config.data_ingestion.unzip_dir,'food_40_percent','train')
        test_dir = os.path.join(self.config.data_ingestion.unzip_dir,'food_40_percent','test')
        train_transformed_file = Path(config.TRAIN_TRANSFORMS_FILE)
        test_transformed_file = Path(config.TEST_TRANSFORMS_FILE)

        TransformationConfig = DataTransformConfig(root_dir= config.root_dir,
                                                   train_dir=Path(train_dir),
                                                   test_dir=Path(test_dir),
                                                    train_transforms_file=train_transformed_file,
                                                    test_transforms_file=test_transformed_file,
                                                   color_transform={"brightness":BRIGHTNESS,
                                                                    "contrast":CONTRAST,
                                                                    "saturation":SATURATION,
                                                                    "hue":HUE},
                                                    spatial_transform= {'vertical_flip': VERTICLE_FLIP,
                                                                        'resize': RESIZE,
                                                                        'center_crop': CENTER_CROP,
                                                                        'rotation': RANDOMROTATION
                                                                        },
                                                    normalize= {'mean': NORMALIZE_MEAN,
                                                                'std': NORMALIZE_STD},
                                                    data_loader_params={'batch_size': self.params.BATCH_SIZE,
                                                                        'shuffle': self.params.SHUFFLE,
                                                                        'num_workers': NUM_WORKERS,
                                                                        'pin_memory': PIN_MEMORY,
                                                                         })
        return TransformationConfig


In [13]:
%%writefile src\Food_Classification\components\data_transformation.py
from Food_Classification.config.configuration import ConfigurationManager
from Food_Classification.entity.config_entity import DataTransformConfig
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from Food_Classification.entity.artifact_entity import DataTransformationArtifact
from Food_Classification import logger
from typing import Tuple
import joblib
import os

class DataTransformation:
    def __init__(self,config: DataTransformConfig):
        self.config = config

    def train_data_transform(self):
        try:
            logger.info("Data Transformation started")
            train_transforms: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                      transforms.CenterCrop(self.config.spatial_transform['center_crop']),
                                                                      transforms.RandomRotation(self.config.spatial_transform['rotation']),
                                                                      transforms.RandomVerticalFlip(self.config.spatial_transform['vertical_flip']),
                                                                      transforms.ColorJitter(**self.config.color_transform),
                                                                      transforms.ToTensor(),
                                                                      transforms.Normalize(**self.config.normalize)])

        
            return train_transforms
        except Exception as e:
            raise e
    
    def test_data_transform(self):
        try:
            test_transforms: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                    transforms.CenterCrop(self.config.spatial_transform['center_crop']),
                                                                    transforms.ToTensor(),
                                                                    transforms.Normalize(**self.config.normalize)])

            return test_transforms
        except Exception as e:
            raise e
    

    def create_dataloaders(self, train_data:transforms.Compose, test_data: transforms.Compose) -> Tuple[DataLoader, DataLoader]:
        try:
            logger.info("Data Loading started")
            train_dataset = ImageFolder(root= self.config.train_dir,
                                        transform=train_data)
            
            test_dataset = ImageFolder(root=self.config.test_dir,
                                    transform=test_data)
            

            train_dataloader = DataLoader(train_dataset, **self.config.data_loader_params)

            test_dataloader = DataLoader(test_dataset, **self.config.data_loader_params)

            return train_dataloader, test_dataloader
        except Exception as e:
            raise e
    
    def initiate_datatransfrom(self) -> DataTransformationArtifact:
        try:
            train_transforms: transforms.Compose = self.train_data_transform()
            test_transforms: transforms.Compose = self.test_data_transform()

            joblib.dump(train_transforms, self.config.train_transforms_file)
            joblib.dump(test_transforms, self.config.test_transforms_file)
            logger.info("transform pickle file created")

            train_dataloader, test_dataloader = self.create_dataloaders(train_transforms, test_transforms)

            transforms_artifact: DataTransformationArtifact = DataTransformationArtifact(transformed_train_object= train_dataloader,
                                                                                         transformed_test_object= test_dataloader,
                                                                                         trained_transformed_file= self.config.train_transforms_file,
                                                                                         test_transformed_file= self.config.test_transforms_file)
            logger.info("Data Transformation completed")
            return transforms_artifact
        except Exception as e:
            raise e


Overwriting src\Food_Classification\components\data_transformation.py


In [10]:
config = ConfigurationManager()
tran = config.get_DataTransformConfig()
transfom = DataTransformation(tran)
transfom_artifact = transfom.initiate_datatransfrom()

[2024-05-13 23:59:38,941: INFO: common: yaml file config\config.yaml loaded successfully]
[2024-05-13 23:59:38,943: INFO: common: yaml file params.yaml loaded successfully]
[2024-05-13 23:59:38,944: INFO: common: directory artifacts created successfully]
[2024-05-13 23:59:38,945: INFO: common: directory artifacts/data_transforms created successfully]


In [11]:
transfom_artifact

DataTransformationArtifact(transformed_train_object=<torch.utils.data.dataloader.DataLoader object at 0x000001E70D5EC460>, transformed_test_object=<torch.utils.data.dataloader.DataLoader object at 0x000001E70D5ECFD0>, trained_transformed_file=WindowsPath('artifacts/data_transforms/train_transforms.pkl'), test_transformed_file=WindowsPath('artifacts/data_transforms/test_transforms.pkl'))