In [1]:

from datasets import load_dataset
import evaluate
import os
import torch
import dill
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
# Get user's home directory
import os
home = os.path.expanduser("~")

# Define the path of the cache directory
cache_dir = os.path.join(home, ".cache", "huggingface", "datasets")

# Define the name and configuration of the dataset
dataset_name = "wmt14"
config_name = "fr-en"

# Build the path for the specific dataset configuration
dataset_config_path = os.path.join(cache_dir, dataset_name, config_name)

print(f"Checking cache at: {dataset_config_path}")

# Check if the dataset configuration is already cached
if os.path.exists(dataset_config_path) and len(os.listdir(dataset_config_path)) > 0:
    print("Dataset already downloaded, loading from cache.")
    # If the dataset is already downloaded, load it from the cache directory
    dataset = load_dataset(dataset_name, config_name, cache_dir=cache_dir)
else:
    print("Downloading the dataset.")
    # Download the dataset and specify the cache directory
    dataset = load_dataset(dataset_name, config_name, cache_dir=cache_dir)

# Keep the full valid and test datasets
valid_dataset = dataset["validation"]
test_dataset = dataset["test"]

texts =[]
labels = []
for element in test_dataset["translation"]:
        # print("element: ", element)
        texts.append(element["en"])
        labels.append(element["fr"])

metric = evaluate.load("sacrebleu")
getpwd = os.getcwd()


# file_path_en = os.path.join(getpwd, "original_english_mm_v2.txt")
# # file_path = "/path/to/translations.txt"

# # Open the file in write mode
# with open(file_path_en, "w") as file:
#     # Write each translation to the file
#     for text in texts:
#         file.write(text + "\n")


# file_path_fr = os.path.join(getpwd, "original_french_mm_v2.txt")
# # file_path = "/path/to/translations.txt"

# Open the file in write mode
# with open(file_path_fr, "w") as file:
#     # Write each translation to the file
#     for label in labels:
#         file.write(label + "\n")

# checkpoint_path_generator = '/home/paperspace/google_drive_v1/Research_Thesis/2024/git_repo/checkpoints/bert_dualG/wmt14_en_fr_1mil_pg_kd_loss_MarianMT_unfreeze_lmlayer_dcd_tp_2_1000sents_debug_Normalkd_2_save_open_direct_pretrained/best_generator_dill_open_at_3.pt'
# checkpoint_path_tokenizer = "/home/paperspace/google_drive_v1/Research_Thesis/2024/git_repo/checkpoints/bert_dualG/wmt14_en_fr_1mil_pg_kd_loss_MarianMT_unfreeze_lmlayer_dcd_tp_2_1000sents_debug_Normalkd_2_save_open_direct_pretrained/best_generator_tokenizer_save_pretrained_at_3"
# translations_generated_filename_batch = "translated_french_by_MarianMT_FT_1000sents_kd_3_dcd_2.txt"

checkpoint_path_generator = '/home/paperspace/google_drive_v3/Research_Thesis/2024/git_repo/checkpoints/bert_dualG/1mil_checkpoints/best_generator_dill_open_format_at_2.pt'
checkpoint_path_tokenizer = '/home/paperspace/google_drive_v3/Research_Thesis/2024/git_repo/checkpoints/bert_dualG/1mil_checkpoints/best_generator_tokenizer_save_pretrained_at_2'

# Load the entire model directly
generator2_checkpoint = torch.load(open(checkpoint_path_generator, "rb"), pickle_module=dill)
# generator2_checkpoint= generator2_checkpoint.from_pretrained(checkpoint_path_generator)

# generator2_train # Extract the underlying model from the DataParallel wrapper
generator2_checkpoint = generator2_checkpoint.module if isinstance(generator2_checkpoint, torch.nn.DataParallel) else generator2_checkpoint

# Check if CUDA is available and then set the default device to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path_tokenizer)


batch_size = 8  # Adjust this based on your GPU's memory capacity

translations_batch = []

if torch.cuda.device_count() > 1:
    generator2_checkpoint = torch.nn.DataParallel(generator2_checkpoint).cuda()
else:
    generator2_checkpoint.cuda()

generator2_checkpoint = generator2_checkpoint.module if hasattr(generator2_checkpoint, 'module') else generator2_checkpoint

generator2_checkpoint.eval()
# generator2_checkpoint.to(device)

# Process texts in batches
for i in tqdm(range(0, len(texts), batch_size), desc="Translating batches"):
    batch = texts[i:i + batch_size]
    inputs = tokenizer(batch, truncation=True, padding="max_length", max_length=128, return_tensors="pt").input_ids.to(device)
    # print("inputs shape: ", inputs.shape)
    # Generate outputs for the entire batch
    outputs = generator2_checkpoint.generate(inputs, max_length=60, num_beams=5, early_stopping=True)
    # print("outputs shape", outputs.shape)
    
    # Decode all outputs in the batch
    # batch_translations = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
    batch_translation = tokenizer.batch_decode(outputs , skip_special_tokens=True)
    translations_batch.extend(batch_translation)

# Save the translations to a text file
# import os
# file_path = os.path.join(os.getcwd(), translations_generated_filename_batch)
# with open(file_path, "w") as file:
#     for translation in translations_batch:
#         file.write(translation + "\n")


result_batch = metric.compute(predictions=translations_batch, references=labels)
result_batch = {"bleu": result_batch["score"]}
result_batch


  from .autonotebook import tqdm as notebook_tqdm


Checking cache at: /home/paperspace/.cache/huggingface/datasets/wmt14/fr-en
Dataset already downloaded, loading from cache.
Using device: cuda


Translating batches: 100%|██████████| 376/376 [04:45<00:00,  1.32it/s]


{'bleu': 0.0011027631663308342}