In [1]:
import os
os.chdir('../')
!pwd

/Users/shrey/Desktop/github/human_pose_estimation


In [2]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class PoseEstimationConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_weights: str
    params_confidence_threshold: float
    params_batch_size: int
    params_augmentation: bool
    params_model_type: str
    params_keypoints: int

In [3]:
from cnnEstimation.constants import *
from cnnEstimation.utils.common import read_yaml, create_dir


class ConfigurationManager:
    def __init__(self,
                 config_filepath = CONFIG_FILE_PATH,
                 params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_dir([self.config.artifacts_root])


    def get_pose_estimation_config(self) -> PoseEstimationConfig:
        config = self.config.pose_estimation

        create_dir([config.root_dir])

        prepare_base_model_config = PoseEstimationConfig(
            root_dir = Path(config.root_dir),
            base_model_path = Path(config.base_model_path),
            updated_base_model_path = Path(config.updated_base_model_path),
            params_image_size = self.params.IMAGE_SIZE,
            params_learning_rate = self.params.LEARNING_RATE,
            params_weights = self.params.WEIGHTS,
            params_confidence_threshold = self.params.CONFIDENCE_THRESHOLD,
            params_batch_size = self.params.BATCH_SIZE,
            params_augmentation = self.params.AUGMENTATION,
            params_model_type = self.params.MODEL_TYPE,
            params_keypoints = self.params.KEYPOINTS
        )
        return prepare_base_model_config

In [4]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow.lite as tflite
from pathlib import Path
from cnnEstimation import logger
import requests


class PreparePoseEstimationModel:
    def __init__(self, config):
        self.config = config
        self.model = None


    def get_base_model(self):
        """
        Loads the MoveNet model from TensorFlow Lite.
        If the model does not exist or is a directory, it downloads the correct file.
        """
        model_type = self.config.params_model_type.lower()
        model_path = Path(self.config.base_model_path)

        if model_path.is_dir():
            logger.warning(f"❌ {model_path} is a directory. Removing it...")
            os.rmdir(model_path)

        if not model_path.exists():
            logger.info(f"Downloading MoveNet model: {model_type}")
            self.download_tflite_model(model_type, model_path)

        try:
            self.model = tflite.Interpreter(model_path=str(model_path))
            self.model.allocate_tensors()
            logger.info(f"✅ Loaded MoveNet TFLite model from {model_path}")
        except Exception as e:
            logger.error(f"❌ Error loading MoveNet model: {str(e)}")
            raise RuntimeError("Failed to load MoveNet model.")

    def download_tflite_model(self, model_type, save_path):
        model_urls = {
            "movenet_lightning": "https://storage.googleapis.com/movenet_models/movenet_lightning.tflite",
            "movenet_thunder": "https://storage.googleapis.com/movenet_models/movenet_thunder.tflite"
        }

        url = model_urls.get(model_type)
        if url is None:
            raise ValueError("Invalid MoveNet model type. Choose 'movenet_lightning' or 'movenet_thunder'.")

        response = requests.get(url, stream=True)
        
        # ✅ Write file in binary mode
        with open(save_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        # ✅ Validate file size (MoveNet models are around ~4MB)
        if os.path.getsize(save_path) < 1000:
            print(f"❌ Error: {save_path} is too small! The download might be incomplete.")
            os.remove(save_path)  # Delete corrupted file
            raise RuntimeError("Downloaded MoveNet model is invalid.")

        print(f"✅ MoveNet {model_type} model downloaded successfully to {save_path}")

    def load_saved_model(self):
        """
        Loads a previously saved MoveNet model.
        """
        model_path = Path(self.config.base_model_path)

        if not model_path.exists():
            raise FileNotFoundError(f"Saved model not found at {model_path}")

        try:
            if model_path.suffix == ".tflite":
                self.model = tflite.Interpreter(model_path=str(model_path))
                self.model.allocate_tensors()
                logger.info(f"TFLite model loaded from {model_path}")
            else:
                self.model = hub.load(str(model_path))
                logger.info(f"MoveNet model loaded from {model_path}")
        except Exception as e:
            logger.error(f"Error loading saved model: {str(e)}")
            raise RuntimeError("Failed to load MoveNet model.")


In [5]:
try:
    config = ConfigurationManager()
    prepare_pose_estimation_model_config = config.get_pose_estimation_config()
    prepare_pose_estimation_model = PreparePoseEstimationModel(prepare_pose_estimation_model_config)
    prepare_pose_estimation_model.get_base_model()
except Exception as e:
    raise e

[2025-01-31 13:18:58,745 - INFO - common - YAML file loaded successfully: config/config.yaml]
[2025-01-31 13:18:58,747 - INFO - common - YAML file loaded successfully: params.yaml]
[2025-01-31 13:18:58,747 - INFO - common - Directory created successfully: artifacts]
[2025-01-31 13:18:58,748 - INFO - common - Directory created successfully: artifacts/pose_estimation]
[2025-01-31 13:18:58,749 - INFO - 2900733022 - Downloading MoveNet model: movenet_lightning]


❌ Error: artifacts/pose_estimation/movenet_model.tflite is too small! The download might be incomplete.


RuntimeError: Downloaded MoveNet model is invalid.