In [1]:
import os
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification\\research'

In [2]:
os.chdir("../")

In [3]:
%pwd

'c:\\Users\\49179\\Desktop\\Food_image_classification'

In [4]:
from dataclasses import dataclass
from pathlib import Path
from torch.utils.data import DataLoader

@dataclass
class data_transformation_config:
    root_dir: Path
    train_dir: Path
    test_dir: Path
    train_transform_pkl: Path
    test_transform_pkl: Path
    transformed_train_data: DataLoader
    transformed_test_data: DataLoader
    batch_size: int
    shuffle: bool
    color_transform: dict
    spatial_transform: dict
    normalize_transform: dict
    data_loader_params: dict
    

In [5]:
from Food_Classification.utils.common import read_yaml, create_directory
from Food_Classification.constants import *
#from Food_Classification.entity.config_entity import data_transformation_config
class ConfigurationManager:
    def __init__(self,
                 config_file_path = CONFIG_FILE_PATH,
                 params_file_path = PARAMS_FILE_PATH):
        
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)   
        create_directory([self.config.artifacts_root])
        
        def get_data_transform_config(self) -> data_transformation_config:
            config = self.config.data_transform
            create_directory([config.root_dir])
            data_transformation_configuration = data_transformation_config(root_dir= config.root_dir,
                                                            train_dir= Path(config.train_dir),
                                                            test_dir= Path(config.test_dir),
                                                             train_transform_pkl= Path(config.train_transform_pkl),
                                                             test_transform_pkl= Path(config.test_transform_pkl),
                                                             transformed_train_data= None,
                                                             transformed_test_data= None,
                                                             batch_size= config.BATCH_SIZE,
                                                             shuffle= config.SHUFFLE,
                                                             color_transform ={'brightness': BRIGHTNESS,
                                                                                    'contrast': CONTRAST,
                                                                                    'saturation': SATURATION,
                                                                                    'hue': HUE},
                                                             spatial_transform= {'vertical_flip': VERTICLE_FLIP,
                                                                                'resize': RESIZE,
                                                                                'center_crop': CENTER_CROP,
                                                                                'rotation': RANDOMROTATION,
                                                                                'zoom': RANDOMZOOM},
                                                             normalize_transform= {'mean': NORMALIZE_MEAN,
                                                                                    'std': NORMALIZE_STD},
                                                             data_loader_params= {"num_workers": NUM_WORKERS,
                                                                                    "pin_memory": PIN_MEMORY})
            
            return data_transformation_configuration
        


In [12]:
import os
from typing import Tuple
import joblib
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from Food_Classification import logger
#from Food_Classification.config.configuration import ConfigurationManager

class data_transformation:
    def __init__(self, config: data_transformation_config):
        self.config = config

    def tranform_train_data(self) -> transforms.Compose:
        try: 
            logger.info("Transforming train data")
            train_transform: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                      transforms.center_crop(self.config.spatial_transform['center_crop']),
                                                                      transforms.RandomZoom(self.config.spatial_transform['zoom']),
                                                                      transforms.RandomRotation(self.config.spatial_transform['rotation']),
                                                                      transforms.VerticalFlip(self.config.spatial_transform['vertical_flip']),
                                                                      transforms.ColorJitter(**self.config.color_transform),
                                                                      transforms.ToTensor(),
                                                                      transforms.Normalize(**self.config.normalize_transform)])
            

            return train_transform
        except Exception as e:
            raise e
        
    def test_transform(self) -> transforms.Compose:
        try:
            logger.info("Transforming test data")
            test_transform: transforms.Compose = transforms.Compose([transforms.Resize(self.config.spatial_transform['resize']),
                                                                      transforms.center_crop(self.config.spatial_transform['center_crop']),
                                                                      transforms.ToTensor(),
                                                                      transforms.Normalize(**self.config.normalize_transform)])

            return test_transform
        
        except Exception as e:
            raise e


    def create_dataloaders(self) -> Tuple[DataLoader, DataLoader]:
        try:
            logger.info("creating_dataloaders")

            train_data: Dataset = ImageFolder(root =self.config.train_dir, 
                                              transform= self.tranform_train_data)

            test_data: Dataset = ImageFolder(root= self.config.test_dir, 
                                             transform = self.test_transform)
            
            class_names: list = train_data.classes

            train_dataloader: DataLoader = DataLoader(dataset= train_data,
                                                      batch_size= self.config.batch_size,
                                                      shuffle = self.config.shuffle,
                                                      num_workers = self.config.num_workers,
                                                      pin_memory = self.config.pin_memory)
            
            test_dataloader: DataLoader = DataLoader(dataset= test_data,
                                                      batch_size= self.config.batch_size,
                                                      shuffle = False,
                                                      num_workers = self.config.num_workers)
            
            logger.info('DataLoaders created')
            return train_dataloader, test_dataloader, class_names
        
        except Exception as e:
            raise e
    
        
    def initiate_data_transformation(self):
        try:
            logger.info("Initiating data transformation")

            train_transform: transforms.Compose = self.transform_train_data()

            test_transform: transforms.Compose = self.test_transform()

            create_dire

            

            



            



Note: you may need to restart the kernel to use updated packages.


ERROR: Invalid requirement: 'torch,'
