In [None]:
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
from collections import defaultdict
import re
import spacy
import torch
import torch.nn.functional as F
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
from tqdm import tqdm

class JokeGraphBuilder:
    def __init__(self, jokes_df):
        self.jokes_df = jokes_df
        self.edges = []
        self.node_features = {}
        
    def create_content_similarity_edges(self, threshold=0.3, method='sentence_bert'):
        """Create edges based on content similarity using various methods"""
        
        if method == 'sentence_bert':
            # Use Sentence-BERT for semantic similarity
            model = SentenceTransformer('all-MiniLM-L6-v2')
            embeddings = model.encode(self.jokes_df['text'].tolist())
            similarity_matrix = cosine_similarity(embeddings)
            
        elif method == 'tfidf':
            # Use TF-IDF for lexical similarity
            vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
            tfidf_matrix = vectorizer.fit_transform(self.jokes_df['text'])
            similarity_matrix = cosine_similarity(tfidf_matrix)
        
        # Create edges for similar jokes
        for i in range(len(similarity_matrix)):
            for j in range(i+1, len(similarity_matrix)):
                if similarity_matrix[i][j] > threshold:
                    self.edges.append({
                        'source': self.jokes_df.iloc[i]['joke_id'],
                        'target': self.jokes_df.iloc[j]['joke_id'],
                        'weight': similarity_matrix[i][j],
                        'edge_type': f'content_similarity_{method}'
                    })

    def create_category_edges(self, max_connections_per_joke=15):
        """Connect jokes within same category (SPARSE VERSION)"""
        import random
        
        # Same category edges - but limited per joke
        category_groups = self.jokes_df.groupby('category')
        for category, group in category_groups:
            joke_ids = group['joke_id'].tolist()
            
            # Instead of connecting ALL to ALL, limit connections per joke
            for joke_id in joke_ids:
                other_jokes = [jid for jid in joke_ids if jid != joke_id]
                
                # Limit how many jokes each joke connects to
                num_connections = min(max_connections_per_joke, len(other_jokes))
                if num_connections > 0:
                    connected_jokes = random.sample(other_jokes, num_connections)
                    
                    for connected_joke in connected_jokes:
                        self.edges.append({
                            'source': joke_id,
                            'target': connected_joke,
                            'weight': 1.0,
                            'edge_type': 'same_category'
                        })
        
        # Related category edges (keep this part as is)
        related_categories = {
            'Programming': ['programming'],
            'Dad Joke': ['Pun'],
            'Misc': ['general'],
        }
        
        for main_cat, related_cats in related_categories.items():
            main_jokes = self.jokes_df[self.jokes_df['category'] == main_cat]['joke_id'].tolist()
            for related_cat in related_cats:
                related_jokes = self.jokes_df[self.jokes_df['category'] == related_cat]['joke_id'].tolist()
                for main_joke in main_jokes:
                    for related_joke in related_jokes:
                        self.edges.append({
                            'source': main_joke,
                            'target': related_joke,
                            'weight': 0.7,
                            'edge_type': 'related_category'
                        })
    def create_source_edges(self, weight=0.5):
        """Connect jokes from the same source"""
        source_groups = self.jokes_df.groupby('source')
        for source, group in source_groups:
            joke_ids = group['joke_id'].tolist()
            # Create edges between jokes from same source
            for i in range(len(joke_ids)):
                for j in range(i+1, min(i+6, len(joke_ids))):  # Limit to avoid too many edges
                    self.edges.append({
                        'source': joke_ids[i],
                        'target': joke_ids[j],
                        'weight': weight,
                        'edge_type': 'same_source'
                    })
    
    def create_keyword_entity_edges(self):
        """Connect jokes that share keywords or named entities"""
        nlp = spacy.load('en_core_web_sm')
        
        joke_keywords = {}
        joke_entities = {}
        
        for idx, row in self.jokes_df.iterrows():
            doc = nlp(row['text'])
            
            # Extract keywords (nouns, verbs, adjectives)
            keywords = [token.lemma_.lower() for token in doc 
                       if token.pos_ in ['NOUN', 'VERB', 'ADJ'] and len(token.text) > 3]
            joke_keywords[row['joke_id']] = set(keywords)
            
            # Extract named entities
            entities = [ent.text.lower() for ent in doc.ents]
            joke_entities[row['joke_id']] = set(entities)
        
        # Create edges based on shared keywords
        joke_ids = list(joke_keywords.keys())
        for i in range(len(joke_ids)):
            for j in range(i+1, len(joke_ids)):
                id1, id2 = joke_ids[i], joke_ids[j]
                
                # Keyword overlap
                keyword_overlap = len(joke_keywords[id1] & joke_keywords[id2])
                if keyword_overlap >= 2:  # At least 2 shared keywords
                    self.edges.append({
                        'source': id1,
                        'target': id2,
                        'weight': min(keyword_overlap * 0.2, 1.0),
                        'edge_type': 'shared_keywords'
                    })
                
                # Entity overlap
                entity_overlap = len(joke_entities[id1] & joke_entities[id2])
                if entity_overlap >= 1:  # At least 1 shared entity
                    self.edges.append({
                        'source': id1,
                        'target': id2,
                        'weight': entity_overlap * 0.3,
                        'edge_type': 'shared_entities'
                    })
    
    def create_structural_similarity_edges(self):
        """Connect jokes with similar structure (length, punctuation, etc.)"""
        
        # Calculate structural features
        structural_features = []
        for text in self.jokes_df['text']:
            features = {
                'length': len(text),
                'word_count': len(text.split()),
                'question_marks': text.count('?'),
                'exclamation_marks': text.count('!'),
                'periods': text.count('.'),
                'has_dialogue': '"' in text or "'" in text,
                'has_numbers': bool(re.search(r'\d', text))
            }
            structural_features.append(features)
        
        # Create edges for structurally similar jokes
        for i in range(len(structural_features)):
            for j in range(i+1, len(structural_features)):
                f1, f2 = structural_features[i], structural_features[j]
                
                # Calculate structural similarity
                similarity = 0
                
                # Length similarity
                length_diff = abs(f1['length'] - f2['length'])
                if length_diff < 50:  # Similar length
                    similarity += 0.3
                
                # Similar punctuation patterns
                if f1['question_marks'] > 0 and f2['question_marks'] > 0:
                    similarity += 0.2
                if f1['exclamation_marks'] > 0 and f2['exclamation_marks'] > 0:
                    similarity += 0.2
                if f1['has_dialogue'] == f2['has_dialogue']:
                    similarity += 0.2
                if f1['has_numbers'] == f2['has_numbers']:
                    similarity += 0.1
                
                if similarity > 0.4:  # Threshold for structural similarity
                    self.edges.append({
                        'source': self.jokes_df.iloc[i]['joke_id'],
                        'target': self.jokes_df.iloc[j]['joke_id'],
                        'weight': similarity,
                        'edge_type': 'structural_similarity'
                    })
    
    
    def create_topic_modeling_edges(self, n_topics=10, threshold=0.3):  # Add threshold param
        """Connect jokes that share topics discovered through topic modeling"""
        from sklearn.decomposition import LatentDirichletAllocation
        from sklearn.feature_extraction.text import CountVectorizer
        
        # Prepare text data
        vectorizer = CountVectorizer(max_features=1000, stop_words='english')
        doc_term_matrix = vectorizer.fit_transform(self.jokes_df['text'])
        
        # Fit LDA model
        lda = LatentDirichletAllocation(n_components=n_topics, random_state=42)
        topic_distributions = lda.fit_transform(doc_term_matrix)
        
        # Create edges between jokes with similar topic distributions
        similarity_matrix = cosine_similarity(topic_distributions)
        
        for i in range(len(similarity_matrix)):
            for j in range(i+1, len(similarity_matrix)):
                if similarity_matrix[i][j] > threshold:  # Use parameter instead of hardcoded 0.3
                    self.edges.append({
                        'source': self.jokes_df.iloc[i]['joke_id'],
                        'target': self.jokes_df.iloc[j]['joke_id'],
                        'weight': similarity_matrix[i][j],
                        'edge_type': 'topic_similarity'
                    })
    
    def create_difficulty_complexity_edges(self):
        """Connect jokes with similar complexity/difficulty levels"""
        
        complexity_scores = []
        for text in self.jokes_df['text']:
            # Simple complexity metrics
            avg_word_length = np.mean([len(word) for word in text.split()])
            sentence_count = len([s for s in text.split('.') if s.strip()])
            vocabulary_richness = len(set(text.lower().split())) / len(text.split())
            
            complexity = (avg_word_length * 0.4 + 
                         sentence_count * 0.3 + 
                         vocabulary_richness * 0.3)
            complexity_scores.append(complexity)
        
        # Group jokes by complexity level
        complexity_quartiles = np.percentile(complexity_scores, [25, 50, 75])
        
        for i in range(len(complexity_scores)):
            for j in range(i+1, len(complexity_scores)):
                # Connect jokes in same complexity quartile
                score1, score2 = complexity_scores[i], complexity_scores[j]
                
                if abs(score1 - score2) < 0.5:  # Similar complexity
                    self.edges.append({
                        'source': self.jokes_df.iloc[i]['joke_id'],
                        'target': self.jokes_df.iloc[j]['joke_id'],
                        'weight': 1.0 - abs(score1 - score2),
                        'edge_type': 'complexity_similarity'
                    })
    
    def generate_node_features(self):
        """Generate comprehensive node features"""
        model = SentenceTransformer('all-MiniLM-L6-v2')
        embeddings = model.encode(self.jokes_df['text'].tolist())
        
        for idx, row in self.jokes_df.iterrows():
            text = row['text']
            
            # Basic features
            features = {
                'joke_id': row['joke_id'],
                'category': row['category'],
                'source': row['source'],
                'text_length': len(text),
                'word_count': len(text.split()),
                'sentence_count': len([s for s in text.split('.') if s.strip()]),
                'question_count': text.count('?'),
                'exclamation_count': text.count('!'),
                'has_dialogue': '"' in text or "'" in text,
                'has_numbers': bool(re.search(r'\d', text)),
                'embedding': embeddings[idx].tolist()  # Sentence-BERT embedding
            }
            
            self.node_features[row['joke_id']] = features
    
    def build_complete_graph(self):
        """Build the complete joke recommendation graph"""
        print("Generating node features...")
        self.generate_node_features()
        
        print("Creating content similarity edges...")
        self.create_content_similarity_edges(threshold=0.4)
        
        print("Creating category edges...")
        self.create_category_edges(max_connections_per_joke=15)
        
        print("Creating source edges...")
        self.create_source_edges()
        
        print("Creating keyword/entity edges...")
        self.create_keyword_entity_edges()
        
        
        print("Creating topic modeling edges...")
        self.create_topic_modeling_edges(threshold = 0.6 ,n_topics=7)
        
        # print("Creating complexity edges...")
        # self.create_difficulty_complexity_edges()
        
        print(f"Graph built with {len(self.node_features)} nodes and {len(self.edges)} edges")
        
        return self.node_features, self.edges

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Usage example:
jokes_df = pd.read_csv('jokes_dataset_v2.csv')

