In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.manifold import TSNE
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# ----------------------------------------
# Configuration
# ----------------------------------------
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# ----------------------------------------
# Model Paths
# ----------------------------------------
BASE_MODEL_PATH = "microsoft/deberta-v3-small"
BGE_MODEL_PATH = "../models/DeBERTa/bge"
SBERT_MODEL_PATH = "../models/DeBERTa/sbert"
SIMCSE_MODEL_PATH = "../models/DeBERTa/simcse"
MLM_MODEL_PATH = "../models/DeBERTa/mlm"

# ----------------------------------------
# Loading Tokenizer and Models
# ----------------------------------------
print("Loading models...")
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-small")

base_model = AutoModel.from_pretrained(BASE_MODEL_PATH).to(DEVICE)
base_model.eval()

bge_model = AutoModel.from_pretrained(BGE_MODEL_PATH).to(DEVICE)
bge_model.eval()

sbert_model = AutoModel.from_pretrained(SBERT_MODEL_PATH).to(DEVICE)
sbert_model.eval()

simcse_model = AutoModel.from_pretrained(SIMCSE_MODEL_PATH).to(DEVICE)
simcse_model.eval()

mlm_model = AutoModel.from_pretrained(MLM_MODEL_PATH).to(DEVICE)
mlm_model.eval()

print("All models loaded successfully!")

# ----------------------------------------
# Load and Prepare the Datasets
# ----------------------------------------
print("Loading port vessel dataset...")
base_df = pd.read_json("../data/email_datasets/synthetic/attrprompting/claude/aggregated/aggregated.json")
base_df["vessel_label"] = base_df.labels.apply(lambda x: x["vessel"].lower().strip() if isinstance(x, dict) and "vessel" in x else "")
base_df["lp_label"] = base_df.labels.apply(lambda x: x["load_port"].lower().strip() if isinstance(x, dict) and "load_port" in x else "")
base_df["dp_label"] = base_df.labels.apply(lambda x: x["discharge_port"].lower().strip() if isinstance(x, dict) and "discharge_port" in x else "")

base_df_clean = base_df[
    (base_df["vessel_label"] != "") & 
    (base_df["lp_label"] != "") & 
    (base_df["dp_label"] != "")
].copy()

print(f"Dataset shape after cleaning: {base_df_clean.shape}")
print(f"Unique vessels: {base_df_clean['vessel_label'].nunique()}")
print(f"Unique load ports: {base_df_clean['lp_label'].nunique()}")
print(f"Unique discharge ports: {base_df_clean['dp_label'].nunique()}")

# ----------------------------------------
# Embedding Generation
# ----------------------------------------
def get_embeddings_batch(texts, model, tokenizer, batch_size=32, pooling_method='mean'):
    """Generate embeddings for a batch of texts"""
    embeddings = []

    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i + batch_size]

        inputs = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        ).to(DEVICE)

        with torch.no_grad():
            outputs = model(**inputs)
            
            if pooling_method == 'cls':
                embeddings_batch = outputs.last_hidden_state[:, 0]
            
            else: 
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                embeddings_batch = sum_embeddings / sum_mask

            embeddings.append(embeddings_batch.cpu().numpy())

    return np.vstack(embeddings)

def prepare_combined_entity_dataset(df, min_frequency=2):
    """Prepare combined dataset of vessels and ports with category labels"""
    
    vessel_counts = df['vessel_label'].value_counts()
    frequent_vessels = vessel_counts[vessel_counts >= min_frequency].index.tolist()
    
    all_ports = list(df['lp_label'].tolist()) + list(df['dp_label'].tolist())
    port_counts = pd.Series(all_ports).value_counts()
    frequent_ports = port_counts[port_counts >= min_frequency].index.tolist()
    
    entity_texts = frequent_vessels + frequent_ports
    entity_categories = ['Vessel'] * len(frequent_vessels) + ['Port'] * len(frequent_ports)
    
    print(f"Entity Analysis:")
    print(f"Total unique vessels: {len(vessel_counts)}")
    print(f"Frequent vessels (min_freq={min_frequency}): {len(frequent_vessels)}")
    print(f"Total unique ports: {len(port_counts)}")
    print(f"Frequent ports (min_freq={min_frequency}): {len(frequent_ports)}")
    print(f"Combined entities for analysis: {len(entity_texts)}")
    
    print(f"\nTop vessels by frequency: {dict(vessel_counts.head(10))}")
    print(f"Top ports by frequency: {dict(port_counts.head(10))}")
    
    return entity_texts, entity_categories

