# Deep Learning Focus: Contrastive Triplet Learning

## 🎯 Learning Objectives
- Master contrastive learning principles and triplet loss formulation
- Understand hard negative mining strategies for Arabic text
- Implement advanced negative sampling techniques
- Explore InfoNCE limitations and solutions
- Develop effective triplet dataset construction methods

## 📚 Paper Context
**From GATE Paper (Section 1):**
> *"At the heart of many highly effective embedding models lies contrastive learning, a paradigm that optimizes the quality of representation by pulling semantically similar (positive) samples closer while pushing dissimilar (negative) samples apart. Despite the versatility and success of contrastive learning, most existing text embedding pipelines rely on InfoNCE loss with in-batch negative samples, achieving robust representations predominantly by using large batch sizes and numerous negative samples."*

**Key Challenges Identified:**
1. **InfoNCE Limitations**: Not sufficient for all downstream tasks
2. **Sentence-level Tasks**: STS benefits less from InfoNCE-based training
3. **Fine-grained Similarity**: InfoNCE struggles with subtle semantic differences
4. **Negative Sampling**: Quality of negatives crucial for learning
5. **Arabic Specificity**: Need for Arabic-tailored contrastive learning

## 🔑 GATE's Innovation
GATE uses **Arabic NLI triplet datasets** with curated hard negatives, moving beyond simple InfoNCE to capture fine-grained semantic relationships crucial for Arabic text understanding.

## Environment Setup for Contrastive Learning

In [None]:
# Core libraries for contrastive learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Optional, Union
import warnings
warnings.filterwarnings('ignore')

# Advanced sampling and mining
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
from collections import defaultdict, Counter
import random
import itertools

# Data handling
import pandas as pd
from dataclasses import dataclass

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

print("🔥 Contrastive Learning Environment Ready!")
print(f"📱 Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"🎯 Focus: Advanced contrastive and triplet learning")

## 🧮 Mathematical Foundation: Contrastive Learning

### Understanding InfoNCE and Triplet Loss Formulations
Let's implement and analyze the mathematical foundations of contrastive learning.

In [None]:
class ContrastiveLearningFoundation:
    """Mathematical foundation and analysis of contrastive learning"""
    
    def __init__(self, temperature=0.07, margin=0.5):
        self.temperature = temperature
        self.margin = margin
        
    def infonce_loss(self, anchor, positive, negatives, return_components=False):
        """
        Compute InfoNCE loss
        
        InfoNCE: L = -log(exp(sim(a,p)/τ) / (exp(sim(a,p)/τ) + Σ exp(sim(a,n)/τ)))
        
        Args:
            anchor: [batch_size, dim] anchor embeddings
            positive: [batch_size, dim] positive embeddings  
            negatives: [batch_size, num_negatives, dim] negative embeddings
        """
        batch_size = anchor.size(0)
        
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor, positive, dim=1) / self.temperature
        
        # Compute negative similarities
        neg_sims = []
        for i in range(batch_size):
            anchor_expanded = anchor[i:i+1].expand(negatives.size(1), -1)
            neg_sim = F.cosine_similarity(anchor_expanded, negatives[i], dim=1) / self.temperature
            neg_sims.append(neg_sim)
        
        # Compute InfoNCE loss
        losses = []
        for i in range(batch_size):
            numerator = torch.exp(pos_sim[i])
            denominator = numerator + torch.sum(torch.exp(neg_sims[i]))
            loss = -torch.log(numerator / denominator)
            losses.append(loss)
        
        final_loss = torch.stack(losses).mean()
        
        if return_components:
            components = {
                'positive_similarities': pos_sim.detach(),
                'negative_similarities': [neg_sim.detach() for neg_sim in neg_sims],
                'individual_losses': torch.stack(losses).detach()
            }
            return final_loss, components
        
        return final_loss
    
    def triplet_loss(self, anchor, positive, negative, return_components=False):
        """
        Compute triplet loss with margin
        
        Triplet: L = max(0, margin + sim(a,n) - sim(a,p))
        
        Args:
            anchor: [batch_size, dim] anchor embeddings
            positive: [batch_size, dim] positive embeddings
            negative: [batch_size, dim] negative embeddings
        """
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor, positive, dim=1)
        neg_sim = F.cosine_similarity(anchor, negative, dim=1)
        
        # Triplet loss with margin
        loss = torch.clamp(self.margin + neg_sim - pos_sim, min=0.0).mean()
        
        if return_components:
            components = {
                'positive_similarities': pos_sim.detach(),
                'negative_similarities': neg_sim.detach(),
                'margin_violations': (neg_sim - pos_sim + self.margin).detach()
            }
            return loss, components
        
        return loss
    
    def multiple_negatives_ranking_loss(self, anchor, positive, negatives, return_components=False):
        """
        Multiple Negatives Ranking Loss (used in sentence-transformers)
        
        Similar to InfoNCE but designed for ranking tasks
        """
        batch_size = anchor.size(0)
        
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor, positive, dim=1) / self.temperature
        
        # Compute all negative similarities
        all_neg_sims = []
        for i in range(batch_size):
            for j in range(negatives.size(1)):
                neg_sim = F.cosine_similarity(
                    anchor[i:i+1], negatives[i, j:j+1], dim=1
                ) / self.temperature
                all_neg_sims.append(neg_sim)
        
        # Combine positive and negative similarities
        all_sims = torch.cat([pos_sim] + all_neg_sims)
        
        # Create labels (first batch_size are positives)
        labels = torch.arange(batch_size).to(anchor.device)
        
        # Cross-entropy loss
        logits = all_sims.view(batch_size, -1)
        loss = F.cross_entropy(logits, labels)
        
        if return_components:
            components = {
                'logits': logits.detach(),
                'labels': labels.detach(),
                'positive_similarities': pos_sim.detach()
            }
            return loss, components
        
        return loss
    
    def demonstrate_loss_behaviors(self):
        """Demonstrate different loss function behaviors"""
        print("🧮 Contrastive Loss Function Analysis")
        print("=" * 40)
        
        # Create sample embeddings
        batch_size, dim, num_negatives = 4, 128, 3
        
        anchor = torch.randn(batch_size, dim)
        positive = torch.randn(batch_size, dim)
        negatives = torch.randn(batch_size, num_negatives, dim)
        
        # Make positives more similar to anchors
        positive = 0.7 * anchor + 0.3 * positive
        
        # Make some negatives harder (more similar to anchor)
        negatives[:, 0] = 0.3 * anchor + 0.7 * negatives[:, 0]  # Hard negatives
        
        print(f"📊 Sample Setup:")
        print(f"   Batch size: {batch_size}")
        print(f"   Embedding dimension: {dim}")
        print(f"   Number of negatives: {num_negatives}")
        print(f"   Temperature: {self.temperature}")
        print(f"   Margin: {self.margin}")
        
        # Test InfoNCE
        print(f"\n🔍 InfoNCE Loss Analysis:")
        infonce_loss, infonce_components = self.infonce_loss(
            anchor, positive, negatives, return_components=True
        )
        print(f"   Loss value: {infonce_loss.item():.4f}")
        print(f"   Positive similarities: {infonce_components['positive_similarities'].tolist()}")
        
        # Test Triplet Loss
        print(f"\n🔍 Triplet Loss Analysis:")
        triplet_loss, triplet_components = self.triplet_loss(
            anchor, positive, negatives[:, 0], return_components=True  # Use first negative
        )
        print(f"   Loss value: {triplet_loss.item():.4f}")
        print(f"   Positive similarities: {triplet_components['positive_similarities'].tolist()}")
        print(f"   Negative similarities: {triplet_components['negative_similarities'].tolist()}")
        print(f"   Margin violations: {triplet_components['margin_violations'].tolist()}")
        
        # Test Multiple Negatives Ranking
        print(f"\n🔍 Multiple Negatives Ranking Loss:")
        mnr_loss, mnr_components = self.multiple_negatives_ranking_loss(
            anchor, positive, negatives, return_components=True
        )
        print(f"   Loss value: {mnr_loss.item():.4f}")
        print(f"   Logits shape: {mnr_components['logits'].shape}")
        
        return {
            'infonce': (infonce_loss, infonce_components),
            'triplet': (triplet_loss, triplet_components),
            'mnr': (mnr_loss, mnr_components)
        }
    
    def analyze_temperature_effects(self):
        """Analyze effects of temperature on contrastive learning"""
        print("\n🌡️ Temperature Effects on Contrastive Learning")
        print("=" * 45)
        
        # Sample similarities
        similarities = torch.tensor([0.9, 0.7, 0.5, 0.3, 0.1, -0.1])
        temperatures = [0.01, 0.05, 0.1, 0.5, 1.0]
        
        print(f"📊 Raw similarities: {similarities.tolist()}")
        print(f"\nTemperature effects on softmax distribution:")
        
        for temp in temperatures:
            scaled = similarities / temp
            softmax_probs = F.softmax(scaled, dim=0)
            
            print(f"\n   τ = {temp:4.2f}:")
            print(f"      Scaled: {scaled.tolist()}")
            print(f"      Softmax: {softmax_probs.tolist()}")
            print(f"      Max prob: {torch.max(softmax_probs).item():.3f}")
            print(f"      Entropy: {-torch.sum(softmax_probs * torch.log(softmax_probs + 1e-8)).item():.3f}")
            
            # Concentration analysis
            concentration = (softmax_probs[0] / torch.sum(softmax_probs[1:])).item()
            print(f"      Concentration ratio: {concentration:.3f}")
        
        return temperatures, similarities

