In [1]:
import torch
from model_attention import Encoder, Decoder, Seq2Seq
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from data_loader import load_data, create_tokenizers, TranslationDataset
from torch.utils.data import DataLoader
import os

Using GPU: NVIDIA GeForce RTX 4060 Laptop GPU
GPU Memory: 8.00 GB


In [2]:
def plot_attention(attention, source, target, src_vocab, trg_vocab):
    # Convert attention weights to numpy array
    attention = attention.squeeze(1).cpu().numpy()
    
    # Get source and target tokens
    src_tokens = [list(src_vocab.keys())[list(src_vocab.values()).index(i)] for i in source if i not in [src_vocab['<sos>'], src_vocab['<eos>'], src_vocab['<pad>']]]
    trg_tokens = [list(trg_vocab.keys())[list(trg_vocab.values()).index(i)] for i in target if i not in [trg_vocab['<sos>'], trg_vocab['<eos>'], trg_vocab['<pad>']]]
    
    # Create figure and axis
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    
    # Plot attention matrix
    cax = ax.matshow(attention, cmap='viridis')
    fig.colorbar(cax)
    
    # Set up axes
    ax.set_xticklabels([''] + src_tokens, rotation=90)
    ax.set_yticklabels([''] + trg_tokens)
    
    # Show every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.title('Attention Map')
    plt.xlabel('Source Tokens')
    plt.ylabel('Target Tokens')
    plt.tight_layout()
    plt.savefig('attention_map.png')
    plt.close()


In [3]:
def evaluate_model(model, iterator, target_tokenizer, device):
    model.eval()
    translations = []
    references = []
    attention_maps = []
    
    with torch.no_grad():
        for batch in iterator:
            src = batch['source'].to(device)
            trg = batch['target'].to(device)
            
            output = model(src, trg, 0)  # Turn off teacher forcing
            
            # Get the predicted tokens
            pred_tokens = output.argmax(2)
            
            # Convert to text
            for i in range(len(pred_tokens)):
                # Get the tokens
                pred_seq = pred_tokens[i].cpu().numpy()
                ref_seq = trg[i].cpu().numpy()
                
                # Convert to text
                pred_text = target_tokenizer.decode(pred_seq)
                ref_text = target_tokenizer.decode(ref_seq)
                
                translations.append(pred_text)
                references.append(ref_text)
                
                # Store attention maps if available
                if hasattr(model, 'attention_weights'):
                    attention_maps.append((src[i].cpu().numpy(), pred_seq, model.attention_weights[i].cpu().numpy()))
    
    return translations, references, attention_maps

In [4]:
def calculate_bleu(references, translations):
    smoothie = SmoothingFunction().method1
    bleu_scores = []
    
    for ref, trans in zip(references, translations):
        score = sentence_bleu([ref.split()], trans.split(), smoothing_function=smoothie)
        bleu_scores.append(score)
    
    return np.mean(bleu_scores)

if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Hyperparameters (must match training)
    HIDDEN_DIM = 256
    N_LAYERS = 2
    DROPOUT = 0.3
    VOCAB_SIZE = 8000  # Must match training vocabulary size
    BATCH_SIZE = 64
    
    # Load data
    train_df, val_df = load_data('english_assamese.csv')
    source_tokenizer, target_tokenizer = create_tokenizers(train_df, vocab_size=VOCAB_SIZE)
    
    # Get vocabulary sizes
    INPUT_DIM = source_tokenizer.get_piece_size()
    OUTPUT_DIM = target_tokenizer.get_piece_size()
    
    # Initialize model
    enc = Encoder(INPUT_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT)
    dec = Decoder(OUTPUT_DIM, HIDDEN_DIM, N_LAYERS, DROPOUT)
    model = Seq2Seq(enc, dec)
    model = model.to(device)
    
    # Load model state
    checkpoint_path = os.path.join('checkpoints', 'best_model_attention.pth')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, weights_only=True)
        model.load_state_dict(checkpoint['state_dict'])
        print("Model loaded successfully from checkpoint")
    else:
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
    
    # Create validation dataset and loader
    val_dataset = TranslationDataset(
        val_df['eng'].tolist(),
        val_df['asm'].tolist(),
        source_tokenizer,
        target_tokenizer
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        pin_memory=True if device.type == 'cuda' else False
    )
    
    # Evaluate model
    translations, references, attention_maps = evaluate_model(model, val_loader, target_tokenizer, device)
    
    # Calculate BLEU score
    bleu_score = calculate_bleu(references, translations)
    print(f"Average BLEU score: {bleu_score:.4f}")
    
    # Print some example translations
    print("\nExample Translations:")
    for i in range(min(5, len(translations))):
        print(f"Source: {val_df['eng'].iloc[i]}")
        print(f"Reference: {references[i]}")
        print(f"Translation: {translations[i]}")
        print()
    
    # Plot attention maps for examples
    if attention_maps:
        for i, (source, target, attention) in enumerate(attention_maps[:5]):
            plot_attention(torch.tensor(attention), source, target, source_tokenizer, target_tokenizer)
            print(f"Saved attention map for example {i+1}") 

Using device: cuda
Loading data from english_assamese.csv
Total samples: 87849
Training samples: 70279
Validation samples: 17570
Creating tokenizers...
Training English tokenizer...
Training Assamese tokenizer...
Tokenizers created successfully!
Model loaded successfully from checkpoint


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()