def create_tsne_visualization(base_emb, bge_emb, sbert_emb, simcse_emb, mlm_emb, entity_texts, entity_categories):
    """Create t-SNE visualization for entity embeddings colored by category"""
    print(f"Computing t-SNE projections for maritime entities...")
    
    perplexity = min(30, len(entity_texts) - 1)
    if perplexity < 2:
        print(f"Skipping t-SNE: too few entities")
        return
    
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
    
    base_2d = tsne.fit_transform(base_emb)
    bge_2d = tsne.fit_transform(bge_emb)
    sbert_2d = tsne.fit_transform(sbert_emb)
    simcse_2d = tsne.fit_transform(simcse_emb)
    mlm_2d = tsne.fit_transform(mlm_emb)
    
    all_embeddings = np.vstack([base_2d, bge_2d, sbert_2d, simcse_2d, mlm_2d])
    x_limits = all_embeddings[:, 0].min()*1.05, all_embeddings[:, 0].max()*1.05
    y_limits = all_embeddings[:, 1].min()*1.05, all_embeddings[:, 1].max()*1.05
    
    # Create 2x3 subplot grid
    fig, ((ax1, ax2, ax3), (ax4, ax5, ax6)) = plt.subplots(2, 3, figsize=(24, 16), dpi=300)
    
    category_colors = {
        'Vessel': '#1f77b4',
        'Port': '#ff7f0e' 
    }
    
    # ----------------------------------------
    # DEBERTa Base
    # ----------------------------------------
    for i, (entity_text, category) in enumerate(zip(entity_texts, entity_categories)):
        ax1.scatter(base_2d[i, 0], base_2d[i, 1],
                   c=category_colors[category],
                   label=category if i == 0 or category != entity_categories[i-1] else "",
                   alpha=0.7,
                   s=50)

        if len(entity_texts) <= 25:
            ax1.annotate(entity_text[:12], (base_2d[i, 0], base_2d[i, 1]), 
                        fontsize=7, alpha=0.7)
    
    ax1.set_title('DeBERTa-v3-small (Base)', fontsize=14, fontweight='bold')
    ax1.set_xlabel('t-SNE 1')
    ax1.set_ylabel('t-SNE 2')
    ax1.set_xlim(x_limits)
    ax1.set_ylim(y_limits)
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # ----------------------------------------
    # DEBERTa BGE
    # ----------------------------------------
    for i, (entity_text, category) in enumerate(zip(entity_texts, entity_categories)):
        ax2.scatter(bge_2d[i, 0], bge_2d[i, 1],
                   c=category_colors[category],
                   label=category if i == 0 or category != entity_categories[i-1] else "",
                   alpha=0.7,
                   s=50)
        if len(entity_texts) <= 25:
            ax2.annotate(entity_text[:12], (bge_2d[i, 0], bge_2d[i, 1]), 
                        fontsize=7, alpha=0.7)
    
    ax2.set_title('DeBERTa + BGE (bge-large-en-v1.5)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('t-SNE 1')
    ax2.set_ylabel('t-SNE 2')
    ax2.set_xlim(x_limits)
    ax2.set_ylim(y_limits)
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # ----------------------------------------
    # DEBERTa SBERT
    # ----------------------------------------
    for i, (entity_text, category) in enumerate(zip(entity_texts, entity_categories)):
        ax3.scatter(sbert_2d[i, 0], sbert_2d[i, 1],
                   c=category_colors[category],
                   label=category if i == 0 or category != entity_categories[i-1] else "",
                   alpha=0.7,
                   s=50)
        if len(entity_texts) <= 25:
            ax3.annotate(entity_text[:12], (sbert_2d[i, 0], sbert_2d[i, 1]), 
                        fontsize=7, alpha=0.7)
    
    ax3.set_title('DeBERTa + SBERT (all-MiniLM-L6-v2)', fontsize=14, fontweight='bold')
    ax3.set_xlabel('t-SNE 1')
    ax3.set_ylabel('t-SNE 2')
    ax3.set_xlim(x_limits)
    ax3.set_ylim(y_limits)
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # ----------------------------------------
    # DEBERTa SimCSE
    # ----------------------------------------
    for i, (entity_text, category) in enumerate(zip(entity_texts, entity_categories)):
        ax4.scatter(simcse_2d[i, 0], simcse_2d[i, 1],
                   c=category_colors[category],
                   label=category if i == 0 or category != entity_categories[i-1] else "",
                   alpha=0.7,
                   s=50)
        if len(entity_texts) <= 25:
            ax4.annotate(entity_text[:12], (simcse_2d[i, 0], simcse_2d[i, 1]), 
                        fontsize=7, alpha=0.7)
    
    ax4.set_title('DeBERTa + SimCSE Finetuned', fontsize=14, fontweight='bold')
    ax4.set_xlabel('t-SNE 1')
    ax4.set_ylabel('t-SNE 2')
    ax4.set_xlim(x_limits)
    ax4.set_ylim(y_limits)
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    
    # ----------------------------------------
    # DEBERTa MLM
    # ----------------------------------------
    for i, (entity_text, category) in enumerate(zip(entity_texts, entity_categories)):
        ax5.scatter(mlm_2d[i, 0], mlm_2d[i, 1],
                   c=category_colors[category],
                   label=category if i == 0 or category != entity_categories[i-1] else "",
                   alpha=0.7,
                   s=50)
        if len(entity_texts) <= 25:
            ax5.annotate(entity_text[:12], (mlm_2d[i, 0], mlm_2d[i, 1]), 
                        fontsize=7, alpha=0.7)
    
    ax5.set_title('DeBERTa + MLM (Maritime Domain)', fontsize=14, fontweight='bold')
    ax5.set_xlabel('t-SNE 1')
    ax5.set_ylabel('t-SNE 2')
    ax5.set_xlim(x_limits)
    ax5.set_ylim(y_limits)
    ax5.grid(True, alpha=0.3)
    ax5.legend()
    ax6.axis('off')
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.02)
    
    output_path = os.path.join("../output/embeddings_tsne", "port_vessel_embeddings_with_mlm.png")
    plt.savefig(output_path, bbox_inches='tight', dpi=300)
    print(f"Visualization saved as: {output_path}")
    plt.show()

