In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from sklearn.decomposition import PCA
import os
import pickle
from tqdm import tqdm
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

@dataclass
class ImageOnlyConfig:
    """Configuration for image-only processing"""
    
    # ===== PATHS =====
    IMAGE_FOLDER: str = "success_traj_img"
    
    OUTPUT_PATH: str = "image_features.npz"
    PCA_MODEL_PATH: str = "image_pca_models.pkl"
    
    # ===== IMAGE PROCESSING =====
    RESNET_FEATURE_DIM: int = 512  # ResNet18 final layer per view
    VIEWS: List[str] = None
    
    # ===== PCA COMPRESSION =====
    COMPRESSED_DIM: int = 64  # Final compressed dimension per view
    TOTAL_COMPRESSED_DIM: int = 192  # 64 * 3 views
    
    # ===== MODEL =====
    DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"
    BATCH_SIZE: int = 32
    
    def __post_init__(self):
        if self.VIEWS is None:
            self.VIEWS = ["front", "top", "wrist"]
        
        print(f"Image-Only Processor Config")
        print(f"Views: {self.VIEWS}")
        print(f"ResNet Features: {self.RESNET_FEATURE_DIM} per view")
        print(f"Compressed Features: {self.COMPRESSED_DIM} per view")
        print(f"Total Compressed: {self.TOTAL_COMPRESSED_DIM}")
        print(f"Device: {self.DEVICE}")

