In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from datasets import load_dataset
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
import nltk
from nltk.translate.bleu_score import SmoothingFunction

# Ensure nltk and sacrebleu are installed and download necessary resources
import os
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

try:
    import nltk
    nltk.download('punkt')
except ImportError:
    install('nltk')
    import nltk
    nltk.download('punkt')

try:
    import sacrebleu
except ImportError:
    install('sacrebleu')
    import sacrebleu

print("NLTK and SacreBLEU installed")

# Generator and Discriminator Models
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.linear = nn.Linear(self.model.config.hidden_size, self.model.config.vocab_size)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.linear(outputs.last_hidden_state)
        return logits
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=1)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits

print("LOADING DATASETS")
# Load dataset and split into training and validation sets
dataset = load_dataset('wmt14', 'fr-en', split='train')
val_dataset = load_dataset('wmt14', 'fr-en', split='validation')

small_dataset = dataset.shuffle(seed=42).select(range(25000))
small_val_dataset = val_dataset.shuffle(seed=42).select(range(1000))

# Tokenization
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

print("STARTED PREPROCESSING")
def preprocess(examples):
    inputs = tokenizer(examples['translation']['en'], return_tensors='pt', padding='max_length', truncation=True, max_length=128)
    targets = tokenizer(examples['translation']['fr'], return_tensors='pt', padding='max_length', truncation=True, max_length=128)
    return {
        'input_ids': inputs['input_ids'].squeeze(),
        'attention_mask': inputs['attention_mask'].squeeze(),
        'target_ids': targets['input_ids'].squeeze()
    }

# Use multiprocessing to speed up the preprocessing
print("train preprocess")
train_dataset = small_dataset.map(preprocess, remove_columns=['translation'], num_proc=6)
print("val preprocess")
val_dataset = small_val_dataset.map(preprocess, remove_columns=['translation'], num_proc=6)

# Define your dataset class
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {key: torch.tensor(val) for key, val in item.items()}

# Data collation function for efficient batching
def collate_fn(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    target_ids = torch.stack([item['target_ids'] for item in batch])
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'target_ids': target_ids}

# Create DataLoaders
val_custom_dataset = CustomDataset(val_dataset)
validation_dataloader = DataLoader(val_custom_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# Initialize models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Load the trained models
generator.load_state_dict(torch.load('/kaggle/input/foreval/generator_model.pth'))
discriminator.load_state_dict(torch.load('/kaggle/input/foreval/discriminator_model.pth'))

# Translation and BLEU computation functions
def translate_texts(generator, dataloader):
    generator.eval()
    all_translations = []
    all_references = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            target_ids = batch['target_ids'].to(device)

            translations = generator(input_ids, attention_mask)
            translations = translations.argmax(dim=-1).cpu().numpy()

            for i in range(len(translations)):
                decoded_translation = tokenizer.decode(translations[i], skip_special_tokens=True)
                decoded_reference = tokenizer.decode(target_ids[i].cpu().numpy(), skip_special_tokens=True)

                all_translations.append(decoded_translation)
                all_references.append(decoded_reference)

    return all_translations, all_references

# After training
validation_translations, validation_references = translate_texts(generator, validation_dataloader)

# Compute average sentence-level BLEU score using nltk
def compute_avg_bleu_nltk(references, translations):
    sentence_bleu_scores = []
    smoothing_function = SmoothingFunction().method1

    for ref, trans in zip(references, translations):
        ref_words = ref.split()
        trans_words = trans.split()
        sentence_bleu_score = nltk.translate.bleu_score.sentence_bleu([ref_words], trans_words, smoothing_function=smoothing_function)
        sentence_bleu_scores.append(sentence_bleu_score)
    
    avg_bleu_score = sum(sentence_bleu_scores) / len(sentence_bleu_scores)
    return avg_bleu_score

# Compute average sacreBLEU score
def compute_avg_sacrebleu(references, translations):
    sacrebleu_scores = []
    for ref, trans in zip(references, translations):
        sacrebleu_score = sacrebleu.sentence_bleu(trans, [ref]).score
        sacrebleu_scores.append(sacrebleu_score)
    
    avg_sacrebleu_score = sum(sacrebleu_scores) / len(sacrebleu_scores)
    return avg_sacrebleu_score

validation_bleu = compute_avg_bleu_nltk(validation_references, validation_translations)
validation_sacrebleu = compute_avg_sacrebleu(validation_references, validation_translations)

print(f"Average Validation BLEU score: {validation_bleu * 100:.2f}")
print(f"Average Validation sacreBLEU score: {validation_sacrebleu:.2f}")

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Collecting sacrebleu
  Downloading sacrebleu-2.4.2-py3-none-any.whl.metadata (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.0/58.0 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.10.1-py3-none-any.whl.metadata (8.5 kB)
Downloading sacrebleu-2.4.2-py3-none-any.whl (106 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Installing collected packages: portalocker, sacrebleu
Successfully installed portalocker-2.10.1 sacrebleu-2.4.2
NLTK and SacreBLEU installed
LOADING DATASETS


Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/30 [00:00<?, ?files/s]

Downloading data:   0%|          | 0.00/475k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/536k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/40836715 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3003 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

STARTED PREPROCESSING
train preprocess


Map (num_proc=6):   0%|          | 0/25000 [00:00<?, ? examples/s]

val preprocess


Map (num_proc=6):   0%|          | 0/1000 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Average Validation BLEU score: 1.30
Average Validation sacreBLEU score: 1.85