# ============================================================================
# MAIN ANALYSIS
# ============================================================================

print("="*80)
print("MARITIME ENTITY EMBEDDINGS ANALYSIS")
print("="*80)

entity_texts, entity_categories = prepare_combined_entity_dataset(base_df_clean, min_frequency=2)

if len(entity_texts) < 4:
    print("Not enough entities for analysis")
else:
    print(f"\nGenerating embeddings for {len(entity_texts)} maritime entities...")
    
    base_embeddings = get_embeddings_batch(
        entity_texts, base_model, tokenizer, BATCH_SIZE, pooling_method='mean'
    )
    
    bge_embeddings = get_embeddings_batch(
        entity_texts, bge_model, tokenizer, BATCH_SIZE, pooling_method='cls'
    )
    
    sbert_embeddings = get_embeddings_batch(
        entity_texts, sbert_model, tokenizer, BATCH_SIZE, pooling_method='cls'
    )
    
    simcse_embeddings = get_embeddings_batch(
        entity_texts, simcse_model, tokenizer, BATCH_SIZE, pooling_method='cls'
    )
    
    mlm_embeddings = get_embeddings_batch(
        entity_texts, mlm_model, tokenizer, BATCH_SIZE, pooling_method='mean'
    )
    
    print("Embeddings generated successfully!")
    
    create_tsne_visualization(
        base_embeddings, bge_embeddings, sbert_embeddings, simcse_embeddings, mlm_embeddings,
        entity_texts, entity_categories
    )