In [3]:
def balance_dataset(jokes_df, max_per_category=40):
    """Create a balanced subset for better learning"""
    
    print("BALANCING DATASET:")
    print("=" * 30)
    
    balanced_jokes = []
    
    for category in jokes_df['category'].unique():
        cat_jokes = jokes_df[jokes_df['category'] == category]
        
        # Take at most max_per_category jokes per category
        if len(cat_jokes) > max_per_category:
            cat_jokes = cat_jokes.sample(max_per_category, random_state=42)
        
        balanced_jokes.append(cat_jokes)
        print(f"{category}: {len(cat_jokes)} jokes (was {len(jokes_df[jokes_df['category'] == category])})")
    
    balanced_df = pd.concat(balanced_jokes, ignore_index=True)
    print(f"\nTotal: {len(balanced_df)} jokes (was {len(jokes_df)})")
    
    return balanced_df

def clean_categories(jokes_df):
    """Merge similar categories"""
    
    print("\nCLEANING CATEGORIES:")
    print("=" * 20)
    
    # Fix the obvious duplicates
    jokes_df = jokes_df.copy()
    jokes_df['category'] = jokes_df['category'].replace({
        'programming': 'Programming',  # Merge case variants
        'general': 'Misc'              # Merge similar categories
    })
    
    print("Category counts after cleaning:")
    print(jokes_df['category'].value_counts())
    
    return jokes_df

