In [5]:
import shutil
import os

numpy_path = '/Users/shaikimran/Documents/Text Analytics/.venv/lib/python3.10/site-packages/numpy'
numpy_dist_info = '/Users/shaikimran/Documents/Text Analytics/.venv/lib/python3.10/site-packages/numpy-1.26.4.dist-info'

# Remove directories if they exist
for path in [numpy_path, numpy_dist_info]:
    if os.path.exists(path):
        shutil.rmtree(path)
        print(f"Removed: {path}")

In [6]:

%pip install transformers deepface tf-keras CLIP

zsh:1: no such file or directory: /Users/shaikimran/Documents/Text Analytics/.venv/bin/python
Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install numpy

Note: you may need to restart the kernel to use updated packages.


In [4]:
# =========================================================
# üé® AGE-CONDITIONED FORENSIC SKETCH GENERATOR
# PyTorch Version - Kaggle Compatible
# =========================================================

import os, time, json, torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

  cpu = _conversion_method_template(device=torch.device("cpu"))


ModuleNotFoundError: No module named 'numpy'

In [None]:
# First, uninstall existing tensorflow and install Apple Silicon version
%pip uninstall tensorflow tensorflow-macos -y
%pip install tensorflow-macos
%pip install tensorflow-metal  # For GPU acceleration on Apple Silicon

In [None]:
# =========================================================
# üßπ CLEAN UP CURRENT ENVIRONMENT
# =========================================================
import subprocess
import sys
import os

print("üßπ Cleaning up environment...")

# Remove problematic packages
packages_to_remove = ['tensorflow', 'tensorflow-macos', 'tensorflow-metal', 'deepface']
for package in packages_to_remove:
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'uninstall', '-y', package])
        print(f"‚úÖ Removed {package}")
    except:
        print(f"‚ö†Ô∏è Could not remove {package}")

print("Environment cleanup completed!")

In [None]:
# =========================================================
# üöÄ CLEAN INSTALLATION (After Kernel Restart)
# =========================================================
import os
# CRITICAL: Disable TensorFlow completely
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['NO_TF'] = '1'

print("üì¶ Installing clean packages...")

import subprocess
import sys

# Install only what we need, avoiding TensorFlow dependencies
packages = [
    'transformers>=4.21.0',
    'torch',
    'torchvision',
    'pillow',
    'matplotlib',
    'tqdm',
    'numpy',
    'opencv-python',
    'git+https://github.com/openai/CLIP.git'
]

for package in packages:
    try:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
    except Exception as e:
        print(f"‚ö†Ô∏è Failed to install {package}: {e}")

print("‚úÖ Package installation completed!")

# Now import
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Import transformers - this should work now
try:
    from transformers import BlipProcessor, BlipForConditionalGeneration
    print("‚úÖ BLIP loaded successfully!")
    
    # Test that it works
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    print("‚úÖ BLIP models loaded and working!")
    
except ImportError as e:
    print(f"‚ùå BLIP import failed: {e}")
    BlipProcessor, BlipForConditionalGeneration = None, None
except Exception as e:
    print(f"‚ùå BLIP model loading failed: {e}")
    BlipProcessor, BlipForConditionalGeneration = None, None

# Import CLIP
try:
    import clip
    print("‚úÖ CLIP loaded successfully!")
except ImportError as e:
    print(f"‚ùå CLIP import failed: {e}")
    clip = None

# Skip DeepFace completely
print("‚ö†Ô∏è DeepFace skipped to avoid TensorFlow issues")
DeepFace = None

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"‚úÖ Using device: {device}")

In [None]:
# =========================================================
# üîß DEVICE SETUP
# =========================================================
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"\n‚úÖ Using device: {device}")
print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")

In [None]:

# =========================================================
# üìÅ Configuration
# =========================================================

# Local macOS path
dataset_path = "/Users/shaikimran/Documents/Text Analytics/archive"
train_A = os.path.join(dataset_path, "photo")
train_B = os.path.join(dataset_path, "cropped_sketch")

# Model hyperparameters
IMG_SIZE = 256
BATCH_SIZE = 8
LAMBDA_L1 = 100
EPOCHS = 10
TEXT_EMB_DIM = 512
AGE_EMB_DIM = 64
LEARNING_RATE = 2e-4

# Create output directories (for local storage)
os.makedirs("./checkpoints", exist_ok=True)
os.makedirs("./age_progression_results", exist_ok=True)

print(f"\nüìÅ Dataset paths:")
print(f"  Photos: {train_A}")
print(f"  Exists: {os.path.exists(train_A)}")
print(f"  Sketches: {train_B}")
print(f"  Exists: {os.path.exists(train_B)}")

# Verify dataset structure
if os.path.exists(train_A):
    photo_files = [f for f in os.listdir(train_A) if f.endswith(('.jpg', '.png', '.jpeg', '.JPG', '.PNG', '.JPEG'))]
    photo_count = len(photo_files)
    print(f"\nüì∏ Found {photo_count} photos")
    if photo_count > 0:
        print(f"  Sample files: {photo_files[:3]}")
else:
    print("  ‚ö†Ô∏è Photo directory not found!")

if os.path.exists(train_B):
    sketch_files = [f for f in os.listdir(train_B) if f.endswith(('.jpg', '.png', '.jpeg', '.JPG', '.PNG', '.JPEG'))]
    sketch_count = len(sketch_files)
    print(f"\n‚úèÔ∏è Found {sketch_count} sketches")
    if sketch_count > 0:
        print(f"  Sample files: {sketch_files[:3]}")
