In [None]:
!pip install torch torchvision torchaudio
!pip install transformers datasets sentence-transformers
!pip install scikit_learn
!pip install -q onnx onnxruntime

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from google.colab import files
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
import random
from datasets import load_dataset
import torch.onnx
import onnx
import os

In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
class SentencePairDataset(Dataset):
    def __init__(self, sentence1_list, sentence2_list, labels, tokenizer, max_length=128):
        self.sentence1 = sentence1_list
        self.sentence2 = sentence2_list
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sent1 = self.tokenizer(
            self.sentence1[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        sent2 = self.tokenizer(
            self.sentence2[idx],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids_1': sent1['input_ids'].squeeze(0),
            'attention_mask_1': sent1['attention_mask'].squeeze(0),
            'input_ids_2': sent2['input_ids'].squeeze(0),
            'attention_mask_2': sent2['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.float32)
        }

In [None]:
def load_combined_dataset():
    from datasets import load_dataset

    print("Loading datasets...")

    sts = load_dataset("mteb/stsbenchmark-sts")

    cs_data = load_cs_dataset()  # We'll create this below

    quora = load_dataset("quora", split="train[:100000]")  # 100K samples

    all_sent1 = []
    all_sent2 = []
    all_labels = []

    # Process STS
    for split in ['train', 'validation']:
        for item in sts[split]:
            all_sent1.append(item['sentence1'])
            all_sent2.append(item['sentence2'])
            all_labels.append(item['score'] / 5.0)  # Normalize to 0-1

    # Process Quora
    for item in quora:
        all_sent1.append(item['questions']['text'][0])
        all_sent2.append(item['questions']['text'][1])
        all_labels.append(1.0 if item['is_duplicate'] else 0.0)

    # Add CS data
    all_sent1.extend(cs_data['sent1'])
    all_sent2.extend(cs_data['sent2'])
    all_labels.extend(cs_data['labels'])

    print(f"Total training pairs: {len(all_labels):,}")

    return all_sent1, all_sent2, all_labels

In [None]:
def generate_cs_dataset():
    """Generate comprehensive CS concept training pairs"""

    concepts = {
        # Programming Fundamentals
        "variable": [
            ("What is a variable?", "A container that stores data values", 1.0),
            ("What is a variable?", "A named memory location for storing values", 0.95),
            ("What is a variable?", "Something that holds information in a program", 0.85),
            ("What is a variable?", "A storage unit for data", 0.9),
            ("Explain variable", "A container that stores data values", 0.95),
            ("Define variable in programming", "A named memory location for storing values", 0.9),
            ("What is a variable?", "A function", 0.2),
            ("What is a variable?", "A loop", 0.1),
        ],

        "function": [
            ("What is a function?", "A reusable block of code that performs a specific task", 1.0),
            ("What is a function?", "A named section of code that can be called", 0.95),
            ("What is a function?", "A way to organize code into reusable pieces", 0.9),
            ("Explain function", "A reusable block of code that performs a specific task", 0.95),
            ("What is a function?", "A variable", 0.2),
            ("What is a function?", "A data type", 0.1),
        ],

        "loop": [
            ("What is a for loop?", "A control structure that repeats code a specific number of times", 1.0),
            ("What is a for loop?", "Iteration over a sequence of elements", 0.9),
            ("What is a for loop?", "A way to repeat instructions multiple times", 0.85),
            ("Explain for loop", "A control structure that repeats code a specific number of times", 0.95),
            ("What is a for loop?", "A conditional statement", 0.2),
            ("What is a for loop?", "A function", 0.15),
        ],

        "while_loop": [
            ("What is a while loop?", "A loop that repeats as long as a condition is true", 1.0),
            ("What is a while loop?", "A control structure that continues until a condition becomes false", 0.95),
            ("What is a while loop?", "Iteration based on a boolean condition", 0.9),
            ("What is a while loop?", "A for loop", 0.3),
        ],

        # Data Structures
        "array": [
            ("What is an array?", "A data structure that stores elements in contiguous memory", 1.0),
            ("What is an array?", "A collection of elements of the same type", 0.95),
            ("What is an array?", "An indexed list of items", 0.9),
            ("What is an array?", "A fixed-size sequential collection", 0.9),
            ("Explain array", "A data structure that stores elements in contiguous memory", 0.95),
            ("What is an array?", "A variable", 0.2),
            ("What is an array?", "A function", 0.1),
        ],

        "linked_list": [
            ("What is a linked list?", "A linear data structure with nodes containing data and pointers", 1.0),
            ("What is a linked list?", "A sequence where each element points to the next", 0.95),
            ("What is a linked list?", "A chain of nodes connected by references", 0.9),
            ("Explain linked list", "A linear data structure with nodes containing data and pointers", 0.95),
            ("What is a linked list?", "An array", 0.3),
            ("What is a linked list?", "A tree", 0.4),
        ],

        "stack": [
            ("What is a stack?", "A LIFO data structure where elements are added and removed from the top", 1.0),
            ("What is a stack?", "A last-in-first-out collection", 0.95),
            ("What is a stack?", "A data structure with push and pop operations", 0.9),
            ("Explain stack", "A LIFO data structure where elements are added and removed from the top", 0.95),
            ("What is a stack?", "A queue", 0.3),
            ("What is a stack?", "An array", 0.4),
        ],

        "queue": [
            ("What is a queue?", "A FIFO data structure where elements are added at the rear and removed from the front", 1.0),
            ("What is a queue?", "A first-in-first-out collection", 0.95),
            ("What is a queue?", "A data structure with enqueue and dequeue operations", 0.9),
            ("What is a queue?", "A stack", 0.3),
        ],

        "hash_table": [
            ("What is a hash table?", "A data structure that maps keys to values using a hash function", 1.0),
            ("What is a hash table?", "A collection using key-value pairs with O(1) lookup", 0.95),
            ("What is a hash table?", "A dictionary implementation using hashing", 0.9),
            ("Explain hash table", "A data structure that maps keys to values using a hash function", 0.95),
            ("What is a hash table?", "An array", 0.3),
        ]
    }

    sent1, sent2, labels = [], [], []

    for concept_name, pairs in concepts.items():
        for s1, s2, label in pairs:
            sent1.append(s1)
            sent2.append(s2)
            labels.append(label)

    # Generate additional negative pairs (unrelated concepts)
    print("Generating negative pairs...")
    for _ in range(len(sent1)):
        i, j = random.sample(range(len(sent1)), 2)
        if labels[i] < 0.5 or labels[j] < 0.5:
            continue
        sent1.append(sent1[i])
        sent2.append(sent2[j])
        labels.append(0.1)


    return sent1, sent2, labels

In [None]:
def load_combined_dataset():
    """Load and combine multiple datasets"""

    print("=" * 50)
    print("LOADING DATASETS")
    print("=" * 50)

    all_sent1 = []
    all_sent2 = []
    all_labels = []

    # 1. CS-specific dataset (highest priority)
    print("\n1. Generating CS concept dataset...")
    cs_sent1, cs_sent2, cs_labels = generate_cs_dataset()

    for _ in range(3):
        all_sent1.extend(cs_sent1)
        all_sent2.extend(cs_sent2)
        all_labels.extend(cs_labels)

    print(f"   ✓ CS pairs (weighted 3x): {len(cs_sent1):,} → {len(cs_sent1)*3:,}")

    print("\n2. Loading STS Benchmark...")
    try:
        sts = load_dataset("mteb/stsbenchmark-sts")
        for split in ['train', 'validation']:
            for item in sts[split]:
                all_sent1.append(item['sentence1'])
                all_sent2.append(item['sentence2'])
                all_labels.append(item['score'] / 5.0)  # Normalize 0-5 to 0-1
        print(f"   ✓ STS pairs: {len(sts['train']) + len(sts['validation']):,}")
    except Exception as e:
        print(f"   ⚠ Could not load STS: {e}")

    print("\n3. Loading Quora dataset...")
    try:
        quora = load_dataset("quora", split="train[:50000]")  # 50K samples
        for item in quora:
            if item['is_duplicate'] is not None:
                all_sent1.append(item['questions']['text'][0])
                all_sent2.append(item['questions']['text'][1])
                all_labels.append(1.0 if item['is_duplicate'] else 0.0)
        print(f"   ✓ Quora pairs: {len(quora):,}")
    except Exception as e:
        print(f"   ⚠ Could not load Quora: {e}")

    # 4. SNLI (Natural Language Inference)
    print("\n4. Loading SNLI dataset...")
    try:
        snli = load_dataset("snli", split="train[:30000]")  # 30K samples
        label_map = {'entailment': 0.9, 'neutral': 0.5, 'contradiction': 0.1}
        for item in snli:
            if item['label'] != -1:  # Skip unlabeled
                label_text = ['entailment', 'neutral', 'contradiction'][item['label']]
                all_sent1.append(item['premise'])
                all_sent2.append(item['hypothesis'])
                all_labels.append(label_map[label_text])
        print(f"   ✓ SNLI pairs: 30,000")
    except Exception as e:
        print(f"   ⚠ Could not load SNLI: {e}")

    print("\n" + "=" * 50)
    print(f"TOTAL TRAINING PAIRS: {len(all_labels):,}")
    print("=" * 50)

    return all_sent1, all_sent2, all_labels

print("Testing dataset loading...")

In [None]:
class SiameseTransformer(nn.Module):
    """Siamese network with transformer encoder"""

    def __init__(self, model_name='distilbert-base-uncased'):
        super().__init__()

        print(f"Loading base model: {model_name}")
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size

        # Projection head (map to shared embedding space)
        self.projection = nn.Sequential(
            nn.Linear(hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 128)
        )

        # Similarity predictor
        self.similarity_head = nn.Sequential(
            nn.Linear(128 * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        # Encode both sentences
        outputs1 = self.encoder(input_ids=input_ids_1, attention_mask=attention_mask_1)
        outputs2 = self.encoder(input_ids=input_ids_2, attention_mask=attention_mask_2)

        # Use [CLS] token representation
        emb1 = outputs1.last_hidden_state[:, 0, :]
        emb2 = outputs2.last_hidden_state[:, 0, :]

        # Project to shared space
        proj1 = self.projection(emb1)
        proj2 = self.projection(emb2)

        # Concatenate and predict similarity
        combined = torch.cat([proj1, proj2], dim=1)
        similarity = self.similarity_head(combined)

        return similarity

In [None]:
def train_model(num_epochs=100, batch_size=32, learning_rate=2e-5):
    """Train the siamese transformer"""

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\n{'='*50}")
    print(f"TRAINING CONFIGURATION")
    print(f"{'='*50}")
    print(f"Device: {device}")
    print(f"Epochs: {num_epochs}")
    print(f"Batch size: {batch_size}")
    print(f"Learning rate: {learning_rate}")
    print(f"{'='*50}\n")

    # Load data
    sent1, sent2, labels = load_combined_dataset()

    # Train/validation split
    train_s1, val_s1, train_s2, val_s2, train_labels, val_labels = train_test_split(
        sent1, sent2, labels, test_size=0.1, random_state=42
    )

    print(f"\nTraining samples: {len(train_labels):,}")
    print(f"Validation samples: {len(val_labels):,}")

    # Tokenizer
    print("\nLoading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

    # Create datasets
    print("Creating dataloaders...")
    train_dataset = SentencePairDataset(train_s1, train_s2, train_labels, tokenizer)
    val_dataset = SentencePairDataset(val_s1, val_s2, val_labels, tokenizer)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    # Model
    print("\nInitializing model...")
    model = SiameseTransformer().to(device)

    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2
    )

    # Loss function
    criterion = nn.MSELoss()

    # Training tracking
    best_val_loss = float('inf')
    patience = 15
    patience_counter = 0
    train_losses = []
    val_losses = []

    print(f"\n{'='*50}")
    print("STARTING TRAINING")
    print(f"{'='*50}\n")

    # Training loop
    for epoch in range(num_epochs):
        # TRAINING PHASE
        model.train()
        train_loss = 0
        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

        for batch in train_bar:
            input_ids_1 = batch['input_ids_1'].to(device)
            attention_mask_1 = batch['attention_mask_1'].to(device)
            input_ids_2 = batch['input_ids_2'].to(device)
            attention_mask_2 = batch['attention_mask_2'].to(device)
            labels_batch = batch['label'].to(device).unsqueeze(1)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
            loss = criterion(outputs, labels_batch)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()
            train_bar.set_postfix({'loss': f'{loss.item():.4f}', 'lr': f'{scheduler.get_last_lr()[0]:.2e}'})

        avg_train_loss = train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # VALIDATION PHASE
        model.eval()
        val_loss = 0

        with torch.no_grad():
            val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]  ")
            for batch in val_bar:
                input_ids_1 = batch['input_ids_1'].to(device)
                attention_mask_1 = batch['attention_mask_1'].to(device)
                input_ids_2 = batch['input_ids_2'].to(device)
                attention_mask_2 = batch['attention_mask_2'].to(device)
                labels_batch = batch['label'].to(device).unsqueeze(1)

                outputs = model(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
                loss = criterion(outputs, labels_batch)
                val_loss += loss.item()
                val_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)

        # Print epoch summary
        print(f"\n{'─'*50}")
        print(f"Epoch {epoch+1}/{num_epochs} Summary:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f}")
        print(f"  LR:         {scheduler.get_last_lr()[0]:.2e}")

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_siamese_model.pt')
            print(f"  ✓ New best model saved! (Val Loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{patience})")

        print(f"{'─'*50}\n")

        # Early stopping
        if patience_counter >= patience:
            print(f"\n⚠ Early stopping triggered at epoch {epoch+1}")
            print(f"Best validation loss: {best_val_loss:.4f}")
            break

    # Plot training curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', linewidth=2)
    plt.plot(val_losses, label='Val Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training Progress', fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(train_losses, label='Train Loss', linewidth=2)
    plt.plot(val_losses, label='Val Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training Progress (Log Scale)', fontsize=14, fontweight='bold')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

    print(f"\n{'='*50}")
    print("TRAINING COMPLETE!")
    print(f"{'='*50}")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Model saved as: best_siamese_model.pt")
    print(f"Training curves saved as: training_curves.png")

    return model, tokenizer


In [None]:
model, tokenizer = train_model(
    num_epochs=2,
    batch_size=32,
    learning_rate=2e-5
)


In [None]:
print("Saving tokenizer configuration...")

vocab = tokenizer.get_vocab()
config = {
    'vocab': vocab,
    'max_length': 128,
    'pad_token': tokenizer.pad_token,
    'unk_token': tokenizer.unk_token,
    'cls_token': tokenizer.cls_token,
    'sep_token': tokenizer.sep_token,
}

with open('tokenizer_config.json', 'w') as f:
    json.dump(config, f, indent=2)

print("Tokenizer config saved as: tokenizer_config.json")
print(f"Vocabulary size: {len(vocab):,}")

In [None]:


print("Exporting model to ONNX format...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load best model
model.load_state_dict(torch.load('best_siamese_model.pt'))
model.eval()
model.to(device)

# Create dummy inputs
batch_size = 1
seq_length = 128

dummy_input_ids_1 = torch.randint(0, 30522, (batch_size, seq_length)).to(device)
dummy_attention_mask_1 = torch.ones(batch_size, seq_length).to(device)
dummy_input_ids_2 = torch.randint(0, 30522, (batch_size, seq_length)).to(device)
dummy_attention_mask_2 = torch.ones(batch_size, seq_length).to(device)

# Export
torch.onnx.export(
    model,
    (dummy_input_ids_1, dummy_attention_mask_1, dummy_input_ids_2, dummy_attention_mask_2),
    'siamese_transformer.onnx',
    input_names=['input_ids_1', 'attention_mask_1', 'input_ids_2', 'attention_mask_2'],
    output_names=['similarity_score'],
    dynamic_axes={
        'input_ids_1': {0: 'batch_size'},
        'attention_mask_1': {0: 'batch_size'},
        'input_ids_2': {0: 'batch_size'},
        'attention_mask_2': {0: 'batch_size'},
        'similarity_score': {0: 'batch_size'}
    },
    opset_version=14
)

print("✓ Model exported as: siamese_transformer.onnx")



onnx_model = onnx.load('siamese_transformer.onnx')
onnx.checker.check_model(onnx_model)
print("✓ ONNX model verified successfully")

# Check file size
import os
file_size_mb = os.path.getsize('siamese_transformer.onnx') / (1024 * 1024)
print(f"✓ Model size: {file_size_mb:.2f} MB")

In [None]:
print("Preparing files for download...")
print("\nFile sizes:")
print(f"  siamese_transformer.onnx: {os.path.getsize('siamese_transformer.onnx')/(1024*1024):.2f} MB")
print(f"  tokenizer_config.json: {os.path.getsize('tokenizer_config.json')/1024:.2f} KB")


print("\n" + "="*50)
print("DOWNLOAD THESE FILES:")
print("="*50)
print("1. siamese_transformer.onnx")
print("2. tokenizer_config.json")
print("="*50)

files.download('siamese_transformer.onnx')
files.download('tokenizer_config.json')


print("\nAll files ready for download.")