# Apply the fixes
balanced_df = balance_dataset(jokes_df, max_per_category=40)
balanced_df = clean_categories(balanced_df)

BALANCING DATASET:
Programming: 39 jokes (was 39)
Dark: 10 jokes (was 10)
Pun: 11 jokes (was 11)
Misc: 13 jokes (was 13)
Dad Joke: 40 jokes (was 552)
programming: 27 jokes (was 27)
general: 40 jokes (was 65)

Total: 180 jokes (was 717)

CLEANING CATEGORIES:
Category counts after cleaning:
category
Programming    66
Misc           53
Dad Joke       40
Pun            11
Dark           10
Name: count, dtype: int64


In [4]:
graph_builder = JokeGraphBuilder(balanced_df)
node_features, edges = graph_builder.build_complete_graph()

Generating node features...
Creating content similarity edges...
Creating category edges...
Creating source edges...
Creating keyword/entity edges...
Creating topic modeling edges...
Graph built with 180 nodes and 6455 edges


In [None]:
def convert_to_pytorch_geometric(node_features, edges):
    """Convert with extensive debugging"""
    
    # Create node mapping
    node_ids = list(node_features.keys())
    node_to_idx = {node_id: idx for idx, node_id in enumerate(node_ids)}
    print(f"Number of nodes: {len(node_ids)}")
    print(f"Sample node IDs: {node_ids[:5]}")
    
    # Extract node features (combine embedding + other features)
    x = []
    for node_id in node_ids:
        features = node_features[node_id]
        
        # Start with embedding
        feature_vector = features['embedding'].copy()
        
        # Add other numerical features
        feature_vector.extend([
            features.get('length', 0),
            features.get('word_count', 0),
            features.get('sentence_count', 0),
            features.get('question_count', 0),
            features.get('exclamation_count', 0),
            float(features.get('has_dialogue', 0)),
        ])
        
        x.append(feature_vector)
    
    x = torch.tensor(x, dtype=torch.float)
    print(f"Node features tensor shape: {x.shape}")
    
    # Group edges by type
    edge_types_dict = {}
    for edge in edges:
        edge_type = edge['edge_type']
        if edge_type not in edge_types_dict:
            edge_types_dict[edge_type] = {'source': [], 'target': [], 'weight': []}
        
        # Check if nodes exist in mapping
        if edge['source'] not in node_to_idx:
            print(f"Skipping edge: source node {edge['source']} not found")
            continue
        if edge['target'] not in node_to_idx:
            print(f"Skipping edge: target node {edge['target']} not found")
            continue
            
        src_idx = node_to_idx[edge['source']]
        tgt_idx = node_to_idx[edge['target']]
        
        # Add both directions for undirected graph
        edge_types_dict[edge_type]['source'].extend([src_idx, tgt_idx])
        edge_types_dict[edge_type]['target'].extend([tgt_idx, src_idx])
        edge_types_dict[edge_type]['weight'].extend([edge['weight'], edge['weight']])
    
    print(f"Edge types processed: {list(edge_types_dict.keys())}")
    for edge_type, edge_data in edge_types_dict.items():
        print(f"  {edge_type}: {len(edge_data['source'])} directed edges")
    
    # Create HeteroData
    data = HeteroData()
    data['joke'].x = x
    
    for edge_type, edge_data in edge_types_dict.items():
        if len(edge_data['source']) == 0:
            print(f"⚠️ No valid edges for type: {edge_type}")
            continue
            
        edge_index = torch.tensor([edge_data['source'], edge_data['target']], dtype=torch.long)
        edge_attr = torch.tensor(edge_data['weight'], dtype=torch.float)
        
        print(f"Creating edge type '{edge_type}': edge_index shape {edge_index.shape}")
        
        data['joke', edge_type, 'joke'].edge_index = edge_index
        data['joke', edge_type, 'joke'].edge_attr = edge_attr
    
    return data