else:
    print("  ‚ö†Ô∏è Sketch directory not found!")

print(f"\n‚öôÔ∏è Model Configuration:")
print(f"  Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Lambda L1: {LAMBDA_L1}")
print(f"  Text Embedding Dim: {TEXT_EMB_DIM}")
print(f"  Age Embedding Dim: {AGE_EMB_DIM}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"\nüíæ Output directories:")
print(f"  Checkpoints: ./checkpoints")
print(f"  Results: ./age_progression_results")

In [None]:
# =========================================================
# üéØ Load CLIP Model (with fallback)
# =========================================================
clip_model = None
clip_preprocess = None

if clip is not None:
    try:
        print("\nüîÑ Loading CLIP model...")
        clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
        print("‚úÖ CLIP model loaded")
    except Exception as e:
        print(f"‚ö†Ô∏è CLIP loading failed: {e}")
        clip_model = None

def clip_text_embedding(text):
    """Get CLIP text embedding with fallback - MPS COMPATIBLE"""
    if clip_model is None:
        # Fallback: simple hash-based embedding
        np.random.seed(hash(text) % (2**32))
        return np.random.randn(512).astype(np.float32)
    
    try:
        tokens = clip.tokenize([text]).to(device)
        with torch.no_grad():
            embedding = clip_model.encode_text(tokens)
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        # Ensure float32 before converting to numpy
        return embedding.cpu().float().numpy().flatten()
    except Exception as e:
        print(f"‚ö†Ô∏è Text embedding error: {e}")
        np.random.seed(hash(text) % (2**32))
        return np.random.randn(512).astype(np.float32)

def clip_image_embedding(image_tensor):
    """Get CLIP image embedding with fallback - MPS COMPATIBLE"""
    if clip_model is None or clip_preprocess is None:
        return np.random.randn(512).astype(np.float32)
    
    try:
        if isinstance(image_tensor, np.ndarray):
            image_tensor = Image.fromarray(image_tensor.astype(np.uint8))
        
        img_preprocessed = clip_preprocess(image_tensor).unsqueeze(0).to(device)
        with torch.no_grad():
            embedding = clip_model.encode_image(img_preprocessed)
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        # Ensure float32 before converting to numpy
        return embedding.cpu().float().numpy().flatten()
    except Exception as e:
        print(f"‚ö†Ô∏è Image embedding error: {e}")
        return np.random.randn(512).astype(np.float32)

def clip_similarity(text, image_array):
    """Calculate CLIP similarity with fallback - MPS COMPATIBLE"""
    try:
        text_emb = clip_text_embedding(text)
        img_emb = clip_image_embedding(image_array)
        
        # L2 normalization (all float32)
        text_norm = np.linalg.norm(text_emb)
        img_norm = np.linalg.norm(img_emb)
        
        if text_norm > 0:
            text_emb = text_emb / text_norm
        if img_norm > 0:
            img_emb = img_emb / img_norm
        
        # Cosine similarity (should be between -1 and 1)
        similarity = float(np.dot(text_emb, img_emb))
        
        # Clip to valid range
        similarity = np.clip(similarity, -1.0, 1.0)
        
        return float(similarity)  # Ensure it's a Python float, not numpy.float64
    except Exception as e:
        print(f"‚ö†Ô∏è CLIP similarity error: {e}")
        return 0.0  # Neutral similarity

In [None]:
# =========================================================
# üë§ DeepFace Analysis (with fallback)
# =========================================================
def analyze_face(image_path):
    """Analyze face with DeepFace (with fallback)"""
    if DeepFace is None:
        return {'age': 30, 'gender': 'Man', 'race': 'white', 'emotion': 'neutral'}
    
    try:
        analysis = DeepFace.analyze(
            img_path=image_path,
            actions=['age', 'gender', 'race', 'emotion'],
            enforce_detection=False,
            silent=True
        )
        if isinstance(analysis, list):
            analysis = analysis[0]
        return {
            'age': int(analysis['age']),
            'gender': analysis.get('dominant_gender', 'Man'),
            'race': analysis.get('dominant_race', 'white'),
            'emotion': analysis.get('dominant_emotion', 'neutral')
        }
    except Exception as e:
        print(f"‚ö†Ô∏è Face analysis failed: {e}")
        return {'age': 30, 'gender': 'Man', 'race': 'white', 'emotion': 'neutral'}

In [None]:
# =========================================================
# ‚è∞ Age Encoding
# =========================================================
def parse_temporal_description(description):
    """Parse temporal info from text"""
    import re
    
    years_ago_match = re.search(r'(\d+)\s+years?\s+ago', description.lower())
    age_match = re.search(r'(?:about|around|aged?)?\s*(\d+)(?:\s*years?\s*old)?', description.lower())
    
    current_age = None
    years_elapsed = 0
    
    if age_match:
        current_age = int(age_match.group(1))
    if years_ago_match:
        years_elapsed = int(years_ago_match.group(1))
    
    if 'elderly' in description.lower() and not current_age:
        current_age = 70
    elif 'middle-aged' in description.lower() and not current_age:
        current_age = 45
    elif 'young' in description.lower() and not current_age:
        current_age = 25
    
    target_age = (current_age or 30) + years_elapsed
    
    return {
        'described_age': current_age or 30,
        'years_ago': years_elapsed,
        'target_age': target_age,
        'raw_description': description
    }

def encode_age(age):
    """Encode age to vector"""
    age_normalized = np.clip(age / 100.0, 0, 1)
    age_vector = np.array([
        age_normalized,
        np.sin(age_normalized * np.pi),
        np.cos(age_normalized * np.pi),
        age_normalized ** 2,
        np.sqrt(age_normalized),
    ], dtype=np.float32)
    age_embedding = np.tile(age_vector, AGE_EMB_DIM // len(age_vector) + 1)[:AGE_EMB_DIM]
    return torch.FloatTensor(age_embedding)

In [None]:
# =========================================================
# üì¶ Dataset Class
# =========================================================
class SketchDataset(Dataset):
    def __init__(self, photo_dir, sketch_dir, descriptions, ages):
        self.photo_files = sorted([os.path.join(photo_dir, f) for f in os.listdir(photo_dir) 
                                   if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        self.sketch_files = sorted([os.path.join(sketch_dir, f) for f in os.listdir(sketch_dir) 
                                    if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        self.descriptions = descriptions
        self.ages = ages
        
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    
    def __len__(self):
        return min(len(self.photo_files), len(self.sketch_files))
    
    def __getitem__(self, idx):
        photo_path = self.photo_files[idx]
        sketch_path = self.sketch_files[idx]
        filename = os.path.basename(photo_path)
        
        photo = Image.open(photo_path).convert('RGB')
        sketch = Image.open(sketch_path).convert('RGB')
        
        photo = self.transform(photo)
        sketch = self.transform(sketch)
        
        description = self.descriptions.get(filename, "a person's face")
        age = self.ages.get(filename, 30)
        
        return photo, sketch, description, age, filename


In [None]:
import json
import os
import torch
from tqdm import tqdm
from PIL import Image

# =========================================================
# üî§ FIXED CAPTION GENERATION - macOS M1 Compatible
# =========================================================
print("\nüî§ Generating captions...")

image_files = [f for f in os.listdir(train_A) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
desc = {}
face_attributes = {}
estimated_ages = {}

DETAILED_COUNT = 50  # Process first 50 images with BLIP

# =========================================================
# üöÄ FIXED BLIP LOADING - macOS M1 Compatible
# =========================================================
blip_processor = None
blip_model = None

try:
    from transformers import BlipProcessor, BlipForConditionalGeneration
    
    print("üì¶ Loading BLIP model...")
    
    # ‚úÖ FIX 1: Use torch.cuda instead of tf.config
    # ‚úÖ FIX 2: Check for MPS (Apple Silicon GPU) support
    if torch.backends.mps.is_available():
        device = "mps"  # Apple Silicon GPU
        print("üçé Using Apple Silicon GPU (MPS)")
    elif torch.cuda.is_available():
        device = "cuda"  # NVIDIA GPU
        print("üéÆ Using CUDA GPU")
    else:
        device = "cpu"
        print("üíª Using CPU")
    
    # Load processor and model
    blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip_model = BlipForConditionalGeneration.from_pretrained(
        "Salesforce/blip-image-captioning-base"
    )
    
    # ‚úÖ FIX 3: Move model to device AFTER loading
    blip_model = blip_model.to(device)
    blip_model.eval()  # Set to evaluation mode
    
    print(f"‚úÖ BLIP model loaded successfully on {device}")
    print(f"   This will create UNIQUE captions for each face!")
    
except ImportError as e:
    print(f"‚ùå BLIP import failed: {e}")
    print("   Install with: pip install transformers")
    blip_processor = None
    blip_model = None
    device = "cpu"
    
except Exception as e:
    print(f"‚ö†Ô∏è BLIP loading failed: {e}")
    print("   Continuing with basic captions...")
    blip_processor = None
    blip_model = None
    device = "cpu"


# =========================================================
# üé® PROCESS IMAGES WITH BLIP
# =========================================================
print(f"\nüñºÔ∏è Processing {len(image_files)} images...")
print(f"   Detailed BLIP captions: {min(DETAILED_COUNT, len(image_files))} images")
print(f"   Simple captions: {max(0, len(image_files) - DETAILED_COUNT)} images")

for idx, name in enumerate(tqdm(image_files, desc="Analyzing images")):
    try:
        img_path = os.path.join(train_A, name)
        
        if idx < DETAILED_COUNT:
            # =============================================
            # Process with BLIP for detailed captions
            # =============================================
            
            # Try face analysis (with fallback)
            try:
                attrs = analyze_face(img_path)
            except:
                attrs = {'age': 30, 'gender': 'unknown', 'race': 'unknown'}
            
            # Generate caption with BLIP
            caption = "a person's face"  # Fallback caption
            
            if blip_processor is not None and blip_model is not None:
                try:
                    # Load and preprocess image
                    img = Image.open(img_path).convert("RGB")
                    inputs = blip_processor(images=img, return_tensors="pt")
                    
                    # ‚úÖ FIX 4: Move inputs to same device as model
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                    
                    # Generate caption
                    with torch.no_grad():  # No gradient needed for inference
                        output_ids = blip_model.generate(**inputs, max_length=50)
                    
                    # Decode caption
                    caption = blip_processor.decode(output_ids[0], skip_special_tokens=True)
                    
                except Exception as e:
                    # If BLIP fails for this image, use fallback
                    print(f"\n   ‚ö†Ô∏è Caption generation failed for {name}: {e}")
                    caption = "a person's face"
            
            # Create detailed description
            description = f"{caption}, approximately {attrs['age']} years old, {attrs['gender']}, {attrs['race']} appearance"
            estimated_ages[name] = attrs['age']
            desc[name] = description
            face_attributes[name] = attrs
            
        else:
            # =============================================
            # Use simple description for remaining images
            # =============================================
            desc[name] = "a person's face"
            estimated_ages[name] = 30
            face_attributes[name] = None
            
    except Exception as e:
        print(f"\n‚ö†Ô∏è Error processing {name}: {e}")
        desc[name] = "a person's face"
        estimated_ages[name] = 30
        face_attributes[name] = None


# =========================================================
# üíæ SAVE RESULTS
# =========================================================
os.makedirs("./outputs", exist_ok=True)

with open("./outputs/descriptions.json", "w") as f:
    json.dump(desc, f, indent=2)

with open("./outputs/estimated_ages.json", "w") as f:
    json.dump(estimated_ages, f, indent=2)

with open("./outputs/face_attributes.json", "w") as f:
    # Convert None values to string for JSON serialization
    face_attrs_serializable = {k: v if v is not None else "none" for k, v in face_attributes.items()}
    json.dump(face_attrs_serializable, f, indent=2)

print(f"\n‚úÖ Processing complete!")
print(f"   Total images: {len(image_files)}")
print(f"   Detailed BLIP analysis: {min(DETAILED_COUNT, len(image_files))}")
print(f"   Descriptions saved: {len(desc)}")
print(f"   Output location: ./outputs/")


# =========================================================
# üìä DISPLAY SAMPLE RESULTS
# =========================================================
print("\nüìä Sample descriptions:")
for i, (name, description) in enumerate(list(desc.items())[:5]):
    print(f"\n{i+1}. {name}")
    print(f"   {description}")
    print(f"   Estimated age: {estimated_ages[name]}")


# =========================================================
# üîç VERIFY CAPTION DIVERSITY
# =========================================================
print("\nüîç Caption Diversity Check:")
print("="*60)

# Extract just the BLIP captions (before age/gender added)
blip_captions = []
for name, full_desc in list(desc.items())[:DETAILED_COUNT]:
    # Extract the caption part (before the comma)
    caption_part = full_desc.split(',')[0] if ',' in full_desc else full_desc
    blip_captions.append(caption_part)

# Check uniqueness
unique_captions = set(blip_captions)
print(f"Total BLIP captions: {len(blip_captions)}")
print(f"Unique captions: {len(unique_captions)}")
print(f"Diversity: {len(unique_captions)/len(blip_captions)*100:.1f}%")

if len(unique_captions) < len(blip_captions) * 0.5:
    print("‚ö†Ô∏è WARNING: Low caption diversity! BLIP may not be working properly.")
    print("   Expected: Each face should have a different description")
else:
    print("‚úÖ Good caption diversity! Each face has unique features.")

print("\nSample unique captions:")
for i, caption in enumerate(list(unique_captions)[:10]):
    print(f"   {i+1}. {caption}")


# =========================================================
# üí° NEXT STEPS
# =========================================================
print("\nüí° Next Steps:")
print("="*60)

if blip_processor is not None and blip_model is not None:
    print("‚úÖ BLIP is working! Your captions are diverse.")
    print("   Now you can:")
    print("   1. Retrain the model with these diverse captions")
    print("   2. Train for 50-100 epochs")
    print("   3. The model will learn to generate different faces for different descriptions")
else:
    print("‚ùå BLIP failed to load.")
    print("   Fix this by:")
    print("   1. Install transformers: pip install transformers")
    print("   2. Ensure you have a stable internet connection (to download BLIP)")
    print("   3. Check you have enough RAM (BLIP needs ~2GB)")

print("\nüìù To use these captions for training:")
print("   dataset = SketchDataset(train_A, train_B, desc, estimated_ages)")
print("   dataloader = DataLoader(dataset, batch_size=8, shuffle=True)")

In [None]:
# =========================================================
# üèóÔ∏è Generator Architecture (PyTorch)
# =========================================================
class AgeConditionedGenerator(nn.Module):
    def __init__(self, text_dim=512, age_dim=64):
        super().__init__()
        
        # Initial projection
        self.fc1 = nn.Linear(text_dim + age_dim, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 8 * 8 * 512)
        
        # Age modulation
        self.age_mod = nn.Linear(age_dim, 512)
        
        # Upsampling blocks
        self.up1 = self._upsample_block(512, 512, dropout=True)
        self.up2 = self._upsample_block(512, 512, dropout=True)
        self.up3 = self._upsample_block(512, 512, dropout=True)
        self.up4 = self._upsample_block(512, 256)
        self.up5 = self._upsample_block(256, 128)
        
        # Age conditioning layers
        self.age_scale1 = nn.Linear(age_dim, 512)
        self.age_scale2 = nn.Linear(age_dim, 512)
        self.age_scale3 = nn.Linear(age_dim, 512)
        
        # Final conv
        self.final = nn.Conv2d(128, 3, kernel_size=3, padding=1)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def _upsample_block(self, in_ch, out_ch, dropout=False):
        layers = [
            nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
        ]
        if dropout:
            layers.append(nn.Dropout2d(0.5))
        layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)
    
    def forward(self, text_emb, age_emb):
        # Combine embeddings
        x = torch.cat([text_emb, age_emb], dim=1)
        
        # Project
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.relu(self.fc3(x))
        x = x.view(-1, 512, 8, 8)
        
        # Age modulation
        age_mod = self.sigmoid(self.age_mod(age_emb)).unsqueeze(-1).unsqueeze(-1)
        x = x * age_mod
        
        # Upsample with age injection
        x = self.up1(x)
        age_s1 = self.sigmoid(self.age_scale1(age_emb)).unsqueeze(-1).unsqueeze(-1)
        x = x * age_s1
        
        x = self.up2(x)
        age_s2 = self.sigmoid(self.age_scale2(age_emb)).unsqueeze(-1).unsqueeze(-1)
        x = x * age_s2
        
        x = self.up3(x)
        age_s3 = self.sigmoid(self.age_scale3(age_emb)).unsqueeze(-1).unsqueeze(-1)
        x = x * age_s3
        
        x = self.up4(x)
        x = self.up5(x)
        
        x = self.final(x)
        return self.tanh(x)

In [None]:
# =========================================================
# üèóÔ∏è Discriminator Architecture
# =========================================================
class AgeAwareDiscriminator(nn.Module):
    def __init__(self, age_dim=64):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = self._conv_block(3, 64, normalize=False)
        self.conv2 = self._conv_block(64, 128)
        self.conv3 = self._conv_block(128, 256)
        self.conv4 = self._conv_block(256, 512)
        
        # Age processing
        self.age_fc = nn.Sequential(
            nn.Linear(age_dim, 256),
            nn.ReLU()
        )
        
        # Final layers
        self.fc_validity = nn.Linear(512 * 16 * 16 + 256, 1)
        self.fc_age_match = nn.Sequential(
            nn.Linear(512 * 16 * 16 + 256, 1),
            nn.Sigmoid()
        )
    
    def _conv_block(self, in_ch, out_ch, normalize=True):
        layers = [nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)
    
    def forward(self, sketch, age_emb):
        x = self.conv1(sketch)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        
        x = x.view(x.size(0), -1)
        age_features = self.age_fc(age_emb)
        
        combined = torch.cat([x, age_features], dim=1)
        
        validity = self.fc_validity(combined)
        age_match = self.fc_age_match(combined)
        
        return validity, age_match

In [None]:
# =========================================================
# üß† Perceptual Loss (VGG)
# =========================================================
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        try:
            vgg = models.vgg16(pretrained=True).features[:16]
            self.vgg = vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False
            self.enabled = True
        except:
            self.enabled = False
            print("‚ö†Ô∏è VGG model not available, perceptual loss disabled")
    
    def forward(self, pred, target):
        if not self.enabled:
            return torch.tensor(0.0, device=pred.device)
        
        pred = (pred + 1) / 2
        target = (target + 1) / 2
        pred_features = self.vgg(pred)
        target_features = self.vgg(target)
        return torch.mean(torch.abs(pred_features - target_features))

In [None]:
# =========================================================
# üéì Initialize Models
# =========================================================
print("\nüèóÔ∏è Building models...")
generator = AgeConditionedGenerator(TEXT_EMB_DIM, AGE_EMB_DIM).to(device)
discriminator = AgeAwareDiscriminator(AGE_EMB_DIM).to(device)
perceptual_loss_fn = PerceptualLoss().to(device)

print(f"‚úÖ Generator params: {sum(p.numel() for p in generator.parameters()):,}")
print(f"‚úÖ Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

# Optimizers
gen_opt = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_opt = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Loss functions
bce_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

In [None]:
# =========================================================
# ‚öñÔ∏è Loss Functions (FIXED for MPS)
# =========================================================
def compute_clip_loss(generated_sketches, text_descriptions):
    """Compute CLIP similarity loss (with fallback) - MPS COMPATIBLE"""
    if clip_model is None:
        return torch.tensor(0.0, device=device, dtype=torch.float32)
    
    try:
        similarities = []
        # Convert to numpy on CPU to avoid MPS dtype issues
        gen_np = generated_sketches.detach().cpu().float().numpy()  # Ensure float32
        gen_np = ((gen_np + 1) * 127.5).astype(np.uint8)
        
        for i, text in enumerate(text_descriptions):
            sketch = gen_np[i].transpose(1, 2, 0)
            sim = clip_similarity(text, sketch)
            # Ensure similarity is in valid range, then compute loss
            sim = np.clip(sim, -1.0, 1.0)
            loss = max(0.0, 1.0 - sim)
            similarities.append(loss)
        
        # Create tensor with explicit float32 dtype
        return torch.tensor(
            np.mean(similarities), 
            device=device, 
            dtype=torch.float32,  # Explicit float32
            requires_grad=False
        )
    except Exception as e:
        print(f"‚ö†Ô∏è CLIP loss error: {e}")
        return torch.tensor(0.0, device=device, dtype=torch.float32)

def generator_loss(disc_validity, disc_age_match, gen_sketch, real_sketch, texts):
    """Combined generator loss - MPS COMPATIBLE"""
    # Adversarial loss
    adv_loss = bce_loss(disc_validity, torch.ones_like(disc_validity))
    
    # Age matching loss
    age_loss = bce_loss(disc_age_match, torch.ones_like(disc_age_match))
    
    # L1 reconstruction loss
    l1 = l1_loss(gen_sketch, real_sketch)
    
    # Perceptual loss
    perc = perceptual_loss_fn(gen_sketch, real_sketch)
    
    # CLIP alignment loss (now MPS-compatible)
    clip_loss = compute_clip_loss(gen_sketch, texts)
    
    # Ensure all losses are float32
    adv_loss = adv_loss.float()
    age_loss = age_loss.float()
    l1 = l1.float()
    perc = perc.float()
    clip_loss = clip_loss.float()
    
    # Combined loss with weights
    total = (
        1.0 * adv_loss +        # Adversarial
        LAMBDA_L1 * l1 +        # L1 reconstruction (weight: 100)
        10.0 * perc +           # Perceptual
        5.0 * clip_loss +       # CLIP alignment
        3.0 * age_loss          # Age consistency
    )
    
    return total, adv_loss, l1, perc, clip_loss, age_loss

def discriminator_loss(disc_real_val, disc_real_age, disc_fake_val, disc_fake_age):
    """Discriminator loss - MPS COMPATIBLE"""
    # Real images should be classified as real (1)
    real_loss = bce_loss(disc_real_val, torch.ones_like(disc_real_val))
    
    # Fake images should be classified as fake (0)
    fake_loss = bce_loss(disc_fake_val, torch.zeros_like(disc_fake_val))
    
    # Combined validity loss
    validity_loss = 0.5 * (real_loss + fake_loss)
    
    # Age matching loss for real images
    age_loss = bce_loss(disc_real_age, torch.ones_like(disc_real_age))
    
    return (validity_loss + age_loss).float()

In [None]:
# =========================================================
# üì¶ Create DataLoader
# =========================================================
dataset = SketchDataset(train_A, train_B, desc, estimated_ages)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

print(f"\n‚úÖ Dataset size: {len(dataset)}")
print(f"‚úÖ Batches per epoch: {len(dataloader)}")

In [None]:
import time
import numpy as np
import torch
from tqdm import tqdm

# =========================================================
# üì¶ Create DataLoader
# =========================================================
dataset = SketchDataset(train_A, train_B, desc, estimated_ages)
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=0,  # Changed to 0 for macOS
    pin_memory=False
)

print(f"\n‚úÖ Dataset size: {len(dataset)}")
print(f"‚úÖ Batches per epoch: {len(dataloader)}")

# =========================================================
# üöÇ Training Loop
# =========================================================
if __name__ == '__main__':  # ‚Üê Add this guard
    print("\n" + "="*60)
    print("üöÄ STARTING TRAINING")
    print("="*60)
    
    # ... rest of your training code

start_time = time.time()

# Training history for plotting
history = {
    'g_loss': [],
    'd_loss': [],
    'g_adv_loss': [],
    'g_l1_loss': [],
    'g_perceptual_loss': [],
    'g_clip_loss': [],
    'g_age_loss': []
}

for epoch in range(EPOCHS):
    generator.train()
    discriminator.train()
    
    g_losses = []
    d_losses = []
    epoch_metrics = {
        'adv': [], 'l1': [], 'perc': [], 'clip': [], 'age': []
    }
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_idx, (photos, sketches, texts, ages, filenames) in enumerate(progress_bar):
        batch_size = photos.size(0)
        photos = photos.to(device)
        sketches = sketches.to(device)
        
        # Get text embeddings
        try:
            text_embs = torch.stack([
                torch.FloatTensor(clip_text_embedding(text)) 
                for text in texts
            ]).to(device)
        except Exception as e:
            print(f"\n‚ö†Ô∏è Text embedding error: {e}")
            continue
        
        # Get age embeddings
        age_embs = torch.stack([encode_age(age.item()) for age in ages]).to(device)
        
        # =====================================
        # Train Discriminator
        # =====================================
        disc_opt.zero_grad()
        
        with torch.no_grad():
            gen_sketches = generator(text_embs, age_embs)
        
        disc_real_val, disc_real_age = discriminator(sketches, age_embs)
        disc_fake_val, disc_fake_age = discriminator(gen_sketches.detach(), age_embs)
        
        d_loss = discriminator_loss(disc_real_val, disc_real_age, disc_fake_val, disc_fake_age)
        d_loss.backward()
        disc_opt.step()
        
        # =====================================
        # Train Generator
        # =====================================
        gen_opt.zero_grad()
        
        gen_sketches = generator(text_embs, age_embs)
        disc_fake_val, disc_fake_age = discriminator(gen_sketches, age_embs)
        
        g_loss, adv, l1, perc, clip_l, age_l = generator_loss(
            disc_fake_val, disc_fake_age, gen_sketches, sketches, texts
        )
        g_loss.backward()
        gen_opt.step()
        
        # Record losses
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        epoch_metrics['adv'].append(adv.item())
        epoch_metrics['l1'].append(l1.item())
        epoch_metrics['perc'].append(perc.item())
        epoch_metrics['clip'].append(clip_l.item())
        epoch_metrics['age'].append(age_l.item())
        
        # Update progress bar
        progress_bar.set_postfix({
            'G': f'{np.mean(g_losses):.3f}',
            'D': f'{np.mean(d_losses):.3f}',
            'L1': f'{np.mean(epoch_metrics["l1"]):.3f}'
        })
    
    # Record epoch history
    history['g_loss'].append(np.mean(g_losses))
    history['d_loss'].append(np.mean(d_losses))
    history['g_adv_loss'].append(np.mean(epoch_metrics['adv']))
    history['g_l1_loss'].append(np.mean(epoch_metrics['l1']))
    history['g_perceptual_loss'].append(np.mean(epoch_metrics['perc']))
    history['g_clip_loss'].append(np.mean(epoch_metrics['clip']))
    history['g_age_loss'].append(np.mean(epoch_metrics['age']))
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        checkpoint_path = f"./checkpoints/checkpoint_epoch_{epoch+1}.pth"
        torch.save({
            'epoch': epoch,
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'gen_opt': gen_opt.state_dict(),
            'disc_opt': disc_opt.state_dict(),
            'history': history
        }, checkpoint_path)
        print(f"\nüíæ Checkpoint saved: {checkpoint_path}")
    
    elapsed = time.time() - start_time
    eta = (elapsed / (epoch + 1)) * (EPOCHS - epoch - 1)
    
    print(f"\n‚úÖ Epoch {epoch+1}/{EPOCHS} completed:")
    print(f"   Generator Loss: {np.mean(g_losses):.4f}")
    print(f"   Discriminator Loss: {np.mean(d_losses):.4f}")
    print(f"   L1 Loss: {np.mean(epoch_metrics['l1']):.4f}")
    print(f"   Perceptual Loss: {np.mean(epoch_metrics['perc']):.4f}")
    print(f"   CLIP Loss: {np.mean(epoch_metrics['clip']):.4f}")
    print(f"   Age Loss: {np.mean(epoch_metrics['age']):.4f}")
    print(f"   ‚è± Time: {elapsed/60:.1f}m | ETA: {eta/60:.1f}m")

total_time = time.time() - start_time
print("\n" + "="*60)
print(f"‚úÖ TRAINING COMPLETE! Total time: {total_time/60:.1f} minutes")
print("="*60)

# Save final models
os.makedirs("./models", exist_ok=True)
torch.save(generator.state_dict(), "./models/generator_final.pth")
torch.save(discriminator.state_dict(), "./models/discriminator_final.pth")
torch.save(history, "./models/training_history.pth")

print("\nüíæ Final models saved:")
print("   Generator: ./models/generator_final.pth")
print("   Discriminator: ./models/discriminator_final.pth")
print("   History: ./models/training_history.pth")

# =========================================================
# üìä Plot Training History
# =========================================================
print("\nüìä Plotting training history...")

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Training Progress', fontsize=16, fontweight='bold')

# Generator and Discriminator Loss
axes[0, 0].plot(history['g_loss'], label='Generator', color='blue', linewidth=2)
axes[0, 0].plot(history['d_loss'], label='Discriminator', color='red', linewidth=2)
axes[0, 0].set_title('Generator vs Discriminator Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Adversarial Loss
axes[0, 1].plot(history['g_adv_loss'], color='purple', linewidth=2)
axes[0, 1].set_title('Adversarial Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].grid(True, alpha=0.3)

# L1 Loss
axes[0, 2].plot(history['g_l1_loss'], color='green', linewidth=2)
axes[0, 2].set_title('L1 Reconstruction Loss')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Loss')
axes[0, 2].grid(True, alpha=0.3)

# Perceptual Loss
axes[1, 0].plot(history['g_perceptual_loss'], color='orange', linewidth=2)
axes[1, 0].set_title('Perceptual Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].grid(True, alpha=0.3)

# CLIP Loss
axes[1, 1].plot(history['g_clip_loss'], color='cyan', linewidth=2)
axes[1, 1].set_title('CLIP Alignment Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].grid(True, alpha=0.3)

# Age Loss
axes[1, 2].plot(history['g_age_loss'], color='magenta', linewidth=2)
axes[1, 2].set_title('Age Consistency Loss')
axes[1, 2].set_xlabel('Epoch')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('./models/training_history.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Training history plot saved: ./models/training_history.png")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# =========================================================
# üé® FIXED GENERATION FUNCTIONS
# =========================================================

def create_truly_unique_latent(description, seed=None):
    """
    Create a TRULY unique latent vector that varies significantly 
    between different descriptions.
    
    The key: Use random noise INSTEAD of deterministic embeddings
    """
    if seed is None:
        # Create unique seed from description
        seed = abs(hash(description)) % (2**32)
    
    # Set random seed
    np.random.seed(seed)
    
    # Generate random latent vector (THIS is what creates different faces)
    latent = np.random.randn(TEXT_EMB_DIM).astype(np.float32)
    
    # Optional: Mix in some semantic features from text (10% max)
    try:
        text_features = clip_text_embedding(description)
        # 90% random noise + 10% text features
        latent = 0.9 * latent + 0.1 * text_features
    except:
        pass  # Use pure random if CLIP fails
    
    # Normalize
    latent = latent / np.linalg.norm(latent)
    
    return latent, seed


def generate_sketch_at_age_fixed(description, target_age, variation_seed=None):
    """
    FIXED: Generate different faces for different descriptions
    """
    # Create UNIQUE latent vector (this is what makes faces different)
    latent_vector, used_seed = create_truly_unique_latent(description, variation_seed)
    
    print(f"üîç Generating: '{description[:50]}...' at age {target_age} (seed: {used_seed})")
    
    with torch.no_grad():
        # Convert to tensor
        text_emb = torch.FloatTensor(latent_vector).unsqueeze(0).to(device)
        
        # Get age embedding
        age_emb = encode_age(target_age).unsqueeze(0).to(device)
        
        # Generate
        generated = generator(text_emb, age_emb)
        sketch = generated[0].cpu().numpy().transpose(1, 2, 0)
        sketch = ((sketch + 1.0) * 127.5).clip(0, 255).astype(np.uint8)
    
    return sketch


def generate_age_progression_fixed(description, start_age=25, end_age=65, steps=5):
    """
    FIXED: Generate age progression maintaining identity
    """
    print(f"\n‚è∞ Age progression: '{description}' ({start_age} ‚Üí {end_age})")
    
    ages = np.linspace(start_age, end_age, steps, dtype=int)
    sketches = []
    
    # CRITICAL: Use SAME seed for all ages to maintain identity
    base_seed = abs(hash(description)) % (2**32)
    print(f"   Using seed: {base_seed}")
    
    for age in ages:
        sketch = generate_sketch_at_age_fixed(description, int(age), variation_seed=base_seed)
        sketches.append((int(age), sketch))
    
    # Display
    fig, axes = plt.subplots(1, len(sketches), figsize=(4*len(sketches), 4))
    if len(sketches) == 1:
        axes = [axes]
    
    for ax, (age, sketch) in zip(axes, sketches):
        ax.imshow(sketch)
        ax.set_title(f"Age {age}", fontsize=12, fontweight='bold')
        ax.axis('off')
    
    plt.suptitle(f"Age Progression: {description}", fontsize=14, y=1.02)
    plt.tight_layout()
    
    # Save
    safe_desc = "".join(c for c in description[:30] if c.isalnum() or c in (' ', '_')).strip()
    filename = f"./age_progression_results/fixed_progression_{safe_desc}_{start_age}_{end_age}.png"
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"‚úÖ Saved: {filename}")
    
    return sketches


# =========================================================
# üß™ TEST SUITE - Verify Different Faces
# =========================================================

def test_different_descriptions_create_different_faces():
    """
    This test MUST show 4 completely different faces!
    """
    print("\n" + "="*60)
    print("üî¨ TEST: Different Descriptions ‚Üí Different Faces")
    print("="*60)
    
    descriptions = [
        "woman with long blonde hair",
        "man with short dark hair and beard",
        "elderly person with glasses",
        "young person with curly hair"
    ]
    
    age = 35
    fig, axes = plt.subplots(1, len(descriptions), figsize=(5*len(descriptions), 5))
    
    for i, desc in enumerate(descriptions):
        sketch = generate_sketch_at_age_fixed(desc, age)
        axes[i].imshow(sketch)
        axes[i].set_title(f"{desc}\nAge {age}", fontsize=10)
        axes[i].axis('off')
    
    plt.suptitle("‚úÖ These Should Be 4 DIFFERENT Faces!", 
                 fontsize=14, fontweight='bold', color='green')
    plt.tight_layout()
    plt.savefig("./age_progression_results/test_different_faces.png", dpi=150)
    plt.show()
    
    print("‚úÖ Test complete! Check that all 4 faces look different.")


def test_same_description_maintains_identity():
    """
    This test MUST show the SAME person aging!
    """
    print("\n" + "="*60)
    print("üî¨ TEST: Same Description ‚Üí Same Identity Across Ages")
    print("="*60)
    
    description = "woman with long hair"
    ages = [25, 35, 45, 55, 65]
    
    fig, axes = plt.subplots(1, len(ages), figsize=(4*len(ages), 4))
    
    # Use consistent seed
    seed = abs(hash(description)) % (2**32)
    
    for i, age in enumerate(ages):
        sketch = generate_sketch_at_age_fixed(description, age, variation_seed=seed)
        axes[i].imshow(sketch)
        axes[i].set_title(f"Age {age}", fontsize=12, fontweight='bold')
        axes[i].axis('off')
    
    plt.suptitle(f"‚úÖ These Should Be the SAME Person Aging!", 
                 fontsize=14, fontweight='bold', color='green')
    plt.tight_layout()
    plt.savefig("./age_progression_results/test_same_identity.png", dpi=150)
    plt.show()
    
    print("‚úÖ Test complete! Check that all faces look like the same person.")




In [None]:
# =========================================================
# üöÄ RUN TESTS
# =========================================================

print("\n" + "="*60)
print("üé® RUNNING FIXED TESTS")
print("="*60)

# Test 1: Different descriptions
try:
    test_different_descriptions_create_different_faces()
except Exception as e:
    print(f"‚ùå Test failed: {e}")
    import traceback
    traceback.print_exc()

# Test 2: Same identity
try:
    test_same_description_maintains_identity()
except Exception as e:
    print(f"‚ùå Test failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: Multiple age progressions
print("\n" + "="*60)
print("üî¨ TEST: Multiple Different Age Progressions")
print("="*60)

descriptions = [
    "a young woman with short hair",
    "a man with a mustache",
    "a person with glasses"
]

for desc in descriptions:
    try:
        generate_age_progression_fixed(desc, 25, 65, 5)
    except Exception as e:
        print(f"‚ùå Failed for '{desc}': {e}")

print("\n" + "="*60)
print("‚úÖ ALL TESTS COMPLETE!")
print("="*60)
print("\nCheck ./age_progression_results/ for results:")
print("  - test_different_faces.png (should show 4 DIFFERENT people)")
print("  - test_same_identity.png (should show 1 person aging)")
print("  - fixed_progression_*.png (age progression sequences)")