class ImageOnlyProcessor:
    """Process only images to create latent vectors"""
    
    def __init__(self, config: ImageOnlyConfig):
        self.config = config
        self.device = torch.device(config.DEVICE)

        # Initialize ResNet18
        self.model = models.resnet18(pretrained=True)
        self.model = nn.Sequential(*list(self.model.children())[:-1])  # Remove classifier
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Storage
        self.image_index = {}
        self.pca_models = {}
        
        print(f"ResNet18 feature extractor initialized")
    
    def parse_filename(self, filename: str) -> Optional[Tuple[str, str, int]]:
        """Parse image filename: traj_key, view, timestep"""
        name = filename.replace('.png', '')
        parts = name.split('_')
        
        try:
            # Find view
            view = None
            view_idx = -1
            for i, part in enumerate(parts):
                if part in self.config.VIEWS:
                    view = part
                    view_idx = i
                    break
            
            if view is None:
                return None
            
            # Extract trajectory key and timestep
            traj_key = '_'.join(parts[:view_idx])
            timestep = int(parts[-1])
            
            return traj_key, view, timestep
            
        except (ValueError, IndexError):
            return None
    
    def build_image_index(self):
        print(f"Building image index from: {self.config.IMAGE_FOLDER}")
        image_index = defaultdict(lambda: defaultdict(dict))
        total_images, parsed_images = 0, 0

        for root, _, files in os.walk(self.config.IMAGE_FOLDER):
            for filename in files:
                if not filename.endswith('.png'):
                    continue
                total_images += 1
                parse_result = self.parse_filename(filename)
                if parse_result:
                    traj_key, view, timestep = parse_result
                    image_path = os.path.join(root, filename)
                    image_index[traj_key][timestep][view] = image_path
                    parsed_images += 1

        complete_triplets = sum(
            len(image_index[traj][ts]) == len(self.config.VIEWS)
            for traj in image_index for ts in image_index[traj]
        )

        print(f"Total images: {total_images}, Parsed: {parsed_images}, Complete triplets: {complete_triplets}")
        self.image_index = dict(image_index)
        return complete_triplets
    
    def extract_features(self, image_path: str) -> np.ndarray:
        """Extract ResNet18 features from single image"""
        try:
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.transform(image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                features = self.model(image_tensor)
                features = features.view(features.size(0), -1)
            
            return features.cpu().numpy().flatten()
        
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            return np.zeros(self.config.RESNET_FEATURE_DIM)
    
    def extract_all_image_features(self) -> Dict[str, Dict[int, np.ndarray]]:
        """Extract features for all complete image triplets"""
        print("Extracting multiview image features...")
        
        features_dict = {}
        total_processed = 0
        
        for traj_key in tqdm(self.image_index, desc="Processing trajectories"):
            features_dict[traj_key] = {}
            
            for timestep in self.image_index[traj_key]:
                # Check if all views available
                available_views = set(self.image_index[traj_key][timestep].keys())
                required_views = set(self.config.VIEWS)
                
                if available_views == required_views:
                    # Extract features from all 3 views
                    view_features = []
                    
                    for view in self.config.VIEWS:
                        image_path = self.image_index[traj_key][timestep][view]
                        features = self.extract_features(image_path)
                        view_features.append(features)
                    
                    # Concatenate all view features
                    combined_features = np.concatenate(view_features)  # [1536,]
                    features_dict[traj_key][timestep] = combined_features
                    total_processed += 1
        
        print(f"Extracted features for {total_processed} complete image triplets")
        return features_dict
    
    def fit_pca_models(self, features_dict: Dict) -> Dict[str, PCA]:
        """Fit PCA for each view separately"""
        print("Fitting PCA compression models...")
        
        # Collect features by view
        view_features = {view: [] for view in self.config.VIEWS}
        
        for traj_key in features_dict:
            for timestep in features_dict[traj_key]:
                combined_features = features_dict[traj_key][timestep]
                
                # Split by view
                for i, view in enumerate(self.config.VIEWS):
                    start_idx = i * self.config.RESNET_FEATURE_DIM
                    end_idx = (i + 1) * self.config.RESNET_FEATURE_DIM
                    view_feature = combined_features[start_idx:end_idx]
                    view_features[view].append(view_feature)
        
        # Fit PCA for each view
        pca_models = {}
        for view in self.config.VIEWS:
            if view_features[view]:
                features_array = np.array(view_features[view])
                
                pca = PCA(n_components=self.config.COMPRESSED_DIM)
                pca.fit(features_array)
                
                explained_var = pca.explained_variance_ratio_.sum()
                print(f"  {view} view: {explained_var:.3f} variance explained")
                
                pca_models[view] = pca
        
        self.pca_models = pca_models
        return pca_models
    
    def compress_all_features(self, features_dict: Dict) -> Dict[str, Dict[int, np.ndarray]]:
        """Apply PCA compression to all features"""
        print("Compressing features with PCA...")
        
        compressed_dict = {}
        
        for traj_key in tqdm(features_dict, desc="Compressing"):
            compressed_dict[traj_key] = {}
            
            for timestep in features_dict[traj_key]:
                combined_features = features_dict[traj_key][timestep]
                
                # Compress each view separately
                compressed_views = []
                for i, view in enumerate(self.config.VIEWS):
                    start_idx = i * self.config.RESNET_FEATURE_DIM
                    end_idx = (i + 1) * self.config.RESNET_FEATURE_DIM
                    view_feature = combined_features[start_idx:end_idx]
                    
                    if view in self.pca_models:
                        compressed_feature = self.pca_models[view].transform([view_feature])
                        compressed_views.append(compressed_feature.flatten())
                    else:
                        compressed_views.append(np.zeros(self.config.COMPRESSED_DIM))
                
                # Combine compressed features from all views
                final_compressed = np.concatenate(compressed_views)  # [192,]
                compressed_dict[traj_key][timestep] = final_compressed
        
        return compressed_dict
    
    def save_image_features(self, compressed_features: Dict):
        """Save image features only"""
        print(f"Saving image features to: {self.config.OUTPUT_PATH}")
        
        # Convert to arrays with metadata
        feature_list = []
        metadata_list = []
        
        for traj_key in compressed_features:
            for timestep in compressed_features[traj_key]:
                feature_vector = compressed_features[traj_key][timestep]
                feature_list.append(feature_vector)
                
                metadata_list.append({
                    'traj_key': traj_key,
                    'timestep': timestep,
                    'feature_dim': len(feature_vector)
                })
        
        feature_array = np.array(feature_list)
        
        # Save features
        np.savez_compressed(
            self.config.OUTPUT_PATH,
            features=feature_array,
            metadata=metadata_list,
            config=self.config.__dict__
        )
        
        # Save PCA models
        with open(self.config.PCA_MODEL_PATH, 'wb') as f:
            pickle.dump(self.pca_models, f)
        
        print(f"Image features saved:")
        print(f"  Features: {self.config.OUTPUT_PATH}")
        print(f"  PCA models: {self.config.PCA_MODEL_PATH}")
        print(f"  Total features: {len(feature_array)}")
        print(f"  Feature dimension: {feature_array.shape[1]}")
        print(f"  File size: {os.path.getsize(self.config.OUTPUT_PATH)/1024/1024:.1f} MB")

def process_images_only(config: ImageOnlyConfig = None) -> str:
    """
    Main function to process images only
    
    Returns:
        Path to generated image features file
    """
    if config is None:
        config = ImageOnlyConfig()
    
    print("=" * 60)
    print("Image-Only Processing Pipeline")
    print("Extracting latent vectors from multiview images")
    print("=" * 60)
    
    try:
        # Initialize processor
        processor = ImageOnlyProcessor(config)
        
        # Step 1: Build image index
        print("\n1. Building image index...")
        complete_count = processor.build_image_index()
        
        if complete_count == 0:
            raise ValueError("No complete image triplets found!")
        
        # Step 2: Extract raw features
        print("\n2. Extracting ResNet18 features...")
        features_dict = processor.extract_all_image_features()
        
        # Step 3: Fit PCA
        print("\n3. Fitting PCA compression...")
        processor.fit_pca_models(features_dict)
        
        # Step 4: Compress features
        print("\n4. Compressing features...")
        compressed_features = processor.compress_all_features(features_dict)
        
        # Step 5: Save results
        print("\n5. Saving image features...")
        processor.save_image_features(compressed_features)
        
        print("\n" + "=" * 60)
        print("Image-Only Processing Completed Successfully!")
        print("=" * 60)
        print(f"✅ Generated: {config.OUTPUT_PATH}")
        print(f"✅ Feature dimension: {config.TOTAL_COMPRESSED_DIM}")
        print(f"✅ Views processed: {config.VIEWS}")
        
        return config.OUTPUT_PATH
        
    except Exception as e:
        print(f"\n Processing failed: {e}")
        raise e

if __name__ == "__main__":
    config = ImageOnlyConfig()
    output_path = process_images_only(config)
    print(f"\nImage latent vectors ready: {output_path}")


Image-Only Processor Config
Views: ['front', 'top', 'wrist']
ResNet Features: 512 per view
Compressed Features: 64 per view
Total Compressed: 192
Device: cuda
Image-Only Processing Pipeline
Extracting latent vectors from multiview images
총 이미지: 0
ResNet18 feature extractor initialized

1. Building image index...
Building image index from: success_traj_img
Total images: 28654, Parsed: 28654, Complete triplets: 9551

2. Extracting ResNet18 features...
Extracting multiview image features...


Processing trajectories:  62%|██████▏   | 61/98 [00:46<00:31,  1.17it/s]