In [None]:
data = convert_to_pytorch_geometric(node_features, edges)

=== CONVERSION DEBUG ===
Number of nodes: 180
Sample node IDs: ['23', '54', '36', '40', '28']
Node features tensor shape: torch.Size([180, 390])
Edge types processed: ['content_similarity_sentence_bert', 'same_category', 'related_category', 'same_source', 'shared_keywords', 'shared_entities', 'topic_similarity']
  content_similarity_sentence_bert: 208 directed edges
  same_category: 5170 directed edges
  related_category: 880 directed edges
  same_source: 1710 directed edges
  shared_keywords: 266 directed edges
  shared_entities: 182 directed edges
  topic_similarity: 4494 directed edges
Creating edge type 'content_similarity_sentence_bert': edge_index shape torch.Size([2, 208])
Creating edge type 'same_category': edge_index shape torch.Size([2, 5170])
Creating edge type 'related_category': edge_index shape torch.Size([2, 880])
Creating edge type 'same_source': edge_index shape torch.Size([2, 1710])
Creating edge type 'shared_keywords': edge_index shape torch.Size([2, 266])
Creating e

In [7]:
from torch_geometric.nn import HGTConv, Linear
import torch.nn as nn

class JokeRecommendationHGT(torch.nn.Module):
    def __init__(self, input_dim, hidden_channels=128, embedding_dim=64, 
                 num_heads=4, num_layers=2, data=None):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        
        # Node type projections (just 'joke' in your case)
        self.lin_dict = torch.nn.ModuleDict()
        for node_type in data.node_types:
            self.lin_dict[node_type] = Linear(input_dim, hidden_channels)
        
        # HGT layers to process different edge types
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(
                hidden_channels, 
                hidden_channels, 
                data.metadata(),
                num_heads, 
                # group='sum'
            )
            self.convs.append(conv)
        
        # Final embedding projection
        self.embedding_proj = Linear(hidden_channels, embedding_dim)
        
        # Optional: Add a classifier head for categories
        self.category_classifier = nn.Sequential(
            Linear(embedding_dim, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            Linear(hidden_channels // 2, 7)  # 7 categories in your data
        )
        
    def forward(self, x_dict, edge_index_dict, return_embeddings=True):
        # Initial transformation
        for node_type, x in x_dict.items():
            x_dict[node_type] = F.relu(self.lin_dict[node_type](x))
        
        # Apply HGT convolutions
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
        
        # Get final joke embeddings
        joke_embeddings = self.embedding_proj(x_dict['joke'])
        
        if return_embeddings:
            return joke_embeddings
        else:
            # Return category predictions
            return self.category_classifier(joke_embeddings)
    
    def get_joke_similarities(self, x_dict, edge_index_dict, joke_indices=None):
        """Get similarity scores between jokes"""
        embeddings = self.forward(x_dict, edge_index_dict, return_embeddings=True)
        
        if joke_indices is not None:
            embeddings = embeddings[joke_indices]
        
        # Compute pairwise similarities
        similarities = torch.mm(embeddings, embeddings.t())
        return similarities
    
    def recommend_jokes(self, x_dict, edge_index_dict, source_joke_idx, top_k=5):
        """Recommend top-k similar jokes to a source joke"""
        embeddings = self.forward(x_dict, edge_index_dict, return_embeddings=True)
        
        # Get source joke embedding
        source_embedding = embeddings[source_joke_idx].unsqueeze(0)
        
        # Compute similarities with all jokes
        similarities = torch.mm(source_embedding, embeddings.t()).squeeze()
        
        # Get top-k most similar (excluding the source joke itself)
        similarities[source_joke_idx] = -float('inf')  # Exclude self
        top_k_indices = similarities.topk(top_k).indices
        top_k_scores = similarities.topk(top_k).values
        
        return top_k_indices, top_k_scores

In [9]:
def simple_split(data, train_ratio=0.7, val_ratio=0.15):
    """Dead simple splitting - no fancy PyG needed"""
    
    # Just use your full graph for all splits (small graph = no memory issues)
    train_data = data
    val_data = data  
    test_data = data
    
    # Create random masks for evaluation
    num_nodes = data['joke'].x.shape[0]
    indices = torch.randperm(num_nodes)
    
    train_end = int(train_ratio * num_nodes)
    val_end = int((train_ratio + val_ratio) * num_nodes)
    
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    
    train_mask[indices[:train_end]] = True
    val_mask[indices[train_end:val_end]] = True
    test_mask[indices[val_end:]] = True
    
    return train_data, val_data, test_data, train_mask, val_mask, test_mask

# 2. Simple training loop
def simple_train(model, data, train_mask, val_mask, num_epochs=100):
    """Simple full-batch training - no loaders needed"""
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(num_epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass on ENTIRE graph (it's small!)
        embeddings = model(data.x_dict, data.edge_index_dict)
        
        # Simple contrastive loss on training nodes
        train_embeddings = embeddings[train_mask]
        loss = simple_contrastive_loss(train_embeddings)
        
        loss.backward()
        optimizer.step()
        
        # Validation every 10 epochs
        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_embeddings = embeddings[val_mask]
                val_loss = simple_contrastive_loss(val_embeddings)
            
            print(f"Epoch {epoch}: Train={loss:.4f}, Val={val_loss:.4f}")

def simple_contrastive_loss(embeddings):
    """Super simple loss - just make embeddings diverse"""
    # Random positive and negative pairs
    n = embeddings.shape[0]
    if n < 2:
        return torch.tensor(0.0)
    
    # Sample some pairs
    num_pairs = min(100, n//2)
    idx1 = torch.randint(0, n, (num_pairs,))
    idx2 = torch.randint(0, n, (num_pairs,))
    
    # Simple distance loss
    distances = F.pairwise_distance(embeddings[idx1], embeddings[idx2])
    return distances.mean()

In [10]:
def create_proper_evaluation(data):
    """Create proper train/val/test edges for recommendation evaluation"""
    
    # Get actual edges from your graph (jokes that are connected)
    main_edge_type = list(data.edge_types)[0]  # Use your main edge type
    edge_index = data[main_edge_type].edge_index
    
    # Convert to edge list and remove duplicates
    edges = edge_index.t().numpy()
    unique_edges = []
    seen = set()
    
    for edge in edges:
        edge_tuple = tuple(sorted([edge[0], edge[1]]))
        if edge_tuple not in seen and edge[0] != edge[1]:
            seen.add(edge_tuple)
            unique_edges.append([edge[0], edge[1]])
    
    unique_edges = np.array(unique_edges)
    print(f"Total unique edges: {len(unique_edges)}")
    
    # Split edges into train/val/test
    train_size = int(0.7 * len(unique_edges))
    val_size = int(0.15 * len(unique_edges))
    
    train_edges = unique_edges[:train_size]
    val_edges = unique_edges[train_size:train_size + val_size]
    test_edges = unique_edges[train_size + val_size:]
    
    print(f"Train edges: {len(train_edges)}, Val edges: {len(val_edges)}, Test edges: {len(test_edges)}")
    
    return train_edges, val_edges, test_edges

def create_negative_edges(positive_edges, num_nodes, num_negatives):
    """Create negative edge samples"""
    existing_edges = set()
    for edge in positive_edges:
        existing_edges.add((edge[0], edge[1]))
        existing_edges.add((edge[1], edge[0]))
    
    negative_edges = []
    while len(negative_edges) < num_negatives:
        src = np.random.randint(0, num_nodes)
        dst = np.random.randint(0, num_nodes)
        
        if src != dst and (src, dst) not in existing_edges:
            negative_edges.append([src, dst])
            existing_edges.add((src, dst))
            existing_edges.add((dst, src))
    
    return np.array(negative_edges)

def evaluate_link_prediction(model, data, test_edges, test_negatives):
    """Properly evaluate link prediction performance"""
    model.eval()
    
    with torch.no_grad():
        # Get embeddings
        embeddings = model(data.x_dict, data.edge_index_dict)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        # Positive predictions
        pos_scores = []
        for edge in test_edges:
            src_emb = embeddings[edge[0]]
            dst_emb = embeddings[edge[1]]
            score = torch.dot(src_emb, dst_emb).item()
            pos_scores.append(score)
        
        # Negative predictions  
        neg_scores = []
        for edge in test_negatives:
            src_emb = embeddings[edge[0]]
            dst_emb = embeddings[edge[1]]
            score = torch.dot(src_emb, dst_emb).item()
            neg_scores.append(score)
        
        # Combine for evaluation
        all_scores = pos_scores + neg_scores
        all_labels = [1] * len(pos_scores) + [0] * len(neg_scores)
        
        # Calculate AUC
        from sklearn.metrics import roc_auc_score
        auc = roc_auc_score(all_labels, all_scores)
        
        # Calculate accuracy with threshold 0
        pred_labels = [1 if score > 0 else 0 for score in all_scores]
        accuracy = sum([p == l for p, l in zip(pred_labels, all_labels)]) / len(all_labels)
        
        return {
            'auc': auc,
            'accuracy': accuracy,
            'pos_score_mean': np.mean(pos_scores),
            'neg_score_mean': np.mean(neg_scores)
        }

def proper_train_with_link_prediction(model, data, train_edges, val_edges, num_epochs=100):
    """Train with proper link prediction objective"""
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    num_nodes = data['joke'].x.shape[0]
    
    # Create validation negatives once
    val_negatives = create_negative_edges(val_edges, num_nodes, len(val_edges))
    
    best_val_auc = 0
    patience = 0
    
    for epoch in range(num_epochs):
        model.train()
        
        # Sample training batch
        batch_size = min(256, len(train_edges))
        batch_indices = np.random.choice(len(train_edges), batch_size, replace=False)
        batch_edges = train_edges[batch_indices]
        
        # Create negative edges for this batch
        batch_negatives = create_negative_edges(batch_edges, num_nodes, len(batch_edges))
        
        # Combine positive and negative edges
        all_edges = np.vstack([batch_edges, batch_negatives])
        labels = torch.cat([
            torch.ones(len(batch_edges)),
            torch.zeros(len(batch_negatives))
        ]).float()
        
        optimizer.zero_grad()
        
        # Forward pass
        embeddings = model(data.x_dict, data.edge_index_dict)
        
        # Compute edge predictions
        edge_scores = []
        for edge in all_edges:
            src_emb = embeddings[edge[0]]
            dst_emb = embeddings[edge[1]]
            score = torch.dot(src_emb, dst_emb)
            edge_scores.append(score)
        
        edge_scores = torch.stack(edge_scores)
        
        # Binary cross-entropy loss (this is what you suggested!)
        loss = F.binary_cross_entropy_with_logits(edge_scores, labels)
        
        loss.backward()
        optimizer.step()
        
        # Validation every 10 epochs
        if epoch % 10 == 0:
            val_metrics = evaluate_link_prediction(model, data, val_edges, val_negatives)
            
            print(f"Epoch {epoch:3d} | "
                  f"Loss: {loss:.4f} | "
                  f"Val AUC: {val_metrics['auc']:.4f} | "
                  f"Val Acc: {val_metrics['accuracy']:.4f} | "
                  f"Pos: {val_metrics['pos_score_mean']:.3f} | "
                  f"Neg: {val_metrics['neg_score_mean']:.3f}")
            
            # Early stopping
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                patience = 0
            else:
                patience += 1
                
            if patience >= 10:
                print("Early stopping!")
                break
    
    return model

In [None]:
train_edges, val_edges, test_edges = create_proper_evaluation(data)

print(f"Training edges: {len(train_edges)}") 
print(f"Validation edges: {len(val_edges)}")
print(f"Test edges: {len(test_edges)}")

# Initialize model
model = JokeRecommendationHGT(
    input_dim=data['joke'].x.shape[1],
    hidden_channels=128,
    embedding_dim=64,
    num_heads=4,
    num_layers=2,
    data=data
)

# Train with the improved data
model = proper_train_with_link_prediction(model, data, train_edges, val_edges)

# Final evaluation
num_nodes = data['joke'].x.shape[0]
test_negatives = create_negative_edges(test_edges, num_nodes, len(test_edges))

print(f"AUC: {test_metrics['auc']:.4f}")
print(f"Accuracy: {test_metrics['accuracy']:.4f}")  
print(f"Positive scores (mean): {test_metrics['pos_score_mean']:.3f}")
print(f"Negative scores (mean): {test_metrics['neg_score_mean']:.3f}")

✅ DATA QUALITY LOOKS GOOD - PROCEEDING WITH TRAINING
Total unique edges: 104
Train edges: 72, Val edges: 15, Test edges: 17
Training edges: 72
Validation edges: 15
Test edges: 17

🔥 TRAINING WITH IMPROVED DATA:
Epoch   0 | Loss: 0.6959 | Val AUC: 0.7822 | Val Acc: 0.5000 | Pos: 0.989 | Neg: 0.944
Epoch  10 | Loss: 0.5916 | Val AUC: 0.6889 | Val Acc: 0.7333 | Pos: 0.847 | Neg: 0.245
Epoch  20 | Loss: 0.4579 | Val AUC: 0.6978 | Val Acc: 0.7000 | Pos: 0.797 | Neg: 0.206
Epoch  30 | Loss: 0.4527 | Val AUC: 0.6578 | Val Acc: 0.6667 | Pos: 0.722 | Neg: 0.059
Epoch  40 | Loss: 0.4282 | Val AUC: 0.6089 | Val Acc: 0.6333 | Pos: 0.671 | Neg: 0.165
Epoch  50 | Loss: 0.4466 | Val AUC: 0.6800 | Val Acc: 0.6000 | Pos: 0.693 | Neg: 0.241
Epoch  60 | Loss: 0.4042 | Val AUC: 0.6133 | Val Acc: 0.6000 | Pos: 0.431 | Neg: 0.200
Epoch  70 | Loss: 0.4420 | Val AUC: 0.7289 | Val Acc: 0.6000 | Pos: 0.707 | Neg: 0.301
Epoch  80 | Loss: 0.4551 | Val AUC: 0.6978 | Val Acc: 0.6000 | Pos: 0.725 | Neg: 0.327
Epoch 

In [13]:
def test_recommendation_quality(model, data, balanced_df):
    """Test if recommendations are actually meaningful"""
    
    model.eval()
    with torch.no_grad():
        embeddings = model(data.x_dict, data.edge_index_dict)
        embeddings = F.normalize(embeddings, p=2, dim=1)
    
        # Test different categories
        test_categories = ['Programming', 'Dad Joke', 'Misc']
        
        for category in test_categories:
            cat_jokes = balanced_df[balanced_df['category'] == category]
            if len(cat_jokes) > 0:
                # Pick random joke from category
                test_joke = cat_jokes.sample(1).iloc[0]
                joke_idx = balanced_df[balanced_df['joke_id'] == test_joke['joke_id']].index[0]
                
                # Get recommendations
                similarities = torch.mm(embeddings[joke_idx:joke_idx+1], embeddings.t()).squeeze()
                top_5_indices = similarities.topk(6)[1][1:]  # Skip self
                
                print(f"source ({category}):")
                print(f"   {test_joke['text'][:100]}...")
                
                print(f"top 5 recommendations:")
                for i, idx in enumerate(top_5_indices):
                    rec_joke = balanced_df.iloc[idx.item()]
                    score = similarities[idx].item()
                    print(f"   {i+1}. ({rec_joke['category']}) [{score:.3f}] {rec_joke['text'][:80]}...")
                
                # Check category consistency
                rec_categories = [balanced_df.iloc[idx.item()]['category'] for idx in top_5_indices]
                same_category_count = sum([1 for cat in rec_categories if cat == category])
                print(f"{same_category_count}/5 recommendations from same category")

In [14]:
test_recommendation_quality(model, data, balanced_df)

source (Programming):
   Why did the programmer's wife leave him? He didn't know how to commit....
top 5 recommendations:
   1. (Programming) [0.999] Why did the developer go broke buying Bitcoin? He kept calling it bytecoin and d...
   2. (Programming) [0.999] What goes after USA? USB....
   3. (Programming) [0.999] What is the most used language in programming? Profanity....
   4. (Programming) [0.999] Why did the developer go to therapy? They had too many unresolved issues....
   5. (Programming) [0.999] What do you get when you cross a React developer with a mathematician? A functio...
5/5 recommendations from same category
source (Dad Joke):
   I cut my finger cutting cheese. I know it may be a cheesy story but I feel grate now....
top 5 recommendations:
   1. (Misc) [1.000] My parents raised me as an only child, which really annoyed my younger brother....
   2. (Dad Joke) [0.999] What is worse then finding a worm in your Apple? Finding half a worm in your App...
   3. (Dad Joke) 