# Initialize foundation and run analysis
contrastive_foundation = ContrastiveLearningFoundation(temperature=0.07, margin=0.5)
loss_analysis = contrastive_foundation.demonstrate_loss_behaviors()
temperature_analysis = contrastive_foundation.analyze_temperature_effects()

## 🎯 Hard Negative Mining Strategies

### Advanced Negative Sampling for Arabic Text
The quality of negative samples is crucial for effective contrastive learning. Let's implement sophisticated mining strategies.

In [None]:
@dataclass
class ArabicTriplet:
    """Data structure for Arabic text triplets"""
    anchor: str
    positive: str
    negative: str
    anchor_embedding: Optional[torch.Tensor] = None
    positive_embedding: Optional[torch.Tensor] = None
    negative_embedding: Optional[torch.Tensor] = None
    difficulty: float = 0.0  # Mining difficulty score
    category: str = "general"  # Semantic category

class HardNegativeMiner:
    """Advanced hard negative mining for Arabic text"""
    
    def __init__(self, embedding_dim=768, similarity_threshold=0.7):
        self.embedding_dim = embedding_dim
        self.similarity_threshold = similarity_threshold
        
        # Strategies for negative mining
        self.mining_strategies = {
            'random': self.random_negative_mining,
            'hard': self.hard_negative_mining,
            'semi_hard': self.semi_hard_negative_mining,
            'cluster_based': self.cluster_based_mining,
            'semantic': self.semantic_category_mining
        }
    
    def create_mock_arabic_corpus(self):
        """Create mock Arabic text corpus for demonstration"""
        
        corpus = {
            'education': [
                "الطالب يدرس في المكتبة بجد واجتهاد",
                "المعلم يشرح الدرس للطلاب في الفصل",
                "الجامعة تقدم برامج تعليمية متنوعة ومتميزة",
                "البحث العلمي يساهم في تطوير المعرفة",
                "التعليم الإلكتروني أصبح ضرورة في العصر الحديث"
            ],
            'technology': [
                "الذكاء الاصطناعي يغير مستقبل التكنولوجيا",
                "الحاسوب أداة مهمة في العمل والحياة",
                "الإنترنت يربط العالم ببعضه البعض",
                "البرمجيات تسهل الكثير من المهام اليومية",
                "الهواتف الذكية تحتوي على تقنيات متقدمة"
            ],
            'health': [
                "الرياضة تحافظ على صحة الجسم والعقل",
                "التغذية الصحية أساس الحياة السليمة",
                "الطبيب يعالج المرضى بخبرة ومهارة",
                "المستشفى يقدم خدمات طبية متطورة",
                "الوقاية خير من العلاج في جميع الأحوال"
            ],
            'culture': [
                "الثقافة العربية تحتوي على تراث غني ومتنوع",
                "الشعر العربي له تاريخ طويل ومجيد",
                "الفنون التراثية تعبر عن هوية الشعوب",
                "اللغة العربية لغة الضاد والبيان",
                "الأدب العربي يحتوي على كنوز من المعرفة"
            ]
        }
        
        return corpus
    
    def generate_embeddings(self, texts: List[str]) -> torch.Tensor:
        """Generate mock embeddings for texts"""
        embeddings = torch.randn(len(texts), self.embedding_dim)
        
        # Add some semantic structure (texts with similar words get similar embeddings)
        for i, text1 in enumerate(texts):
            words1 = set(text1.split())
            for j, text2 in enumerate(texts):
                if i != j:
                    words2 = set(text2.split())
                    overlap = len(words1.intersection(words2)) / len(words1.union(words2))
                    
                    # Make embeddings more similar based on word overlap
                    if overlap > 0.2:
                        embeddings[j] = 0.3 * embeddings[i] + 0.7 * embeddings[j]
        
        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings
    
    def random_negative_mining(self, anchor_text: str, positive_text: str, 
                              corpus: Dict[str, List[str]], num_negatives: int = 5) -> List[str]:
        """Random negative sampling (baseline)"""
        all_texts = []
        for category_texts in corpus.values():
            all_texts.extend(category_texts)
        
        # Remove anchor and positive from candidates
        candidates = [text for text in all_texts if text not in [anchor_text, positive_text]]
        
        return random.sample(candidates, min(num_negatives, len(candidates)))
    
    def hard_negative_mining(self, anchor_embedding: torch.Tensor, positive_embedding: torch.Tensor,
                            candidate_embeddings: torch.Tensor, candidate_texts: List[str],
                            num_negatives: int = 5) -> List[str]:
        """Hard negative mining - select most similar negatives"""
        # Compute similarities with anchor
        similarities = F.cosine_similarity(
            anchor_embedding.unsqueeze(0), candidate_embeddings, dim=1
        )
        
        # Get indices of most similar (hardest) negatives
        _, hard_indices = torch.topk(similarities, min(num_negatives, len(similarities)))
        
        return [candidate_texts[idx] for idx in hard_indices.tolist()]
    
    def semi_hard_negative_mining(self, anchor_embedding: torch.Tensor, positive_embedding: torch.Tensor,
                                 candidate_embeddings: torch.Tensor, candidate_texts: List[str],
                                 num_negatives: int = 5) -> List[str]:
        """Semi-hard negative mining - negatives closer to anchor than positive"""
        # Compute similarities
        anchor_pos_sim = F.cosine_similarity(anchor_embedding, positive_embedding, dim=0)
        anchor_neg_sims = F.cosine_similarity(
            anchor_embedding.unsqueeze(0), candidate_embeddings, dim=1
        )
        
        # Semi-hard condition: similarity(anchor, negative) > similarity(anchor, positive)
        semi_hard_mask = anchor_neg_sims > anchor_pos_sim
        semi_hard_indices = torch.where(semi_hard_mask)[0]
        
        if len(semi_hard_indices) == 0:
            # Fallback to hard negatives if no semi-hard found
            return self.hard_negative_mining(
                anchor_embedding, positive_embedding, candidate_embeddings, candidate_texts, num_negatives
            )
        
        # Select from semi-hard negatives
        selected_indices = semi_hard_indices[
            torch.randperm(len(semi_hard_indices))[:min(num_negatives, len(semi_hard_indices))]
        ]
        
        return [candidate_texts[idx] for idx in selected_indices.tolist()]
    
    def cluster_based_mining(self, anchor_embedding: torch.Tensor, positive_embedding: torch.Tensor,
                            candidate_embeddings: torch.Tensor, candidate_texts: List[str],
                            num_negatives: int = 5, num_clusters: int = 3) -> List[str]:
        """Cluster-based negative mining for diverse negatives"""
        # Perform clustering on candidate embeddings
        if len(candidate_embeddings) < num_clusters:
            return self.random_negative_mining(None, None, {'all': candidate_texts}, num_negatives)
        
        kmeans = KMeans(n_clusters=num_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(candidate_embeddings.numpy())
        
        # Sample negatives from different clusters
        negatives = []
        negatives_per_cluster = max(1, num_negatives // num_clusters)
        
        for cluster_id in range(num_clusters):
            cluster_indices = np.where(cluster_labels == cluster_id)[0]
            if len(cluster_indices) > 0:
                selected = np.random.choice(
                    cluster_indices, 
                    min(negatives_per_cluster, len(cluster_indices)), 
                    replace=False
                )
                negatives.extend([candidate_texts[idx] for idx in selected])
        
        return negatives[:num_negatives]
    
    def semantic_category_mining(self, anchor_text: str, positive_text: str,
                                corpus: Dict[str, List[str]], num_negatives: int = 5) -> List[str]:
        """Semantic category-based mining - avoid same category negatives"""
        # Find anchor category
        anchor_category = None
        for category, texts in corpus.items():
            if anchor_text in texts:
                anchor_category = category
                break
        
        # Sample from different categories
        negatives = []
        for category, texts in corpus.items():
            if category != anchor_category:
                available = [text for text in texts if text not in [anchor_text, positive_text]]
                if available:
                    negatives.extend(random.sample(available, min(2, len(available))))
        
        return negatives[:num_negatives]
    
    def create_arabic_triplets(self, corpus: Dict[str, List[str]], 
                              strategy: str = 'hard', num_triplets: int = 20) -> List[ArabicTriplet]:
        """Create Arabic triplets using specified mining strategy"""
        print(f"🎯 Creating Arabic Triplets with {strategy} mining")
        print("=" * 50)
        
        triplets = []
        all_texts = []
        
        # Collect all texts
        for category_texts in corpus.values():
            all_texts.extend(category_texts)
        
        # Generate embeddings
        all_embeddings = self.generate_embeddings(all_texts)
        text_to_embedding = {text: emb for text, emb in zip(all_texts, all_embeddings)}
        
        triplet_count = 0
        
        # Create triplets for each category
        for category, texts in corpus.items():
            if len(texts) < 2:
                continue
                
            # Create positive pairs within category
            for i in range(len(texts)):
                for j in range(i + 1, len(texts)):
                    if triplet_count >= num_triplets:
                        break
                        
                    anchor_text = texts[i]
                    positive_text = texts[j]
                    
                    # Get embeddings
                    anchor_emb = text_to_embedding[anchor_text]
                    positive_emb = text_to_embedding[positive_text]
                    
                    # Mine negatives based on strategy
                    if strategy in ['hard', 'semi_hard', 'cluster_based']:
                        # Get candidate negatives (exclude same category for diversity)
                        candidate_texts = []
                        candidate_embeddings = []
                        
                        for other_category, other_texts in corpus.items():
                            if other_category != category:
                                candidate_texts.extend(other_texts)
                                candidate_embeddings.extend([
                                    text_to_embedding[text] for text in other_texts
                                ])
                        
                        if candidate_embeddings:
                            candidate_embeddings = torch.stack(candidate_embeddings)
                            negatives = self.mining_strategies[strategy](
                                anchor_emb, positive_emb, candidate_embeddings, candidate_texts, 1
                            )
                        else:
                            negatives = self.random_negative_mining(anchor_text, positive_text, corpus, 1)
                    else:
                        negatives = self.mining_strategies[strategy](anchor_text, positive_text, corpus, 1)
                    
                    if negatives:
                        negative_text = negatives[0]
                        negative_emb = text_to_embedding[negative_text]
                        
                        # Calculate difficulty score
                        pos_sim = F.cosine_similarity(anchor_emb, positive_emb, dim=0).item()
                        neg_sim = F.cosine_similarity(anchor_emb, negative_emb, dim=0).item()
                        difficulty = neg_sim - pos_sim + 1.0  # Higher = more difficult
                        
                        triplet = ArabicTriplet(
                            anchor=anchor_text,
                            positive=positive_text,
                            negative=negative_text,
                            anchor_embedding=anchor_emb,
                            positive_embedding=positive_emb,
                            negative_embedding=negative_emb,
                            difficulty=difficulty,
                            category=category
                        )
                        
                        triplets.append(triplet)
                        triplet_count += 1
                        
                        print(f"   Triplet {triplet_count}: {category} (difficulty: {difficulty:.3f})")
                        print(f"      Anchor: {anchor_text[:50]}...")
                        print(f"      Positive: {positive_text[:50]}...")
                        print(f"      Negative: {negative_text[:50]}...")
                        print(f"      Pos sim: {pos_sim:.3f}, Neg sim: {neg_sim:.3f}")
                        print()
                
                if triplet_count >= num_triplets:
                    break
            
            if triplet_count >= num_triplets:
                break
        
        print(f"✅ Created {len(triplets)} Arabic triplets using {strategy} mining")
        return triplets
    
    def analyze_mining_strategies(self, corpus: Dict[str, List[str]]) -> Dict[str, List[ArabicTriplet]]:
        """Compare different mining strategies"""
        print("\n🔍 Comparing Mining Strategies")
        print("=" * 35)
        
        strategies_results = {}
        
        for strategy in ['random', 'hard', 'semi_hard', 'semantic']:
            print(f"\n📊 Testing {strategy} strategy:")
            triplets = self.create_arabic_triplets(corpus, strategy, num_triplets=5)
            strategies_results[strategy] = triplets
            
            # Analyze strategy characteristics
            if triplets:
                difficulties = [t.difficulty for t in triplets]
                avg_difficulty = np.mean(difficulties)
                std_difficulty = np.std(difficulties)
                
                print(f"   Average difficulty: {avg_difficulty:.3f}")
                print(f"   Difficulty std: {std_difficulty:.3f}")
                print(f"   Difficulty range: {min(difficulties):.3f} - {max(difficulties):.3f}")
        
        return strategies_results

# Initialize miner and create corpus
hard_negative_miner = HardNegativeMiner()
arabic_corpus = hard_negative_miner.create_mock_arabic_corpus()

print("📚 Arabic Corpus Created:")
for category, texts in arabic_corpus.items():
    print(f"   {category}: {len(texts)} texts")

# Test different mining strategies
mining_comparison = hard_negative_miner.analyze_mining_strategies(arabic_corpus)

## 🏗️ Advanced Triplet Training Framework

### Complete Implementation for Arabic Text

In [None]:
class ArabicTripletTrainer:
    """Advanced triplet training framework for Arabic text embeddings"""
    
    def __init__(self, embedding_dim=768, margin=0.5, temperature=0.07, 
                 mining_strategy='hard', curriculum_learning=True):
        self.embedding_dim = embedding_dim
        self.margin = margin
        self.temperature = temperature
        self.mining_strategy = mining_strategy
        self.curriculum_learning = curriculum_learning
        
        # Training history
        self.training_history = {
            'losses': [],
            'difficulties': [],
            'positive_similarities': [],
            'negative_similarities': [],
            'margin_violations': []
        }
        
        # Curriculum learning parameters
        self.curriculum_stages = {
            'easy': {'difficulty_threshold': 0.3, 'epochs': 2},
            'medium': {'difficulty_threshold': 0.6, 'epochs': 3},
            'hard': {'difficulty_threshold': 1.0, 'epochs': 5}
        }
        
    def triplet_loss_with_hard_negatives(self, anchor, positive, negative, 
                                        return_components=False):
        """
        Advanced triplet loss with multiple negatives and adaptive margin
        """
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor, positive, dim=1)
        neg_sim = F.cosine_similarity(anchor, negative, dim=1)
        
        # Adaptive margin based on positive similarity
        adaptive_margin = self.margin * (1.0 - pos_sim.detach())
        
        # Triplet loss with adaptive margin
        loss = torch.clamp(adaptive_margin + neg_sim - pos_sim, min=0.0)
        
        # Add hard negative penalty (encourage diversity)
        hard_negative_penalty = torch.clamp(neg_sim - 0.8, min=0.0) * 0.1
        
        total_loss = (loss + hard_negative_penalty).mean()
        
        if return_components:
            components = {
                'base_loss': loss.mean().item(),
                'hard_penalty': hard_negative_penalty.mean().item(),
                'positive_similarities': pos_sim.detach(),
                'negative_similarities': neg_sim.detach(),
                'adaptive_margins': adaptive_margin.detach(),
                'margin_violations': (neg_sim - pos_sim + adaptive_margin).detach()
            }
            return total_loss, components
        
        return total_loss
    
    def contrastive_loss_with_temperature(self, anchor, positive, negatives,
                                         return_components=False):
        """
        InfoNCE-style loss with learnable temperature
        """
        batch_size = anchor.size(0)
        
        # Compute positive similarities
        pos_sim = F.cosine_similarity(anchor, positive, dim=1) / self.temperature
        
        # Compute negative similarities
        neg_sims = []
        for i in range(batch_size):
            if negatives.dim() == 3:  # [batch_size, num_negatives, dim]
                anchor_expanded = anchor[i:i+1].expand(negatives.size(1), -1)
                neg_sim = F.cosine_similarity(anchor_expanded, negatives[i], dim=1) / self.temperature
            else:  # [batch_size, dim] - single negative per sample
                neg_sim = F.cosine_similarity(anchor[i:i+1], negatives[i:i+1], dim=1) / self.temperature
            neg_sims.append(neg_sim)
        
        # Compute contrastive loss
        losses = []
        for i in range(batch_size):
            numerator = torch.exp(pos_sim[i])
            denominator = numerator + torch.sum(torch.exp(neg_sims[i]))
            loss = -torch.log(numerator / (denominator + 1e-8))
            losses.append(loss)
        
        total_loss = torch.stack(losses).mean()
        
        if return_components:
            components = {
                'positive_similarities': pos_sim.detach(),
                'negative_similarities': [neg_sim.detach() for neg_sim in neg_sims],
                'individual_losses': torch.stack(losses).detach(),
                'temperature': self.temperature
            }
            return total_loss, components
        
        return total_loss
    
    def curriculum_training_step(self, triplets: List[ArabicTriplet], 
                                current_stage: str) -> Dict[str, float]:
        """
        Perform curriculum learning training step
        """
        stage_config = self.curriculum_stages[current_stage]
        
        # Filter triplets by difficulty
        filtered_triplets = [
            t for t in triplets 
            if t.difficulty <= stage_config['difficulty_threshold']
        ]
        
        if not filtered_triplets:
            print(f"⚠️ No triplets found for {current_stage} stage")
            return {'loss': 0.0, 'num_samples': 0}
        
        print(f"\n📚 Training Stage: {current_stage}")
        print(f"   Difficulty threshold: {stage_config['difficulty_threshold']}")
        print(f"   Available triplets: {len(filtered_triplets)}")
        
        # Prepare batch data
        anchors = torch.stack([t.anchor_embedding for t in filtered_triplets])
        positives = torch.stack([t.positive_embedding for t in filtered_triplets])
        negatives = torch.stack([t.negative_embedding for t in filtered_triplets])
        
        # Compute loss
        loss, components = self.triplet_loss_with_hard_negatives(
            anchors, positives, negatives, return_components=True
        )
        
        # Record training history
        self.training_history['losses'].append(loss.item())
        self.training_history['difficulties'].extend([t.difficulty for t in filtered_triplets])
        self.training_history['positive_similarities'].extend(
            components['positive_similarities'].tolist()
        )
        self.training_history['negative_similarities'].extend(
            components['negative_similarities'].tolist()
        )
        self.training_history['margin_violations'].extend(
            components['margin_violations'].tolist()
        )
        
        # Calculate metrics
        pos_sim_mean = components['positive_similarities'].mean().item()
        neg_sim_mean = components['negative_similarities'].mean().item()
        margin_violations = (components['margin_violations'] > 0).float().mean().item()
        
        results = {
            'loss': loss.item(),
            'base_loss': components['base_loss'],
            'hard_penalty': components['hard_penalty'],
            'num_samples': len(filtered_triplets),
            'pos_sim_mean': pos_sim_mean,
            'neg_sim_mean': neg_sim_mean,
            'margin_violations': margin_violations,
            'avg_difficulty': np.mean([t.difficulty for t in filtered_triplets])
        }
        
        print(f"   Loss: {loss.item():.4f}")
        print(f"   Positive similarity: {pos_sim_mean:.3f}")
        print(f"   Negative similarity: {neg_sim_mean:.3f}")
        print(f"   Margin violations: {margin_violations:.1%}")
        
        return results
    
    def full_curriculum_training(self, triplets: List[ArabicTriplet]) -> Dict[str, List[Dict]]:
        """
        Complete curriculum learning training
        """
        print("🎓 Starting Curriculum Learning Training")
        print("=" * 45)
        
        all_results = {}
        
        for stage in ['easy', 'medium', 'hard']:
            stage_results = []
            
            for epoch in range(self.curriculum_stages[stage]['epochs']):
                print(f"\n🔄 Epoch {epoch + 1}/{self.curriculum_stages[stage]['epochs']}")
                
                epoch_result = self.curriculum_training_step(triplets, stage)
                epoch_result['stage'] = stage
                epoch_result['epoch'] = epoch + 1
                
                stage_results.append(epoch_result)
            
            all_results[stage] = stage_results
        
        return all_results
    
    def analyze_training_dynamics(self) -> Dict[str, float]:
        """
        Analyze training dynamics and convergence
        """
        print("\n📊 Training Dynamics Analysis")
        print("=" * 30)
        
        if not self.training_history['losses']:
            print("⚠️ No training history available")
            return {}
        
        losses = self.training_history['losses']
        pos_sims = self.training_history['positive_similarities']
        neg_sims = self.training_history['negative_similarities']
        violations = self.training_history['margin_violations']
        
        analysis = {
            'loss_trend': 'decreasing' if losses[-1] < losses[0] else 'increasing',
            'loss_reduction': (losses[0] - losses[-1]) / losses[0] * 100,
            'avg_loss': np.mean(losses),
            'loss_std': np.std(losses),
            'final_loss': losses[-1],
            'avg_pos_similarity': np.mean(pos_sims),
            'avg_neg_similarity': np.mean(neg_sims),
            'similarity_gap': np.mean(pos_sims) - np.mean(neg_sims),
            'violation_rate': np.mean([v > 0 for v in violations]),
            'training_stability': 1.0 / (1.0 + np.std(losses[-5:]) if len(losses) >= 5 else np.std(losses))
        }
        
        print(f"📈 Loss Trend: {analysis['loss_trend']}")
        print(f"📉 Loss Reduction: {analysis['loss_reduction']:.1f}%")
        print(f"🎯 Final Loss: {analysis['final_loss']:.4f}")
        print(f"✅ Positive Similarity: {analysis['avg_pos_similarity']:.3f}")
        print(f"❌ Negative Similarity: {analysis['avg_neg_similarity']:.3f}")
        print(f"📏 Similarity Gap: {analysis['similarity_gap']:.3f}")
        print(f"⚠️ Violation Rate: {analysis['violation_rate']:.1%}")
        print(f"🔒 Training Stability: {analysis['training_stability']:.3f}")
        
        return analysis

# Test triplet training with different strategies
def test_triplet_training_strategies():
    """Test different triplet training strategies"""
    print("\n🧪 Testing Triplet Training Strategies")
    print("=" * 40)
    
    strategies = ['hard', 'semi_hard', 'random']
    results = {}
    
    for strategy in strategies:
        print(f"\n🎯 Testing {strategy} strategy:")
        
        # Create triplets with this strategy
        triplets = hard_negative_miner.create_arabic_triplets(
            arabic_corpus, strategy=strategy, num_triplets=10
        )
        
        if triplets:
            # Initialize trainer
            trainer = ArabicTripletTrainer(
                mining_strategy=strategy,
                curriculum_learning=True
            )
            
            # Run curriculum training
            training_results = trainer.full_curriculum_training(triplets)
            
            # Analyze training dynamics
            dynamics = trainer.analyze_training_dynamics()
            
            results[strategy] = {
                'training_results': training_results,
                'dynamics': dynamics,
                'num_triplets': len(triplets)
            }
    
    return results

# Run comprehensive testing
training_strategy_results = test_triplet_training_strategies()

## 📊 Advanced Visualization and Analysis

### Understanding Contrastive Learning Dynamics

In [None]:
def visualize_contrastive_learning_analysis():
    """Create comprehensive visualizations of contrastive learning"""
    
    # Create comprehensive subplot layout
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Plot 1: Mining Strategy Comparison
    ax1 = axes[0, 0]
    if mining_comparison:
        strategies = list(mining_comparison.keys())
        avg_difficulties = []
        
        for strategy in strategies:
            triplets = mining_comparison[strategy]
            if triplets:
                avg_difficulties.append(np.mean([t.difficulty for t in triplets]))
            else:
                avg_difficulties.append(0)
        
        bars = ax1.bar(strategies, avg_difficulties, alpha=0.8, 
                      color=['skyblue', 'orange', 'green', 'red'])
        ax1.set_title('Average Difficulty by Mining Strategy', fontweight='bold')
        ax1.set_ylabel('Average Difficulty Score')
        ax1.tick_params(axis='x', rotation=45)
        ax1.grid(True, alpha=0.3)
        
        # Add value labels on bars
        for bar, value in zip(bars, avg_difficulties):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom')
    
    # Plot 2: Loss Function Comparison
    ax2 = axes[0, 1]
    if loss_analysis:
        loss_types = ['InfoNCE', 'Triplet', 'MNR']
        loss_values = [
            loss_analysis['infonce'][0].item(),
            loss_analysis['triplet'][0].item(),
            loss_analysis['mnr'][0].item()
        ]
        
        ax2.bar(loss_types, loss_values, alpha=0.8, color=['blue', 'green', 'purple'])
        ax2.set_title('Loss Function Comparison', fontweight='bold')
        ax2.set_ylabel('Loss Value')
        ax2.grid(True, alpha=0.3)
    
    # Plot 3: Training Strategy Performance
    ax3 = axes[0, 2]
    if training_strategy_results:
        strategies = list(training_strategy_results.keys())
        final_losses = []
        
        for strategy in strategies:
            dynamics = training_strategy_results[strategy]['dynamics']
            if 'final_loss' in dynamics:
                final_losses.append(dynamics['final_loss'])
            else:
                final_losses.append(0)
        
        ax3.bar(strategies, final_losses, alpha=0.8, color=['lightblue', 'lightgreen', 'lightcoral'])
        ax3.set_title('Final Training Loss by Strategy', fontweight='bold')
        ax3.set_ylabel('Final Loss')
        ax3.grid(True, alpha=0.3)
    
    # Plot 4: Similarity Gap Analysis
    ax4 = axes[1, 0]
    if training_strategy_results:
        strategies = list(training_strategy_results.keys())
        similarity_gaps = []
        
        for strategy in strategies:
            dynamics = training_strategy_results[strategy]['dynamics']
            if 'similarity_gap' in dynamics:
                similarity_gaps.append(dynamics['similarity_gap'])
            else:
                similarity_gaps.append(0)
        
        ax4.bar(strategies, similarity_gaps, alpha=0.8, color=['gold', 'silver', 'bronze'])
        ax4.set_title('Positive-Negative Similarity Gap', fontweight='bold')
        ax4.set_ylabel('Similarity Gap')
        ax4.grid(True, alpha=0.3)
    
    # Plot 5: Temperature Effects Visualization
    ax5 = axes[1, 1]
    if temperature_analysis:
        temperatures, similarities = temperature_analysis
        
        # Show softmax distribution for different temperatures
        for i, temp in enumerate(temperatures[:3]):  # Show first 3 temperatures
            scaled = similarities / temp
            softmax_probs = F.softmax(scaled, dim=0)
            
            ax5.plot(range(len(softmax_probs)), softmax_probs.numpy(), 
                    'o-', label=f'τ={temp}', alpha=0.8)
        
        ax5.set_title('Temperature Effects on Softmax', fontweight='bold')
        ax5.set_xlabel('Similarity Rank')
        ax5.set_ylabel('Softmax Probability')
        ax5.legend()
        ax5.grid(True, alpha=0.3)
    
    # Plot 6: Training Dynamics
    ax6 = axes[1, 2]
    if training_strategy_results:
        for strategy, results in training_strategy_results.items():
            if 'dynamics' in results and 'training_stability' in results['dynamics']:
                stability = results['dynamics']['training_stability']
                violation_rate = results['dynamics']['violation_rate']
                
                ax6.scatter(violation_rate, stability, s=100, alpha=0.7, label=strategy)
        
        ax6.set_title('Training Stability vs Violation Rate', fontweight='bold')
        ax6.set_xlabel('Margin Violation Rate')
        ax6.set_ylabel('Training Stability')
        ax6.legend()
        ax6.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def generate_contrastive_learning_insights():
    """Generate insights from contrastive learning analysis"""
    print("\n💡 Contrastive Learning Insights")
    print("=" * 35)
    
    # Mining strategy insights
    if mining_comparison:
        print("\n🎯 Mining Strategy Analysis:")
        
        for strategy, triplets in mining_comparison.items():
            if triplets:
                avg_difficulty = np.mean([t.difficulty for t in triplets])
                difficulty_std = np.std([t.difficulty for t in triplets])
                
                print(f"   {strategy:>12}: avg_difficulty={avg_difficulty:.3f}, std={difficulty_std:.3f}")
                
                # Strategy characteristics
                if strategy == 'hard':
                    print(f"                 → Best for challenging models, may cause training instability")
                elif strategy == 'semi_hard':
                    print(f"                 → Balanced approach, good convergence properties")
                elif strategy == 'random':
                    print(f"                 → Baseline approach, stable but slower learning")
                elif strategy == 'semantic':
                    print(f"                 → Category-aware, good for structured datasets")
    
    # Training strategy insights
    if training_strategy_results:
        print("\n🚂 Training Strategy Performance:")
        
        best_strategy = None
        best_score = -float('inf')
        
        for strategy, results in training_strategy_results.items():
            dynamics = results['dynamics']
            if 'similarity_gap' in dynamics:
                score = dynamics['similarity_gap'] * dynamics['training_stability']
                
                print(f"   {strategy:>12}: gap={dynamics['similarity_gap']:.3f}, stability={dynamics['training_stability']:.3f}, score={score:.3f}")
                
                if score > best_score:
                    best_score = score
                    best_strategy = strategy
        
        if best_strategy:
            print(f"\n🏆 Best Strategy: {best_strategy} (score: {best_score:.3f})")
    
    # Loss function insights
    if loss_analysis:
        print("\n🧮 Loss Function Characteristics:")
        
        infonce_loss = loss_analysis['infonce'][0].item()
        triplet_loss = loss_analysis['triplet'][0].item()
        mnr_loss = loss_analysis['mnr'][0].item()
        
        print(f"   InfoNCE: {infonce_loss:.4f} → Good for large-scale contrastive learning")
        print(f"   Triplet: {triplet_loss:.4f} → Direct optimization of similarity relationships")
        print(f"   MNR: {mnr_loss:.4f} → Balanced approach for ranking tasks")
        
        # Recommend best loss
        losses = {'InfoNCE': infonce_loss, 'Triplet': triplet_loss, 'MNR': mnr_loss}
        best_loss = min(losses, key=losses.get)
        print(f"\n💡 Recommended loss: {best_loss} (lowest value: {losses[best_loss]:.4f})")

def practical_implementation_recommendations():
    """Provide practical recommendations for contrastive learning"""
    print("\n🛠️ Practical Implementation Recommendations")
    print("=" * 50)
    
    recommendations = {
        "🎯 Negative Mining Strategy": [
            "Start with semi-hard negatives for stable training",
            "Use hard negatives only after initial convergence",
            "Implement dynamic difficulty adjustment",
            "Mix strategies during training for robustness"
        ],
        "🌡️ Temperature Optimization": [
            "Start with τ=0.07 and adjust based on loss behavior",
            "Lower temperature (0.01-0.05) for sharper distributions",
            "Higher temperature (0.1-0.5) for smoother learning",
            "Use learnable temperature when possible"
        ],
        "📚 Curriculum Learning": [
            "Begin with easy negatives (low similarity to anchor)",
            "Gradually increase difficulty as model learns",
            "Monitor margin violation rates as difficulty indicator",
            "Adjust stage thresholds based on convergence"
        ],
        "🔧 Training Optimization": [
            "Use batch sizes of 64-256 for good negative sampling",
            "Implement gradient clipping for training stability",
            "Monitor positive/negative similarity gap",
            "Apply data augmentation for Arabic text variation"
        ],
        "📊 Evaluation Metrics": [
            "Track similarity gap as primary metric",
            "Monitor margin violation rates for difficulty",
            "Use retrieval metrics (MRR, nDCG) for validation",
            "Test on diverse Arabic dialects for robustness"
        ]
    }
    
    for category, tips in recommendations.items():
        print(f"\n{category}:")
        for tip in tips:
            print(f"   • {tip}")

# Run comprehensive analysis
visualize_contrastive_learning_analysis()
generate_contrastive_learning_insights()
practical_implementation_recommendations()

## 🎯 Key Insights and Learning Takeaways

### Mastering Contrastive Learning for Arabic Text

In [None]:
def summarize_contrastive_learning_mastery():
    """Comprehensive summary of contrastive learning mastery"""
    
    insights = {
        "🧮 Mathematical Mastery": [
            "InfoNCE optimizes ranking through temperature-scaled softmax",
            "Triplet loss directly optimizes similarity relationships with margin",
            "Temperature scaling controls distribution sharpness and learning speed",
            "Multiple negatives ranking balances classification and similarity objectives",
            "Adaptive margins improve training stability and convergence"
        ],
        "🎯 Mining Strategy Excellence": [
            "Hard negatives accelerate learning but may cause instability",
            "Semi-hard negatives provide optimal balance of challenge and stability",
            "Cluster-based mining ensures negative diversity across semantic space",
            "Semantic category mining leverages domain structure effectively",
            "Dynamic difficulty adjustment enables curriculum learning"
        ],
        "📚 Curriculum Learning Innovation": [
            "Progressive difficulty increases model robustness gradually",
            "Stage-based training prevents mode collapse in early phases",
            "Difficulty thresholds should align with model capacity",
            "Margin violation rates indicate optimal difficulty progression",
            "Multi-stage approach improves final performance significantly"
        ],
        "🚀 Training Optimization": [
            "Batch size affects negative sampling quality and diversity",
            "Gradient clipping essential for hard negative training stability",
            "Learning rate scheduling improves convergence in later stages",
            "Mixed precision training enables larger batch sizes efficiently",
            "Regular evaluation prevents overfitting to training negatives"
        ],
        "🔬 Arabic-Specific Adaptations": [
            "Morphological variations require careful negative selection",
            "Dialectal differences enhance negative diversity naturally",
            "Root-pattern relationships inform semantic similarity mining",
            "Cross-dialectal triplets improve generalization",
            "Arabic-specific augmentations increase training robustness"
        ]
    }
    
    print("🎓 Contrastive Learning Mastery")
    print("=" * 50)
    
    for category, points in insights.items():
        print(f"\n{category}:")
        for point in points:
            print(f"   • {point}")
    
    return insights

def connection_to_gate_innovation():
    """Connect insights to GATE's contrastive learning innovations"""
    print("\n🔗 Connection to GATE's Innovation")
    print("=" * 35)
    
    print("📋 How GATE Advances Contrastive Learning:")
    
    gate_innovations = {
        "Beyond InfoNCE Limitations": [
            "Identifies InfoNCE insufficiency for sentence-level tasks",
            "Implements task-specific loss functions for better performance",
            "Uses Arabic NLI triplets for fine-grained similarity learning"
        ],
        "Arabic-Tailored Negative Mining": [
            "Curates hard negatives from Arabic NLI datasets",
            "Leverages linguistic structure for better negative selection",
            "Addresses Arabic-specific semantic challenges"
        ],
        "Hybrid Training Integration": [
            "Combines contrastive learning with classification objectives",
            "Balances similarity learning with semantic understanding",
            "Integrates with Matryoshka representation learning"
        ],
        "Performance Achievements": [
            "20-25% improvement over larger models (OpenAI)",
            "State-of-the-art Arabic STS performance on MTEB",
            "Robust performance across multiple embedding dimensions"
        ]
    }
    
    for innovation, details in gate_innovations.items():
        print(f"\n🎯 {innovation}:")
        for detail in details:
            print(f"   ✓ {detail}")
    
    print(f"\n🏆 Result: GATE demonstrates that thoughtful contrastive learning")
    print(f"   design can outperform larger, more resource-intensive models")

def advanced_research_directions():
    """Outline advanced research directions in contrastive learning"""
    print("\n🔬 Advanced Research Directions")
    print("=" * 35)
    
    directions = {
        "🧬 Meta-Learning for Mining": [
            "Learn optimal negative sampling strategies automatically",
            "Adapt mining difficulty based on model performance",
            "Personalize negative selection for different domains"
        ],
        "🌍 Cross-Lingual Contrastive Learning": [
            "Multi-lingual triplet construction and evaluation",
            "Zero-shot transfer through cross-lingual negatives",
            "Language-agnostic similarity learning frameworks"
        ],
        "🎭 Multi-Modal Integration": [
            "Text-image contrastive learning for Arabic content",
            "Audio-text alignment for Arabic speech processing",
            "Multi-modal negative mining strategies"
        ],
        "⚡ Efficiency Optimizations": [
            "Approximate negative sampling for large-scale training",
            "Hierarchical negative caching mechanisms",
            "Dynamic batch construction for optimal mining"
        ],
        "🎯 Task-Specific Adaptations": [
            "Question-answering focused contrastive objectives",
            "Document retrieval optimized negative mining",
            "Conversational AI contrastive learning strategies"
        ]
    }
    
    for direction, aspects in directions.items():
        print(f"\n{direction}:")
        for aspect in aspects:
            print(f"   • {aspect}")

# Generate comprehensive insights
contrastive_insights = summarize_contrastive_learning_mastery()
connection_to_gate_innovation()
advanced_research_directions()

print("\n🎓 Learning Completion Summary")
print("=" * 35)
print("✅ Contrastive learning mathematics thoroughly mastered")
print("✅ Advanced negative mining strategies implemented")
print("✅ Curriculum learning framework developed")
print("✅ Arabic-specific adaptations understood")
print("✅ Training optimization techniques learned")
print("✅ Connection to GATE's innovations established")

print("\n🚀 Next Steps:")
print("   • Apply contrastive learning to your domain")
print("   • Experiment with novel negative mining strategies")
print("   • Integrate with multi-modal learning systems")
print("   • Contribute to Arabic NLP research advancement")

print("\n🎉 Congratulations! You have completed all GATE focused learning notebooks:")
print("   1. ✅ Matryoshka Representation Learning")
print("   2. ✅ Hybrid Loss Architecture")
print("   3. ✅ Arabic NLP Challenges")
print("   4. ✅ Contrastive Triplet Learning")
print("\n🌟 You are now equipped to implement and extend GATE's innovations!")