In [1]:
# @title Start installing and importing the required libraries (session restart may be needed)
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")

# Ensure the Google Colab runtime is using GPU
!pip install fsspec

!pip install torch
!pip install scikit-learn
!pip install tokenizers
!pip install transformers
!pip install bert-score
!pip install rouge-score
!pip install sacrebleu
!pip install evaluate
!pip install tabulate
!pip install comet_ml
!pip install pyspellchecker
!pip install datasets

Is CUDA available: True
CUDA version: 12.4
PyTorch version: 2.5.1+cu124
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cura

In [2]:
# @title Reorganized imports
import argparse
import csv
import logging
import os
import pickle
import random
import sys
import time
from typing import List, Optional, Tuple, Union, Dict
import comet_ml
import re
from spellchecker import SpellChecker

# Third-party libraries
import evaluate
import numpy as np
import pandas as pd
import rouge_score
import sacrebleu
import sklearn
import tokenizers
import transformers
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    AutoTokenizer,
    DebertaForSequenceClassification,
    DebertaTokenizer,
    get_scheduler,
    Trainer,
    TrainingArguments
)

logging.basicConfig(level=logging.DEBUG)

from sklearn.metrics import classification_report

In [3]:
# @title Mount Google Drive on Colab for persistent storage
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


## Preprocessing for GAN and Classifier model

In [None]:
# @title Preprocessing on Trump and Shakespeare Train Eval and Test

# Initialize spell checker
spell = SpellChecker()

# Function to clean and correct text line by line
def correct_text_line(line):
    # Fix common OCR and formatting errors
    line = re.sub(r'(\s[oO]Jce\s|\soJce\s)', ' once ', line, flags=re.IGNORECASE)  # Fix OCR "oJce" -> "once"
    line = re.sub(r'([a-zA-Z])(\d+)', r'\1', line)  # Remove digits mixed with letters
    line = re.sub(r'\s+', ' ', line).strip()  # Normalize spaces

    # Tokenize while keeping punctuation
    tokens = re.findall(r"[\w']+|[.,!?\'-]", line)

    corrected_tokens = []
    for i, token in enumerate(tokens):
        if token.isalpha():  # Only process alphabetic words
            corrected_word = spell.correction(token)
            corrected_word = corrected_word if corrected_word is not None else token

            # Handle capitalization:
            if i == 0 or (i > 0 and tokens[i - 1] in ".!?"):  # Capitalize first word of a sentence
                corrected_word = corrected_word.capitalize()
            elif token.lower() == "i":  # Ensure 'i' is always capitalized
                corrected_word = "I"

            corrected_tokens.append(corrected_word)
        else:
            corrected_tokens.append(token)  # Keep punctuation as-is

    # Join words carefully, ensuring spacing is correct
    return ''.join(
        corrected_tokens[i] if corrected_tokens[i] in ".,!?'-" else ' ' + corrected_tokens[i]
        for i in range(len(corrected_tokens))
    ).strip()

# Define base paths
base_input_path = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/"
base_output_path = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/"

# Define file categories
categories = ["train", "eval", "test"]
authors = ["shakespeare", "trump"]

# Process each file type (train, eval, test) for both authors
for category in categories:
    for author in authors:
        input_file_path = f"{base_input_path}{category}_{author}.txt"
        output_file_path = f"{base_output_path}{category}_{author}_spellchecked.txt"

        with open(input_file_path, "r", encoding="utf-8") as infile, open(output_file_path, "w", encoding="utf-8") as outfile:
            for line in infile:
                corrected_line = correct_text_line(line)  # Apply spell-checking
                outfile.write(corrected_line + "\n")

        print(f"Processing complete. Corrected file saved at: {output_file_path}")



Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_shakespeare_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_trump_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_shakespeare_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_trump_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/test_shakespeare_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/test_trump_spellchecked.txt


In [None]:
#@title copy and rename files

import shutil
import os

# Define source and destination directories
source_dir = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/"
destination_dir = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/"

# Ensure the destination directory exists
os.makedirs(destination_dir, exist_ok=True)

# Define file categories and labels
categories = {
    "train": "train",
    "eval": "dev",
    "test": "test"
}
authors = {
    "shakespeare": "1",
    "trump": "0"
}

# Copy and rename files
for category, new_category in categories.items():
    for author, label in authors.items():
        src_file = f"{source_dir}{category}_{author}_spellchecked.txt"
        dest_file = f"{destination_dir}{new_category}.{label}.txt"  # Rename format

        shutil.copy(src_file, dest_file)
        print(f"Copied and renamed {src_file} to {dest_file}")

print("✅ All files copied and renamed successfully!")


Copied and renamed /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_shakespeare_spellchecked.txt to /content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/train.1.txt
Copied and renamed /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_trump_spellchecked.txt to /content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/train.0.txt
Copied and renamed /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_shakespeare_spellchecked.txt to /content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/dev.1.txt
Copied and renamed /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_trump_spellchecked.txt to /content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/dev.0.txt
Copied and renamed /content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/test_shakespeare_spellchecked.txt to /content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/test.1.txt
Copied and rena

# Binary Experiments

## Training the Classifier

In [None]:
!CUDA_VISIBLE_DEVICES=0 python "/content/drive/MyDrive/ProjectNLP/02.Training Parallelo Shakespeare/GAN_Originale/my_utils/train_classifier.py" \
  --dataset_path "/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/" \
  --max_samples_train 5000 \
  --max_samples_eval 1000 \
  --lowercase \
  --max_sequence_length 32 \
  --batch_size 32 \
  --use_cuda_if_available \
  --learning_rate 2e-5 \
  --epochs 10 \
  --lr_scheduler_type "linear" \
  --model_tag "distilbert/distilbert-base-cased" \
  --save_base_folder "/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/" \
  --save_steps 1 \
  --eval_strategy "epochs" \
  --eval_steps 1


2025-01-30 19:47:05.826695: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738266425.850038   73413 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738266425.856550   73413 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Arguments summary: 
 
	max_samples_train:		5000
	max_samples_eval:		1000
	dataset_path:		/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/Data/
	lowercase:		True
	max_sequence_length:		32
	batch_size:		32
	use_cuda_if_available:		True
	learning_rate:		2e-05
	epochs:		10
	lr_scheduler_type:		linear
	model_tag:		distilbert/distilbert-base-cased
	save_base_folder:		/content/drive/MyDrive/ProjectNLP/20250112_Autori/Ulti

In [None]:
# @title Manual inference for classifier for visual inspection

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Define model and tokenizer paths
model_path = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/"  # Update if different
tokenizer_path = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/"  # Tokenizer is usually saved with the model

# Load the trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

# Function to classify a sentence
def classify_sentence(sentence):
    # Tokenize input
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device)

    # Get model prediction
    with torch.no_grad():
        logits = model(**inputs).logits

    # Convert logits to probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=1)
    predicted_class = torch.argmax(probabilities, dim=1).item()

    # Labels
    labels = {0: "Trump", 1: "Shakespeare"}

    return labels[predicted_class], probabilities.cpu().numpy()

# Test the classifier
sentence = "When we're gonna declare war to Ukraine"
prediction, _ = classify_sentence(sentence)
print(f"Sentence: {sentence}")
print(f"Predicted Class: {prediction}")


Sentence: When we're gonna declare war to Ukraine
Predicted Class: Trump


## Train the GAN

### Basic GAN Train

In [None]:
!CUDA_VISIBLE_DEVICES=0 python "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/GAN_Binaria_shakespeare/train.py"  \
         --style_a=trump \
         --style_b=shakespeare \
         --lang=en \
         --path_mono_A="/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_trump_spellchecked.txt" \
         --path_mono_B="/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_shakespeare_spellchecked.txt" \
         --path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_trump_spellchecked.txt" \
	     --path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/eval_shakespeare_spellchecked.txt" \
         --pretrained_classifier_model="/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/" \
         --pretrained_classifier_eval="/content/drive/MyDrive/ProjectNLP/20250112_Autori/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/" \
         --shuffle \
         --generator_model_tag="facebook/bart-base" \
         --discriminator_model_tag="distilbert/distilbert-base-cased" \
         --lambdas="10|1|1|1|1" \
         --epochs=20 \
         --learning_rate=1e-4 \
         --max_sequence_length=32 \
         --batch_size=16 \
         --save_base_folder="/content/drive/MyDrive/ProjectNLP/02.Training Parallelo Shakespeare/Ultimi_Checkpoint/" \
         --save_steps=1 \
         --eval_strategy=epochs \
         --eval_steps=1 \
         --pin_memory \
         --use_cuda_if_available


2025-01-30 20:25:46.213464: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738268746.234139    2081 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738268746.240605    2081 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Arguments summary: 
 
	style_a:		trump
	style_b:		shakespeare
	lang:		en
	max_samples_train:		None
	max_samples_eval:		None
	nonparal_same_size:		False
	path_mono_A:		/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_trump_spellchecked.txt
	path_mono_B:		/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/train_shakespeare_spellchecked.txt
	path_mono_A_eval:		/content/drive/MyDrive/ProjectNLP/20250112_Autori/P

### Enhanced tokenizer

In [None]:
!CUDA_VISIBLE_DEVICES=0 python "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/GAN_Tokenizer_Shakespeare/train.py"  \
         --style_a=trump \
         --style_b=shakespeare \
         --lang=en \
         --path_mono_A="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/train_trump_spellchecked.txt" \
         --path_mono_B="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/train_shakespeare_spellchecked.txt" \
         --path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/eval_trump_spellchecked.txt" \
	     --path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/eval_shakespeare_spellchecked.txt" \
         --pretrained_classifier_model="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/" \
         --pretrained_classifier_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/UltimoClassifier/classifiers/Data/distilbert/distilbert-base-cased_10/checkpoints/checkpoint-314/" \
         --shuffle \
         --generator_model_tag="facebook/bart-base" \
         --discriminator_model_tag="distilbert/distilbert-base-cased" \
         --lambdas="10|1|1|1|1" \
         --epochs=20 \
         --learning_rate=1e-4 \
         --max_sequence_length=32 \
         --batch_size=16 \
         --save_base_folder="/content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Checkpoint_GAN_Enhanced_Tokenizer/" \
         --save_steps=1 \
         --eval_strategy=epochs \
         --eval_steps=1 \
         --pin_memory \
         --use_cuda_if_available


2025-01-31 19:50:31.009490: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738353031.264266    3693 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738353031.329594    3693 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-31 19:50:31.868715: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Arguments summary: 
 
	style_a:		trump
	style_b:		shakespeare
	lang:		en
	max_samples_train:		None
	max_samples_eval:		None
	

# Ternary Experiment

In [None]:
# @title Preprocessing on Lyrics Train Eval and Test

# Initialize spell checker
spell = SpellChecker()

# Function to clean and correct text line by line
def correct_text_line(line):
    # Fix common OCR and formatting errors
    line = re.sub(r'(\s[oO]Jce\s|\soJce\s)', ' once ', line, flags=re.IGNORECASE)  # Fix OCR "oJce" -> "once"
    line = re.sub(r'([a-zA-Z])(\d+)', r'\1', line)  # Remove digits mixed with letters
    line = re.sub(r'\s+', ' ', line).strip()  # Normalize spaces

    # Tokenize while keeping punctuation
    tokens = re.findall(r"[\w']+|[.,!?\'-]", line)

    corrected_tokens = []
    for i, token in enumerate(tokens):
        if token.isalpha():  # Only process alphabetic words
            corrected_word = spell.correction(token)
            corrected_word = corrected_word if corrected_word is not None else token

            # Handle capitalization:
            if i == 0 or (i > 0 and tokens[i - 1] in ".!?"):  # Capitalize first word of a sentence
                corrected_word = corrected_word.capitalize()
            elif token.lower() == "i":  # Ensure 'i' is always capitalized
                corrected_word = "I"

            corrected_tokens.append(corrected_word)
        else:
            corrected_tokens.append(token)  # Keep punctuation as-is

    # Join words carefully, ensuring spacing is correct
    return ''.join(
        corrected_tokens[i] if corrected_tokens[i] in ".,!?'-" else ' ' + corrected_tokens[i]
        for i in range(len(corrected_tokens))
    ).strip()

# Define base paths
base_input_path = "/content/drive/MyDrive/ProjectNLP/20250112_Autori/Processed/"
base_output_path = "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/"

# Define file categories
categories = ["train", "eval", "test"]
authors = ["lyrics"]

# Process each file type (train, eval, test) for both authors
for category in categories:
    for author in authors:
        input_file_path = f"{base_input_path}{category}_{author}.txt"
        output_file_path = f"{base_output_path}{category}_{author}_spellchecked.txt"

        with open(input_file_path, "r", encoding="utf-8") as infile, open(output_file_path, "w", encoding="utf-8") as outfile:
            for line in infile:
                corrected_line = correct_text_line(line)  # Apply spell-checking
                outfile.write(corrected_line + "\n")

        print(f"Processing complete. Corrected file saved at: {output_file_path}")



Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/train_lyrics_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/eval_lyrics_spellchecked.txt
Processing complete. Corrected file saved at: /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_spellchecked/test_lyrics_spellchecked.txt


In [None]:
# @title Create aggregated test, train and eval files
splits = ["train", "eval", "test"]
classes = ["trump", "lyrics", "shakespeare"]

# Define source and destination folders
source_folder = "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked"
dest_folder = "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_Ternary_Classifier"

# Create the destination folder if it doesn't exist
os.makedirs(dest_folder, exist_ok=True)

# Process each split
for split in splits:
    all_rows = []  # To store all the rows for the current split
    for cls in classes:
        # Construct the file name based on split and class
        filename = f"{split}_{cls}_spellchecked.txt"
        file_path = os.path.join(source_folder, filename)

        # Check if the file exists
        if not os.path.exists(file_path):
            print(f"Warning: {file_path} not found. Skipping.")
            continue

        # Read the file (assuming one example per line)
        with open(file_path, "r", encoding="utf-8") as f:
            lines = f.read().splitlines()
            # For each non-empty line, create a row with text and mapped label
            for line in lines:
                if line.strip():  # Ignore empty lines
                    all_rows.append({
                        "text": line.strip(),
                        "label": cls
                    })

    # Create a DataFrame for the current split and save to CSV
    df = pd.DataFrame(all_rows)
    csv_path = os.path.join(dest_folder, f"{split}.csv")
    df=df.sample(frac=1)
    df.to_csv(csv_path, index=False)
    print(f"Saved {split}.csv with {len(df)} records at {csv_path}")


Saved train.csv with 18616 records at /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_Ternary_Classifier/train.csv
Saved eval.csv with 1035 records at /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_Ternary_Classifier/eval.csv
Saved test.csv with 1034 records at /content/drive/MyDrive/ProjectNLP/03.Ultima Estensione Shakespeare/Data_Ternary_Classifier/test.csv


## Train the Classifier

In [None]:
from datasets import Dataset  # Import Hugging Face Dataset
# File paths
file_paths = {
    "train": "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/Data_Ternary_Classifier/train.csv",
    "eval":  "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/Data_Ternary_Classifier/eval.csv",
    "test":  "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/Data_Ternary_Classifier/test.csv"
}

# Define label mapping
label_mapping = {"trump": 0, "lyrics": 1, "shakespeare": 2}

# Dictionary to hold both pandas DataFrames and PyTorch Datasets
dfs = {}
datasets = {}

# Loop over each split, load the CSV, map the labels, and convert to a Hugging Face Dataset
for split, path in file_paths.items():
    df = pd.read_csv(path)
    df["label"] = df["label"].map(label_mapping)
    dfs[split] = df

    # Convert to Hugging Face Dataset
    hf_dataset = Dataset.from_pandas(df)

    # ✅ Convert Hugging Face Dataset to PyTorch format
    datasets[split] = hf_dataset.with_format("torch")

# Optionally, extract datasets for convenience
train_dataset = datasets["train"]
eval_dataset = datasets["eval"]
test_dataset = datasets["test"]

print("Esempi di TRAIN:")
print(dfs["train"].head())


Esempi di TRAIN:
                                                text  label
0                          There's bones in the sink      1
1  He's trying to buy his way back into the Democ...      0
2  Under Republican leadership, the economy is bo...      0
3                       I owe my life to my enemies.      2
4          'To your dwarfs here, yes, he is a dwarf.      0


In [None]:
# @title Initialize Classifier
model_name = "distilbert-base-cased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

# Funzione di tokenizzazione
def tokenize_function(examples):
    truncation = 'longest_first'
    return tokenizer(examples["text"], padding="max_length", truncation=truncation, max_length=32)

# Tokenizzazione dei dataset
train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = eval_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

# Rimuovere colonne non necessarie
train_dataset = train_dataset.remove_columns(["text"])
eval_dataset = eval_dataset.remove_columns(["text"])
test_dataset = test_dataset.remove_columns(["text"])

# Impostare il tipo per PyTorch
train_dataset.set_format("torch")
eval_dataset.set_format("torch")
test_dataset.set_format("torch")

# Inizializzazione del modello DeBERTa
model = DistilBertForSequenceClassification.from_pretrained(model_name, num_labels=3)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

Map:   0%|          | 0/18616 [00:00<?, ? examples/s]

Map:   0%|          | 0/1035 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# @title Ternary Classifier Training and Validation
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)
    acc = accuracy_score(labels, predictions)
    return {"accuracy": acc}

# Configurazione degli argomenti per il Trainer
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/ProvaFra",  # Directory per i risultati
    evaluation_strategy="epoch",    # Valutazione dopo ogni epoca
    learning_rate=5e-5,             # Tasso di apprendimento
    per_device_train_batch_size=8,  # Batch size per il training
    per_device_eval_batch_size=8,   # Batch size per la valutazione
    num_train_epochs=10,            # Numero di epoche
    weight_decay=0.01,              # Decadimento del peso
    logging_dir="./logs",           # Directory per i log
    logging_steps=10,               # Passi di log
    save_strategy="epoch",          # Salvare dopo ogni epoca
    report_to=[]                    # no reporting to wandb, comet_ml, etc.
)

# Inizializzazione del Trainer con compute_metrics
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics  # funzione di calcolo delle metriche
)

# Addestramento del modello
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5373,0.347185,0.867633
2,0.2356,0.370057,0.882126
3,0.3288,0.442159,0.877295
4,0.2876,0.481198,0.875362
5,0.0741,0.525486,0.891787
6,0.1277,0.524305,0.874396
7,0.1246,0.65522,0.879227
8,0.1627,0.677973,0.878261
9,0.114,0.639255,0.88599
10,0.1513,0.67573,0.881159


TrainOutput(global_step=23270, training_loss=0.20870743999773975, metrics={'train_runtime': 1288.0075, 'train_samples_per_second': 144.533, 'train_steps_per_second': 18.067, 'total_flos': 1541285669514240.0, 'train_loss': 0.20870743999773975, 'epoch': 10.0})

In [None]:
# @title Test the Best Model
checkpoint_path = "/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/DistilBertCheckpoint/checkpoint-11635"

# Load the model from the checkpoint
model = DistilBertForSequenceClassification.from_pretrained(checkpoint_path)
model.to(trainer.args.device)
# Update the trainer's model with the loaded checkpoint
trainer.model = model

# Evaluate the model on the test dataset
predictions = trainer.predict(test_dataset)

# Extract predicted labels
preds = torch.argmax(torch.tensor(predictions.predictions), axis=1)

# Get the true labels from the test dataset
true_labels = test_dataset["label"]

# Generate and print the classification report
print("Classification Report:")
print(classification_report(true_labels, preds, target_names=list(label_mapping.keys())))


Classification Report:
              precision    recall  f1-score   support

       trump       0.93      0.86      0.90       387
      lyrics       0.77      0.92      0.84       318
 shakespeare       0.92      0.82      0.87       329

    accuracy                           0.87      1034
   macro avg       0.87      0.87      0.87      1034
weighted avg       0.88      0.87      0.87      1034



## Train Ternary GAN with enhanced Tokenizer

### Class Definition

In [4]:
# @title MonostyleDataset Class
# Configure logging
logging.basicConfig(level=logging.DEBUG)

class MonostyleDataset(Dataset):
    """
    Mono-style dataset:
    Loads textual data from CSV files, line-based files, or a provided list of sentences.
    """

    def __init__(
        self,
        dataset_format: str,
        dataset_path: str = None,
        sentences_list: List[str] = None,
        text_column_name: str = None,
        separator: str = None,
        style: str = None,
        max_dataset_samples: int = None,
        SEED: int = 42
    ):
        super().__init__()

        self.allowed_dataset_formats = ["list", "csv", "line_file"]
        if dataset_format not in self.allowed_dataset_formats:
            raise Exception(
                f"MonostyleDataset: '{dataset_format}' is not supported. "
                f"Allowed formats: {self.allowed_dataset_formats}."
            )

        self.dataset_format = dataset_format
        self.dataset_path = dataset_path
        self.sentences_list = sentences_list
        self.text_column_name = text_column_name
        self.separator = separator
        self.style = style
        self.max_dataset_samples = max_dataset_samples

        # Load data based on the format
        self.load_data(SEED)

    def _load_data_csv(self):
        try:
            df = pd.read_csv(self.dataset_path, sep=self.separator, header=None, encoding='utf-8')
            df.dropna(inplace=True)
            if self.text_column_name is not None:
                self.data = df[self.text_column_name].tolist()
            else:
                self.data = df.iloc[:, 0].tolist()
            logging.debug(
                f"MonostyleDataset, _load_data_csv: parsed {len(self.data)} examples from '{self.dataset_path}'."
            )
        except UnicodeDecodeError as e:
            logging.error(
                f"MonostyleDataset, _load_data_csv: UnicodeDecodeError while reading '{self.dataset_path}': {e}"
            )
            raise
        except FileNotFoundError:
            logging.error(
                f"MonostyleDataset, _load_data_csv: File not found: '{self.dataset_path}'."
            )
            raise
        except Exception as e:
            logging.error(
                f"MonostyleDataset, _load_data_csv: Error loading CSV dataset: {e}"
            )
            raise

    def _load_data_line_file(self):
        try:
            with open(self.dataset_path, 'r', encoding='utf-8') as f:
                self.data = f.read().split(self.separator)
            logging.debug(
                f"MonostyleDataset, _load_data_line_file: parsed {len(self.data)} examples from '{self.dataset_path}'."
            )
        except UnicodeDecodeError as e:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: UnicodeDecodeError while reading '{self.dataset_path}': {e}"
            )
            raise
        except FileNotFoundError:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: File not found: '{self.dataset_path}'."
            )
            raise
        except Exception as e:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: Error loading line_file dataset: {e}"
            )
            raise

    def load_data(self, SEED=42):
        if self.dataset_format == "csv":
            self._load_data_csv()
        elif self.dataset_format == "line_file":
            self._load_data_line_file()
        elif self.dataset_format == "list":
            if self.sentences_list is None:
                raise Exception(
                    "MonostyleDataset: 'list' format specified but 'sentences_list' is None."
                )
            self.data = self.sentences_list
            logging.debug(
                f"MonostyleDataset, load_data: data already loaded, {len(self.data)} examples."
            )
        else:
            raise Exception(
                f"MonostyleDataset, load_data: '{self.dataset_format}' format is not supported."
            )

        # Limit the number of samples if needed
        if self.max_dataset_samples is not None and self.max_dataset_samples < len(self.data):
            random.seed(SEED)
            ix = random.sample(range(len(self.data)), self.max_dataset_samples)
            self.data = [self.data[i] for i in ix]
            logging.debug(f"MonostyleDataset, load_data: reduced data to {len(self.data)} samples.")

        # Shuffle the data
        random.shuffle(self.data)
        logging.debug("MonostyleDataset, load_data: data has been shuffled.")

    def reduce_data(self, n_samples):
        if n_samples < len(self.data):
            self.data = self.data[:n_samples]
            logging.debug(f"MonostyleDataset, reduce_data: reduced data to {n_samples} samples.")
        else:
            logging.debug(
                f"MonostyleDataset, reduce_data: requested {n_samples}, but dataset has {len(self.data)}."
            )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [5]:
# @title GeneratorModel Class
class GeneratorModel(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        new_style_tokens: List[str] = None,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
        tokenizer = None
    ):
        super(GeneratorModel, self).__init__()

        self.model_name_or_path = model_name_or_path
        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        # If no style tokens are provided, use default ones
        if new_style_tokens is None:
            new_style_tokens = [
                '[pos->neu]', '[pos->neg]',
                '[neu->pos]', '[neg->pos]',
                '[neu->neg]', '[neg->neu]'
            ]

        if pretrained_path is None:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
            if tokenizer:
                self.tokenizer = tokenizer
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        else:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_path)
            if tokenizer:
                self.tokenizer = tokenizer
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_path}tokenizer/")

        num_added_tokens = self.tokenizer.add_tokens(new_style_tokens)
        print(f"Added {num_added_tokens} new tokens to the tokenizer.")

        # Resizing embeddings to include the new tokens
        self.model.resize_token_embeddings(len(self.tokenizer))
        print(f"New embedding size: {len(self.tokenizer)} tokens.")

    def train(self):
        # Setting the model in training mode
        self.model.train()

    def eval(self):
        # Setting the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_sentences: List[str] = None,
        device=None,
    ):

        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )

        if target_sentences is not None:
            target = self.tokenizer(
                target_sentences,
                truncation=self.truncation,
                padding=self.padding,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
            labels = target["input_ids"]
            inputs = inputs.to(device)
            labels = labels.to(device)
            output_supervised = self.model(**inputs, labels=labels)

        inputs = inputs.to(device)
        output = self.model.generate(**inputs, max_length=self.max_seq_length)
        transferred_sentences = self.tokenizer.batch_decode(output, skip_special_tokens=True)

        if target_sentences is not None:
            return output, transferred_sentences, output_supervised.loss
        else:
            return output, transferred_sentences

    def transfer(
        self,
        sentences: List[str],
        device=None
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )

        inputs = inputs.to(device)
        output = self.model.generate(**inputs, max_length=self.max_seq_length)
        transferred_sentences = self.tokenizer.batch_decode(output, skip_special_tokens=True)
        return transferred_sentences

    def save_model(
        self,
        path: Union[str]
    ):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(f"{path}/tokenizer/")


In [6]:
# @title DiscriminatorModel Class
class DiscriminatorModel(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
        tokenizer = None
    ):
        super(DiscriminatorModel, self).__init__()

        self.model_name_or_path = model_name_or_path
        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        if pretrained_path is None:
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
            if tokenizer:
                self.tokenizer = tokenizer
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        else:
            self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_path)
            if tokenizer:
                self.tokenizer = tokenizer
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_path}tokenizer/")

    def train(self):
        # Set the model in training mode
        self.model.train()

    def eval(self):
        # Set the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_labels: Tensor,
        return_hidden: bool = False,
        device=None,
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )
        inputs["labels"] = target_labels
        inputs = inputs.to(device)
        output = self.model(**inputs, output_hidden_states=return_hidden)
        return output, output.loss

    def save_model(
        self,
        path: Union[str]
    ):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(f"{path}/tokenizer")


In [7]:
# @title ClassifierModel Class
class ClassifierModel(nn.Module):
    def __init__(
        self,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
    ):
        super(ClassifierModel, self).__init__()

        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_path)
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
        self.model.eval()

    def eval(self):
        # Set the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_labels: Tensor,
        return_hidden: bool = False,
        device=None,
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )
        inputs["labels"] = target_labels
        inputs = inputs.to(device)
        output = self.model(**inputs, output_hidden_states=return_hidden)
        return output, output.loss


In [8]:
# @title CycleGANModel Class
class CycleGANModel(nn.Module):
    def __init__(
        self,
        G_ab: Union['GeneratorModel', None],
        G_ba: Union['GeneratorModel', None],
        D_ab: Union['DiscriminatorModel', None],
        D_ba: Union['DiscriminatorModel', None],
        Cls: Union['ClassifierModel', None],
        device=None,
        label2id: Dict[str, int] = None
    ):
        """
        Initialization method for the CycleGANModel

        Args:
            G_ab (GeneratorModel): Generator model for mapping A->B
            G_ba (GeneratorModel): Generator model for mapping B->A
            D_ab (DiscriminatorModel): Discriminator model for B
            D_ba (DiscriminatorModel): Discriminator model for A
            Cls (ClassifierModel): Style classifier
            label2id (Dict[str,int]): Style-to-integer mapping (e.g., {"neu": 0, "pos": 1, "neg": 2})
        """
        super(CycleGANModel, self).__init__()

        if G_ab is None or G_ba is None or D_ab is None or D_ba is None:
            logging.warning(
                "CycleGANModel: Some models are missing. Please call 'load_models' to load from a previous checkpoint."
            )

        self.G_ab = G_ab
        self.G_ba = G_ba
        self.D_ab = D_ab
        self.D_ba = D_ba
        self.Cls = Cls

        self.device = device
        logging.info(f"Device: {device}")

        # Use default label2id if none is provided
        if label2id is None:
            label2id = {"neu": 0, "pos": 1, "neg": 2}
        self.label2id = label2id

        # Move all models to device
        self.G_ab.model.to(self.device)
        self.G_ba.model.to(self.device)
        self.D_ab.model.to(self.device)
        self.D_ba.model.to(self.device)
        if self.Cls is not None:
            self.Cls.model.to(self.device)

    def train(self):
        self.G_ab.train()
        self.G_ba.train()
        self.D_ab.train()
        self.D_ba.train()

    def eval(self):
        self.G_ab.eval()
        self.G_ba.eval()
        self.D_ab.eval()
        self.D_ba.eval()

    def get_optimizer_parameters(self):
        params = list(self.G_ab.model.parameters())
        params += list(self.G_ba.model.parameters())
        params += list(self.D_ab.model.parameters())
        params += list(self.D_ba.model.parameters())
        return params

    def training_cycle(
        self,
        sentences_a: List[str],
        sentences_b: List[str],
        target_sentences_ab: List[str] = None,
        target_sentences_ba: List[str] = None,
        style_source=str,
        style_target=str,
        lambdas: List[float] = None,
        loss_logging=None,
        training_step: int = None
    ):
        # ----- Cycle A -> B -----
        token_a_b = f"[{style_source}->{style_target}]"
        token_b_a = f"[{style_target}->{style_source}]"

        label2id = self.label2id
        #print(f"Label2id: {label2id}")


        # First half
        mono_a_with_style = [f"{token_a_b} {s}" for s in sentences_a]
        _, transferred_ab = self.G_ab(mono_a_with_style, device=self.device)

        # D_ab fake
        self.D_ab.eval()
        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_fake_sentences = torch.column_stack((ones, zeros))  # generator side
        _, loss_g_ab = self.D_ab(transferred_ab, labels_fake_sentences, device=self.device)

        if lambdas[4] != 0:
            labels_style_b_sentences = torch.full(
                (len(transferred_ab),),
                label2id[style_target],
                dtype=int
            )
            _, loss_g_ab_cls = self.Cls(transferred_ab, labels_style_b_sentences, device=self.device)

        # Second half
        mono_transferred_ab_with_style = [f"{token_b_a} {s}" for s in transferred_ab]
        _, _, cycle_loss_aba = self.G_ba(mono_transferred_ab_with_style, sentences_a, device=self.device)

        complete_loss_g_ab = lambdas[0] * cycle_loss_aba + lambdas[1] * loss_g_ab

        loss_logging['Cycle Loss A-B-A'].append((lambdas[0] * cycle_loss_aba).item())
        loss_logging['Loss generator  A-B'].append((lambdas[1] * loss_g_ab).item())

        if lambdas[4] != 0:
            complete_loss_g_ab += lambdas[4] * loss_g_ab_cls

            loss_logging['Classifier-guided A-B'].append((lambdas[4] * loss_g_ab_cls).item())

        complete_loss_g_ab.backward()

        # D_ab training
        self.D_ab.train()
        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_fake_sentences = torch.column_stack((zeros, ones))  # discriminator side
        _, loss_d_ab_fake = self.D_ab(transferred_ab, labels_fake_sentences, device=self.device)

        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_real_sentences = torch.column_stack((ones, zeros))
        _, loss_d_ab_real = self.D_ab(sentences_b, labels_real_sentences, device=self.device)
        complete_loss_d_ab = lambdas[2] * loss_d_ab_fake + lambdas[3] * loss_d_ab_real


        loss_logging['Loss D(A->B)'].append(complete_loss_d_ab.item())
        complete_loss_d_ab.backward()

        # ----- Cycle B -> A -----
        mono_b_with_style = [f"{token_b_a} {s}" for s in sentences_b]

        # First half
        _, transferred_ba = self.G_ba(mono_b_with_style, device=self.device)

        # D_ba
        self.D_ba.eval()
        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_fake_sentences = torch.column_stack((ones, zeros))
        _, loss_g_ba = self.D_ba(transferred_ba, labels_fake_sentences, device=self.device)

        if lambdas[4] != 0:
            labels_style_a_sentences = torch.full(
                (len(transferred_ba),),
                label2id[style_source],
                dtype=int
            )
            _, loss_g_ba_cls = self.Cls(transferred_ba, labels_style_a_sentences, device=self.device)

        # Second half
        mono_transferred_ba_with_style = [f"{token_a_b} {s}" for s in transferred_ba]
        _, _, cycle_loss_bab = self.G_ab(mono_transferred_ba_with_style, sentences_b, device=self.device)

        complete_loss_g_ba = lambdas[0] * cycle_loss_bab + lambdas[1] * loss_g_ba

        loss_logging['Cycle Loss B-A-B'].append((lambdas[0] * cycle_loss_bab).item())
        loss_logging['Loss generator  B-A'].append((lambdas[1] * loss_g_ba).item())

        if lambdas[4] != 0:
            complete_loss_g_ba += lambdas[4] * loss_g_ba_cls
            loss_logging['Classifier-guided B-A'].append((lambdas[4] * loss_g_ba_cls).item())

        complete_loss_g_ba.backward()

        # D_ba training
        self.D_ba.train()
        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_fake_sentences = torch.column_stack((zeros, ones))
        _, loss_d_ba_fake = self.D_ba(transferred_ba, labels_fake_sentences, device=self.device)

        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_real_sentences = torch.column_stack((ones, zeros))
        _, loss_d_ba_real = self.D_ba(sentences_a, labels_real_sentences, device=self.device)
        complete_loss_d_ba = lambdas[2] * loss_d_ba_fake + lambdas[3] * loss_d_ba_real

        loss_logging['Loss D(B->A)'].append(complete_loss_d_ba.item())
        complete_loss_d_ba.backward()

    def save_models(self, base_path: Union[str]):
        self.G_ab.save_model(base_path + "/G_ab/")
        self.G_ba.save_model(base_path + "/G_ba/")
        self.D_ab.save_model(base_path + "/D_ab/")
        self.D_ba.save_model(base_path + "/D_ba/")

    def transfer(self, sentences: List[str], direction: str):
        if direction == "AB":
            transferred_sentences = self.G_ab.transfer(sentences, device=self.device)
        else:
            transferred_sentences = self.G_ba.transfer(sentences, device=self.device)
        return transferred_sentences


In [9]:
# @title Evaluator Class
class Evaluator():
    def __init__(self, cycleGAN, args, experiment=None, label2id=None):
        """
        Class for evaluation
        """
        super(Evaluator, self).__init__()

        self.cycleGAN = cycleGAN
        self.args = args
        self.experiment = experiment

        # If label2id is not provided, use a default mapping
        if label2id is None:
            label2id = {"neu": 0, "pos": 1, "neg": 2}
        self.label2id = label2id

        self.bleu = evaluate.load('sacrebleu')
        self.rouge = evaluate.load('rouge')
        # if args.bertscore: self.bertscore = evaluate.load('bertscore')


    def __compute_metric__(self, predictions, references, metric_name, direction=None):
        # predictions = list | references = list of lists
        scores = []
        if metric_name in ['bleu', 'rouge']:
            for pred, ref in zip(predictions, references):
                if metric_name == 'bleu':
                    res = self.bleu.compute(predictions=[pred], references=[ref])
                    scores.append(res['score'])
                elif metric_name == 'rouge':
                    tmp_rouge1, tmp_rouge2, tmp_rougeL = [], [], []
                    for r in ref:
                        res = self.rouge.compute(predictions=[pred], references=[r], use_aggregator=False)
                        tmp_rouge1.append(res['rouge1'][0])
                        tmp_rouge2.append(res['rouge2'][0])
                        tmp_rougeL.append(res['rougeL'][0])
                    scores.append([max(tmp_rouge1), max(tmp_rouge2), max(tmp_rougeL)])
        else:
            raise Exception(f"Metric {metric_name} is not supported.")
        return scores

    def __compute_classif_metrics__(self, pred_A, pred_B, style_A, style_B):
        # Using self.label2id
        label2id = self.label2id

        device = self.cycleGAN.device
        truncation, padding = 'longest_first', 'max_length'

        # If certain conditions are met, load an external classifier instead of using self.cycleGAN.Cls
        if ('lambdas' not in vars(self.args)
            or self.args.lambdas[4] == 0
            or self.args.pretrained_classifier_eval != self.args.pretrained_classifier_model):
            classifier = AutoModelForSequenceClassification.from_pretrained(self.args.pretrained_classifier_eval)
            classifier_tokenizer = AutoTokenizer.from_pretrained(f'{self.args.pretrained_classifier_eval}tokenizer/')
            classifier.to(device)
        else:
            classifier = self.cycleGAN.Cls.model
            classifier_tokenizer = self.cycleGAN.Cls.tokenizer
        classifier.eval()

        y_pred, y_true = [], np.concatenate([
            np.full(len(pred_A), label2id[style_A]),
            np.full(len(pred_B), label2id[style_B])
        ])

        for i in range(0, len(pred_A), self.args.batch_size):
            batch_a = pred_A[i:i+self.args.batch_size]
            inputs = classifier_tokenizer(
                batch_a,
                truncation=truncation,
                padding=padding,
                max_length=self.args.max_sequence_length,
                return_tensors="pt"
            )
            inputs = inputs.to(device)
            with torch.no_grad():
                output = classifier(**inputs)
            y_pred.extend(np.argmax(output.logits.cpu().numpy(), axis=1))

        for i in range(0, len(pred_B), self.args.batch_size):
            batch_b = pred_B[i:i+self.args.batch_size]
            inputs = classifier_tokenizer(
                batch_b,
                truncation=truncation,
                padding=padding,
                max_length=self.args.max_sequence_length,
                return_tensors="pt"
            )
            inputs = inputs.to(device)
            with torch.no_grad():
                output = classifier(**inputs)
            y_pred.extend(np.argmax(output.logits.cpu().numpy(), axis=1))

        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
        rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
        f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
        return acc, prec, rec, f1

    def run_eval_mono(self, epoch, current_training_step, phase, dl_source, dl_target, style_source, style_target):
        print(f'Start {phase}...')
        self.cycleGAN.eval()

        real_A, real_B = [], []
        pred_A, pred_B = [], []
        scores_AB_bleu_self, scores_BA_bleu_self = [], []
        scores_AB_r1_self, scores_BA_r1_self = [], []
        scores_AB_r2_self, scores_BA_r2_self = [], []
        scores_AB_rL_self, scores_BA_rL_self = [], []

        # Define the style token for style B
        style_token_A = f"[{dl_target.dataset.style}->{dl_source.dataset.style}]"
        style_token_B = f"[{dl_source.dataset.style}->{dl_target.dataset.style}]"

        for batch in dl_source:
            mono_a = list(batch)
            mono_a_with_style = [f"{style_token_B} {sentence}" for sentence in mono_a]
            with torch.no_grad():
                transferred = self.cycleGAN.transfer(sentences=mono_a_with_style, direction='AB')
            real_A.extend(mono_a)
            pred_B.extend(transferred)
            mono_a = [[s] for s in mono_a]
            scores_AB_bleu_self.extend(self.__compute_metric__(transferred, mono_a, 'bleu'))
            scores_rouge_self = np.array(self.__compute_metric__(transferred, mono_a, 'rouge'))
            scores_AB_r1_self.extend(scores_rouge_self[:, 0].tolist())
            scores_AB_r2_self.extend(scores_rouge_self[:, 1].tolist())
            scores_AB_rL_self.extend(scores_rouge_self[:, 2].tolist())

        avg_AB_bleu_self = np.mean(scores_AB_bleu_self)
        avg_AB_r1_self = np.mean(scores_AB_r1_self)
        avg_AB_r2_self = np.mean(scores_AB_r2_self)
        avg_AB_rL_self = np.mean(scores_AB_rL_self)

        for batch in dl_target:
            mono_b = list(batch)
            mono_b_with_style = [f"{style_token_A} {sentence}" for sentence in mono_b]
            with torch.no_grad():
                transferred = self.cycleGAN.transfer(sentences=mono_b_with_style, direction='BA')
            real_B.extend(mono_b)
            pred_A.extend(transferred)
            mono_b = [[s] for s in mono_b]
            scores_BA_bleu_self.extend(self.__compute_metric__(transferred, mono_b, 'bleu'))
            scores_rouge_self = np.array(self.__compute_metric__(transferred, mono_b, 'rouge'))
            scores_BA_r1_self.extend(scores_rouge_self[:, 0].tolist())
            scores_BA_r2_self.extend(scores_rouge_self[:, 1].tolist())
            scores_BA_rL_self.extend(scores_rouge_self[:, 2].tolist())

        avg_BA_bleu_self = np.mean(scores_BA_bleu_self)
        avg_BA_r1_self = np.mean(scores_BA_r1_self)
        avg_BA_r2_self = np.mean(scores_BA_r2_self)
        avg_BA_rL_self = np.mean(scores_BA_rL_self)
        avg_2dir_bleu_self = (avg_AB_bleu_self + avg_BA_bleu_self) / 2

        acc, _, _, _ = self.__compute_classif_metrics__(pred_A, pred_B, style_source, style_target)
        acc_scaled = acc * 100
        avg_acc_bleu_self = (avg_2dir_bleu_self + acc_scaled) / 2
        avg_acc_bleu_self_geom = (avg_2dir_bleu_self * acc_scaled) ** 0.5
        avg_acc_bleu_self_h = (2 * avg_2dir_bleu_self * acc_scaled) / (avg_2dir_bleu_self + acc_scaled + 1e-6)

        metrics = {
            'epoch': epoch,
            'step': current_training_step,
            'self-BLEU A->B': avg_AB_bleu_self,
            'self-BLEU B->A': avg_BA_bleu_self,
            'self-BLEU avg': avg_2dir_bleu_self,
            'self-ROUGE-1 A->B': avg_AB_r1_self,
            'self-ROUGE-1 B->A': avg_BA_r1_self,
            'self-ROUGE-2 A->B': avg_AB_r2_self,
            'self-ROUGE-2 B->A': avg_BA_r2_self,
            'self-ROUGE-L A->B': avg_AB_rL_self,
            'self-ROUGE-L B->A': avg_BA_rL_self,
            'style accuracy': acc,
            'acc-BLEU': avg_acc_bleu_self,
            'g-acc-BLEU': avg_acc_bleu_self_geom,
            'h-acc-BLEU': avg_acc_bleu_self_h
        }

        if phase[:10] == 'validation':
            base_path = f"{self.args.save_base_folder}epoch_{epoch}/"
            suffix = f'epoch{epoch}'
        else:
            if self.args.from_pretrained is not None:
                if self.args.save_base_folder is not None:
                    base_path = f"{self.args.save_base_folder}"
                else:
                    base_path = f"{self.args.from_pretrained}epoch_{epoch}/"
            else:
                base_path = f"{self.args.save_base_folder}test/epoch_{epoch}/"
            suffix = f'epoch{epoch}_test'

        os.makedirs(os.path.dirname(base_path), exist_ok=True)
        pickle.dump(metrics, open(f"{base_path}metrics_{suffix}.pickle", 'wb'))

        for m, v in metrics.items():
            if m not in ['epoch', 'step']:
                print(f'{m}: {v}')

        df_AB = pd.DataFrame()
        df_AB['A (source)'] = real_A
        df_AB['B (generated)'] = pred_B
        df_AB.to_csv(f"{base_path}{style_source}_{style_target}_{suffix}.csv", sep=',', header=True)

        df_BA = pd.DataFrame()
        df_BA['B (source)'] = real_B
        df_BA['A (generated)'] = pred_A
        df_BA.to_csv(f"{base_path}{style_target}_{style_source}_{suffix}.csv", sep=',', header=True)

        del df_AB, df_BA
        print(f'End {phase}...')

    def dummy_classif(self):
        pred_A = [
            'wake up or you are going to lose your business .',
            'this place has none of them .',
            'it is april and there are no grass tees yet .',
            'there is no grass on the range .',
            'bottom line , this place sucks .',
            'someone should buy this place .',
            'very disappointed in the customer service .',
            'we will not be back .'
        ]
        pred_B = [
            'huge sandwich !',
            'i added mushrooms , it was very flavorful .',
            'he enjoyed it as well .',
            'fast and friendly service .',
            'will definitely be back .',
            "my dad 's favorite .",
            'huge burgers , fish sandwiches , salads .',
            'decent service .'
        ]
        acc, _, _, _ = self.__compute_classif_metrics__(pred_A, pred_B, 'neg', 'pos')
        print('Dummy classification metrics computation end')


In [10]:
# @title Main function (training pipeline)
def main(args):
    # List of required attributes
    required_attrs = [
        "epochs", "style_a", "style_b", "style_c",
        "path_mono_A", "path_mono_B", "path_mono_C",
        "path_mono_A_eval", "path_mono_B_eval", "path_mono_C_eval",
        "batch_size", "max_samples_train", "max_samples_eval",
        "nonparal_same_size", "generator_model_tag", "discriminator_model_tag",
        "pretrained_classifier_model", "pretrained_classifier_eval",
        "from_pretrained", "save_base_folder", "save_steps",
        "lambdas", "learning_rate", "max_sequence_length",
        "lr_scheduler_type", "warmup_ratio", "use_cuda_if_available"
    ]

    # Check for missing attributes
    missing_attrs = [attr for attr in required_attrs if not hasattr(args, attr)]
    if missing_attrs:
        raise AttributeError(f"Args object is missing: {', '.join(missing_attrs)}")

    # Seeding
    SEED = 42
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Print paths for debugging
    print(f"Loading dataset A from: {args.path_mono_A}")
    print(f"Loading dataset B from: {args.path_mono_B}")
    print(f"Loading dataset C from: {args.path_mono_C}")

    # ----- Load datasets -----
    mono_ds_a = MonostyleDataset(
        dataset_format="line_file",
        dataset_path=args.path_mono_A,
        style=args.style_a,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )
    mono_ds_b = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_b,
        dataset_path=args.path_mono_B,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )
    mono_ds_c = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_c,
        dataset_path=args.path_mono_C,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )

    # Parse lambdas
    lambdas = [float(l) for l in args.lambdas.split('|')]

    # Print all args for clarity
    hyper_params = {}
    print("\nArguments summary:")
    for key, value in vars(args).items():
        hyper_params[key] = value
        print(f"  {key}:\t{value}")

    # If specified, reduce all datasets to the same size
    if args.nonparal_same_size:
        min_len = min(len(mono_ds_a), len(mono_ds_b), len(mono_ds_c))
        mono_ds_a.reduce_data(min_len)
        mono_ds_b.reduce_data(min_len)
        mono_ds_c.reduce_data(min_len)

    # Create eval datasets
    mono_ds_a_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_a,
        dataset_path=args.path_mono_A_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )
    mono_ds_b_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_b,
        dataset_path=args.path_mono_B_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )
    mono_ds_c_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_c,
        dataset_path=args.path_mono_C_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )

    # Dataloaders
    mono_dl_a = DataLoader(mono_ds_a, batch_size=args.batch_size, shuffle=True)
    mono_dl_b = DataLoader(mono_ds_b, batch_size=args.batch_size, shuffle=True)
    mono_dl_c = DataLoader(mono_ds_c, batch_size=args.batch_size, shuffle=True)

    mono_dl_a_eval = DataLoader(mono_ds_a_eval, batch_size=args.batch_size, shuffle=False)
    mono_dl_b_eval = DataLoader(mono_ds_b_eval, batch_size=args.batch_size, shuffle=False)
    mono_dl_c_eval = DataLoader(mono_ds_c_eval, batch_size=args.batch_size, shuffle=False)

    # Optional: free memory
    del mono_ds_a, mono_ds_b, mono_ds_c
    del mono_ds_a_eval, mono_ds_b_eval, mono_ds_c_eval

    '''
    ----- ----- ----- ----- ----- ----- ----- -----
              Load Tokenizers to add New Tokens
    ----- ----- ----- ----- ----- ----- ----- -----
    '''

    gen_tokenizer_AB = AutoTokenizer.from_pretrained("facebook/bart-base")  # For Trump → ALL
    gen_tokenizer_BA = AutoTokenizer.from_pretrained("facebook/bart-base")  # For ALL → Trump

    disc_tokenizer_AB = AutoTokenizer.from_pretrained("distilbert-base-cased")  # Discriminator for Trump → ALL
    disc_tokenizer_BA = AutoTokenizer.from_pretrained("distilbert-base-cased")  # Discriminator for ALL → Trump

    def extract_new_tokens(dataset_paths, tokenizer, max_word_length=20):
        existing_vocab = set(tokenizer.get_vocab().keys())
        new_tokens = set()

        # Regex pattern to match only valid words (no punctuation, no digits)
        valid_word_pattern = re.compile(r"^[a-z]+$")

        for path in dataset_paths:
            with open(path, 'r', encoding='utf-8') as file:
                for line in file:
                    words = line.strip().split()
                    for word in words:
                        word = word.lower().strip()  # Convert to lowercase and trim spaces
                        if (
                            word not in existing_vocab and  # Not already in vocab
                            valid_word_pattern.match(word) and  # No symbols or digits
                            len(word) > 1 and len(word) <= max_word_length and  # Length constraint
                            "_" not in word # Remove placeholders
                        ):
                            new_tokens.add(word)
        return list(new_tokens)

    # Extract words separately for each style
    trump_words = extract_new_tokens([args.path_mono_A], gen_tokenizer_AB)  # Words in Trump dataset, etc
    shakespeare_words = extract_new_tokens([args.path_mono_B], gen_tokenizer_BA)
    lyrics_words = extract_new_tokens([args.path_mono_C], gen_tokenizer_BA)

    print(f"Trump-specific tokens: {len(trump_words)}, Shakespeare-specific tokens: {len(shakespeare_words)}, Lyrics-specific tokens: {len(lyrics_words)}")

     # Update tokenizers accordingly
    gen_tokenizer_AB.add_tokens(shakespeare_words)  # G_AB must generate Shakespeare and Trump
    gen_tokenizer_AB.add_tokens(trump_words)
    gen_tokenizer_BA.add_tokens(lyrics_words)       # G_BA must generate lyrics

    disc_tokenizer_AB.add_tokens(shakespeare_words)  # D_AB must classify Trump and Lyrics
    disc_tokenizer_AB.add_tokens(trump_words)
    disc_tokenizer_BA.add_tokens(lyrics_words)       # D_BA must classify lyrics style

    print(f"{len(shakespeare_words+trump_words)} new tokens added to G_AB & D_AB (Trump → Shakespeare and Lyrics)")
    print(f"{len(lyrics_words)} new tokens added to G_BA & D_BA (Shakespeare/Lyrics → Trump)")

    # ----- Instantiate G, D, Cls -----
    if args.from_pretrained:
        G_ab = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            pretrained_path=f"{args.from_pretrained}G_ab/",
            max_seq_length=args.max_sequence_length, tokenizer=gen_tokenizer_AB
        )
        G_ba = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            pretrained_path=f"{args.from_pretrained}G_ba/",
            max_seq_length=args.max_sequence_length, tokenizer=gen_tokenizer_BA
        )
        D_ab = DiscriminatorModel(
            args.discriminator_model_tag,
            f"{args.from_pretrained}D_ab/",
            max_seq_length=args.max_sequence_length, tokenizer=disc_tokenizer_AB
        )
        D_ba = DiscriminatorModel(
            args.discriminator_model_tag,
            f"{args.from_pretrained}D_ba/",
            max_seq_length=args.max_sequence_length, tokenizer=disc_tokenizer_BA
        )
        print("[INFO] Loaded pretrained G_ab, G_ba, D_ab, D_ba")
    else:
        G_ab = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            max_seq_length=args.max_sequence_length, tokenizer=gen_tokenizer_AB
        )
        G_ba = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            max_seq_length=args.max_sequence_length, tokenizer=gen_tokenizer_BA
        )
        D_ab = DiscriminatorModel(
            args.discriminator_model_tag,
            max_seq_length=args.max_sequence_length, tokenizer=disc_tokenizer_AB
        )
        #tokenizer=disc_tokenizer_AB
        D_ba = DiscriminatorModel(
            args.discriminator_model_tag,
            max_seq_length=args.max_sequence_length,tokenizer=disc_tokenizer_BA
        )
        #tokenizer=disc_tokenizer_BA
        print("[INFO] Using fresh G_ab, G_ba, D_ab, D_ba")

    '''
    ----- ----- ----- ----- ----- ----- ----- -----
             Resize Token Embeddings
    ----- ----- ----- ----- ----- ----- ----- -----
    '''
    G_ab.model.resize_token_embeddings(len(gen_tokenizer_AB))
    G_ba.model.resize_token_embeddings(len(gen_tokenizer_BA))

    D_ab.model.resize_token_embeddings(len(disc_tokenizer_AB))
    D_ba.model.resize_token_embeddings(len(disc_tokenizer_BA))


    # If we need the classifier
    if lambdas[4] != 0 and args.pretrained_classifier_model:
        Cls = ClassifierModel(args.pretrained_classifier_model, max_seq_length=args.max_sequence_length)
        print("[INFO] Loaded pretrained classifier")
    else:
        Cls = None

    # Device
    if args.use_cuda_if_available and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Device: {device}")

    # Create CycleGAN
    cycleGAN = CycleGANModel(
        G_ab=G_ab,
        G_ba=G_ba,
        D_ab=D_ab,
        D_ba=D_ba,
        Cls=Cls,
        device=device,
        label2id=args.label2id
    )

    # Calculate total training steps
    n_batch_ab = min(len(mono_dl_a), len(mono_dl_b))
    n_batch_ac = min(len(mono_dl_a), len(mono_dl_c))
    steps_per_epoch = n_batch_ab + n_batch_ac
    total_training_steps = args.epochs * steps_per_epoch

    # Optimizer
    optimizer = AdamW(cycleGAN.get_optimizer_parameters(), lr=args.learning_rate)

    # Scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=int(steps_per_epoch / 2),
        T_mult=1,
        eta_min=0
    )

    current_training_step = 0
    start_epoch = 0

    # Resume checkpoint if available
    if args.from_pretrained and os.path.exists(f"{args.from_pretrained}checkpoint.pth"):
        ckpt = torch.load(f"{args.from_pretrained}checkpoint.pth", map_location="cpu")
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["lr_scheduler"])
        current_training_step = ckpt["training_step"]
        del ckpt

    # Evaluator
    evaluator = Evaluator(cycleGAN, args, label2id=args.label2id)

    # Training subphase function
    def train_subphase(dataloader_a, dataloader_x, style_src, style_tgt, loss_log):
        nonlocal current_training_step
        n_batch = min(len(dataloader_a), len(dataloader_x))
        progress_bar = tqdm(range(n_batch), desc=f"{style_src}->{style_tgt}")

        cycleGAN.train()
        for batch_a, batch_x in zip(dataloader_a, dataloader_x):
            # Ensure batch_a and batch_x have the same size
            len_a, len_x = len(batch_a), len(batch_x)
            if len_a > len_x:
                batch_a = batch_a[:len_x]
            elif len_x > len_a:
                batch_x = batch_x[:len_a]

            cycleGAN.training_cycle(
                sentences_a=batch_a,
                sentences_b=batch_x,
                style_source=style_src,
                style_target=style_tgt,
                lambdas=lambdas,
                loss_logging=loss_log,
                training_step=current_training_step
            )

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            current_training_step += 1
            progress_bar.update(1)

        progress_bar.close()

    # Loss logging
    loss_logging = {
        'Cycle Loss A-B-A': [],
        'Loss generator  A-B': [],
        'Classifier-guided A-B': [],
        'Loss D(A->B)': [],
        'Cycle Loss B-A-B': [],
        'Loss generator  B-A': [],
        'Classifier-guided B-A': [],
        'Loss D(B->A)': []
    }
    loss_logging['hyper_params'] = hyper_params

    # ----- Training loop -----
    for epoch_idx in range(start_epoch, args.epochs):
        print(f"\n=== EPOCH {epoch_idx} ===")

        # (1) A->B
        train_subphase(mono_dl_a, mono_dl_b, style_src=args.style_a, style_tgt=args.style_b, loss_log=loss_logging)
        # (2) A->C
        train_subphase(mono_dl_a, mono_dl_c, style_src=args.style_a, style_tgt=args.style_c, loss_log=loss_logging)

        # (3) End-of-epoch evaluation
        evaluator.run_eval_mono(
            epoch_idx,
            current_training_step,
            phase="validation_AB_epoch",
            dl_source=mono_dl_a_eval,
            dl_target=mono_dl_b_eval,
            style_source=args.style_a,
            style_target=args.style_b
        )
        evaluator.run_eval_mono(
            epoch_idx,
            current_training_step,
            phase="validation_AC_epoch",
            dl_source=mono_dl_a_eval,
            dl_target=mono_dl_c_eval,
            style_source=args.style_a,
            style_target=args.style_c
        )

        # (4) Checkpoint saving
        if epoch_idx % args.save_steps == 0:
            cycleGAN.save_models(f"{args.save_base_folder}epoch_{epoch_idx}/")

            checkpoint = {
                'epoch': epoch_idx + 1,
                'training_step': current_training_step,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': scheduler.state_dict()
            }
            torch.save(checkpoint, f"{args.save_base_folder}/checkpoint.pth")

            # Remove old loss file if needed
            if epoch_idx > 0:
                prev_loss_file = f"{args.save_base_folder}loss.pickle"
                if os.path.exists(prev_loss_file):
                    os.remove(prev_loss_file)

            # Save training loss
            pickle.dump(loss_logging, open(f"{args.save_base_folder}loss.pickle", "wb"))

        cycleGAN.train()

    print("\n=== Training completed ===")


In [11]:
# @title Args Class to configure the GAN
class Args:
    def __init__(
        self,
        epochs,
        style_a,
        style_b,
        style_c,
        path_mono_A,
        path_mono_B,
        path_mono_C,
        path_mono_A_eval,
        path_mono_B_eval,
        path_mono_C_eval,
        generator_model_tag,
        discriminator_model_tag,
        label2id,
        style_token_list,
        batch_size=8,
        max_samples_train=None,
        max_samples_eval=None,
        nonparal_same_size=False,
        pretrained_classifier_model=None,
        pretrained_classifier_eval=None,
        from_pretrained=None,
        save_base_folder="./checkpoints/",
        save_steps=1,
        lambdas="10|1|1|1|1|1",
        learning_rate=5e-5,
        max_sequence_length=32,
        lr_scheduler_type="cosine_with_restarts",
        warmup_ratio=0.0,
        use_cuda_if_available=False
    ):
        """
        Class to store all training/testing arguments.
        """
        self.epochs = epochs
        self.style_a = style_a
        self.style_b = style_b
        self.style_c = style_c
        self.path_mono_A = path_mono_A
        self.path_mono_B = path_mono_B
        self.path_mono_C = path_mono_C
        self.path_mono_A_eval = path_mono_A_eval
        self.path_mono_B_eval = path_mono_B_eval
        self.path_mono_C_eval = path_mono_C_eval
        self.generator_model_tag = generator_model_tag
        self.discriminator_model_tag = discriminator_model_tag
        self.batch_size = batch_size
        self.max_samples_train = max_samples_train
        self.max_samples_eval = max_samples_eval
        self.nonparal_same_size = nonparal_same_size
        self.pretrained_classifier_model = pretrained_classifier_model
        self.pretrained_classifier_eval = pretrained_classifier_eval
        self.from_pretrained = from_pretrained
        self.save_base_folder = save_base_folder
        self.save_steps = save_steps
        self.lambdas = lambdas
        self.learning_rate = learning_rate
        self.max_sequence_length = max_sequence_length
        self.lr_scheduler_type = lr_scheduler_type
        self.warmup_ratio = warmup_ratio
        self.use_cuda_if_available = use_cuda_if_available
        self.label2id = label2id
        self.style_token_list = style_token_list


### Training and Validation

In [None]:
# Create the label2id map based on the defined styles
label2id = {
    "tru": 0,
    "lyr": 1,
    "sha": 2
}

# list of style tokens
style_token_list = [
    "[tru->lyr]", "[tru->sha]",
    "[lyr->tru]", "[lyr->sha]",
    "[sha->tru]", "[sha->lyr]"
]

# Define args directly in the notebook
args = Args(
      epochs=30,
    style_a="lyr",
    style_b="tru",
    style_c="sha",
    path_mono_A="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_lyrics_spellchecked.txt",
    path_mono_B="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_trump_spellchecked.txt",
    path_mono_C="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_shakespeare_spellchecked.txt",
    path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_lyrics_spellchecked.txt",
    path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_trump_spellchecked.txt",
    path_mono_C_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_shakespeare_spellchecked.txt",
    generator_model_tag="facebook/bart-base",
    discriminator_model_tag="distilbert/distilbert-base-cased",
    pretrained_classifier_model="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/DistilBertCheckpoint/checkpoint-11635",
    pretrained_classifier_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/TernaryClassifier/DistilBertCheckpoint/checkpoint-11635",
    lambdas="10|1|1|1|1",
    learning_rate=5e-5,
    max_sequence_length=64,
    save_base_folder="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Checkpoint_GAN_Ternaria_Tokenizer_Shakespeare/",
    save_steps=1,
    use_cuda_if_available=1,
    label2id=label2id,           # Explicitly pass the label mapping
    style_token_list=style_token_list  # Explicitly pass the style tokens
)

# Call the main function with the updated arguments
main(args)



Loading dataset A from: /content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_lyrics_spellchecked.txt
Loading dataset B from: /content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_trump_spellchecked.txt
Loading dataset C from: /content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_shakespeare_spellchecked.txt

Arguments summary:
  epochs:	30
  style_a:	lyr
  style_b:	tru
  style_c:	sha
  path_mono_A:	/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_lyrics_spellchecked.txt
  path_mono_B:	/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_trump_spellchecked.txt
  path_mono_C:	/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_shakespeare_spellchecked.txt
  path_mono_A_eval:	/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/ev

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Trump-specific tokens: 2902, Shakespeare-specific tokens: 3819, Lyrics-specific tokens: 3467
6721 new tokens added to G_AB & D_AB (Trump → Shakespeare and Lyrics)
3467 new tokens added to G_BA & D_BA (Shakespeare/Lyrics → Trump)


model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Added 6 new tokens to the tokenizer.


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


New embedding size: 55760 tokens.
Added 6 new tokens to the tokenizer.
New embedding size: 53738 tokens.


config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[INFO] Using fresh G_ab, G_ba, D_ab, D_ba
[INFO] Loaded pretrained classifier
Device: cuda


Downloading builder script:   0%|          | 0.00/8.15k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]


=== EPOCH 0 ===


lyr->tru: 100%|██████████| 717/717 [34:07<00:00,  2.86s/it]
lyr->sha: 100%|██████████| 717/717 [30:28<00:00,  2.55s/it]


Start validation_AB_epoch...
self-BLEU A->B: 64.93104133142941
self-BLEU B->A: 73.06517641189879
self-BLEU avg: 68.9981088716641
self-ROUGE-1 A->B: 0.8808362740669065
self-ROUGE-1 B->A: 0.913437634426684
self-ROUGE-2 A->B: 0.7838126647245209
self-ROUGE-2 B->A: 0.8416429121626757
self-ROUGE-L A->B: 0.8808362740669065
self-ROUGE-L B->A: 0.9131877125272775
style accuracy: 0.2584745762711864
acc-BLEU: 47.42278324939137
g-acc-BLEU: 42.23062508904713
h-acc-BLEU: 37.606938151001
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 64.91705811545476
self-BLEU B->A: 62.312928876387694
self-BLEU avg: 63.614993495921226
self-ROUGE-1 A->B: 0.8805390398715041
self-ROUGE-1 B->A: 0.8527408730109001
self-ROUGE-2 A->B: 0.7834489081107644
self-ROUGE-2 B->A: 0.7301715995049621
self-ROUGE-L A->B: 0.8805390398715041
self-ROUGE-L B->A: 0.8525999286839092
style accuracy: 0.5784615384615385
acc-BLEU: 60.73057367103753
g-acc-BLEU: 60.66203673375248
h-acc-BLEU: 60.593576644333375
End validati




=== EPOCH 1 ===


lyr->tru: 100%|██████████| 717/717 [34:46<00:00,  2.91s/it]
lyr->sha: 100%|██████████| 717/717 [30:49<00:00,  2.58s/it]


Start validation_AB_epoch...
self-BLEU A->B: 64.20291823862179
self-BLEU B->A: 72.22680600071361
self-BLEU avg: 68.2148621196677
self-ROUGE-1 A->B: 0.8768353316483118
self-ROUGE-1 B->A: 0.8994493157790309
self-ROUGE-2 A->B: 0.7790843840747816
self-ROUGE-2 B->A: 0.8286015313717856
self-ROUGE-L A->B: 0.8768353316483118
self-ROUGE-L B->A: 0.8989919343128985
style accuracy: 0.23163841807909605
acc-BLEU: 45.68935196378865
g-acc-BLEU: 39.750701567247184
h-acc-BLEU: 34.58394986754691
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 64.20449003740025
self-BLEU B->A: 60.023843559392105
self-BLEU avg: 62.11416679839618
self-ROUGE-1 A->B: 0.8766799953910016
self-ROUGE-1 B->A: 0.8249698809340412
self-ROUGE-2 A->B: 0.7791711896303373
self-ROUGE-2 B->A: 0.7020592483057313
self-ROUGE-L A->B: 0.8766799953910016
self-ROUGE-L B->A: 0.82449197627009
style accuracy: 0.58
acc-BLEU: 60.057083399198085
g-acc-BLEU: 60.02184330980662
h-acc-BLEU: 59.98662339906054
End validation_AC_epoch.

lyr->tru: 100%|██████████| 717/717 [36:10<00:00,  3.03s/it]
lyr->sha: 100%|██████████| 717/717 [30:44<00:00,  2.57s/it]


Start validation_AB_epoch...
self-BLEU A->B: 63.388996594143144
self-BLEU B->A: 73.14201819277297
self-BLEU avg: 68.26550739345805
self-ROUGE-1 A->B: 0.8667658984834571
self-ROUGE-1 B->A: 0.9018313009235644
self-ROUGE-2 A->B: 0.7683229728662295
self-ROUGE-2 B->A: 0.8324096139373304
self-ROUGE-L A->B: 0.8667658984834571
self-ROUGE-L B->A: 0.9014813077735833
style accuracy: 0.2175141242937853
acc-BLEU: 45.00845991141829
g-acc-BLEU: 38.53402660008288
h-acc-BLEU: 32.99093531400143
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 63.388996594143144
self-BLEU B->A: 59.99626785630341
self-BLEU avg: 61.69263222522328
self-ROUGE-1 A->B: 0.8667658984834571
self-ROUGE-1 B->A: 0.8263920903845088
self-ROUGE-2 A->B: 0.7683229728662295
self-ROUGE-2 B->A: 0.7046342340948994
self-ROUGE-L A->B: 0.8667658984834571
self-ROUGE-L B->A: 0.8255356017097146
style accuracy: 0.5092307692307693
acc-BLEU: 56.3078545741501
g-acc-BLEU: 56.04978730015075
h-acc-BLEU: 55.79290229143146
End valida

lyr->tru: 100%|██████████| 717/717 [35:33<00:00,  2.98s/it]
lyr->sha: 100%|██████████| 717/717 [30:22<00:00,  2.54s/it]


Start validation_AB_epoch...
self-BLEU A->B: 63.7595669258787
self-BLEU B->A: 73.56509248995015
self-BLEU avg: 68.66232970791442
self-ROUGE-1 A->B: 0.8657451114059797
self-ROUGE-1 B->A: 0.90303311688328
self-ROUGE-2 A->B: 0.772438966921165
self-ROUGE-2 B->A: 0.8293485816536531
self-ROUGE-L A->B: 0.8657451114059797
self-ROUGE-L B->A: 0.9029875006073927
style accuracy: 0.22598870056497175
acc-BLEU: 45.6305998822058
g-acc-BLEU: 39.39151008587414
h-acc-BLEU: 34.00549310875656
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 63.7595669258787
self-BLEU B->A: 59.83287541247771
self-BLEU avg: 61.79622116917821
self-ROUGE-1 A->B: 0.8657451114059797
self-ROUGE-1 B->A: 0.8245652156372916
self-ROUGE-2 A->B: 0.772438966921165
self-ROUGE-2 B->A: 0.6983590657479993
self-ROUGE-L A->B: 0.8657451114059797
self-ROUGE-L B->A: 0.8240490348028052
style accuracy: 0.4938461538461538
acc-BLEU: 55.5904182768968
g-acc-BLEU: 55.242941763292194
h-acc-BLEU: 54.897636711648666
End validation_A

lyr->tru: 100%|██████████| 717/717 [34:52<00:00,  2.92s/it]
lyr->sha: 100%|██████████| 717/717 [30:06<00:00,  2.52s/it]


Start validation_AB_epoch...
self-BLEU A->B: 63.77654079753442
self-BLEU B->A: 71.48947289585544
self-BLEU avg: 67.63300684669494
self-ROUGE-1 A->B: 0.8658476685533568
self-ROUGE-1 B->A: 0.8783331906842424
self-ROUGE-2 A->B: 0.773595100215155
self-ROUGE-2 B->A: 0.8027034998261824
self-ROUGE-L A->B: 0.8658476685533568
self-ROUGE-L B->A: 0.8783331906842424
style accuracy: 0.2401129943502825
acc-BLEU: 45.822153140861595
g-acc-BLEU: 40.298342138198535
h-acc-BLEU: 35.44042019096972
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 63.81123720857833
self-BLEU B->A: 58.76962782465006
self-BLEU avg: 61.290432516614196
self-ROUGE-1 A->B: 0.8659518352200234
self-ROUGE-1 B->A: 0.8077374736368409
self-ROUGE-2 A->B: 0.7736815626657478
self-ROUGE-2 B->A: 0.6784626849088652
self-ROUGE-L A->B: 0.8659518352200234
self-ROUGE-L B->A: 0.807308305104224
style accuracy: 0.46153846153846156
acc-BLEU: 53.722139335230175
g-acc-BLEU: 53.186362848708704
h-acc-BLEU: 52.65592922676963
End val

lyr->tru: 100%|██████████| 717/717 [33:41<00:00,  2.82s/it]
lyr->sha: 100%|██████████| 717/717 [29:17<00:00,  2.45s/it]


Start validation_AB_epoch...
self-BLEU A->B: 62.273003286321526
self-BLEU B->A: 69.70345004401403
self-BLEU avg: 65.98822666516779
self-ROUGE-1 A->B: 0.8538153218937865
self-ROUGE-1 B->A: 0.8665562277692369
self-ROUGE-2 A->B: 0.7562107951544557
self-ROUGE-2 B->A: 0.7878461246679523
self-ROUGE-L A->B: 0.8538153218937865
self-ROUGE-L B->A: 0.8663805014337166
style accuracy: 0.2627118644067797
acc-BLEU: 46.129706552922876
g-acc-BLEU: 41.636390400830145
h-acc-BLEU: 37.58075037462213
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 62.23582307142483
self-BLEU B->A: 55.46112965988315
self-BLEU avg: 58.848476365653994
self-ROUGE-1 A->B: 0.8534547449707095
self-ROUGE-1 B->A: 0.7860875307948123
self-ROUGE-2 A->B: 0.7555259865420154
self-ROUGE-2 B->A: 0.6515139084599213
self-ROUGE-L A->B: 0.8534547449707095
self-ROUGE-L B->A: 0.785589918135223
style accuracy: 0.5184615384615384
acc-BLEU: 55.34731510590392
g-acc-BLEU: 55.23646584698776
h-acc-BLEU: 55.12583809827964
End vali

lyr->tru: 100%|██████████| 717/717 [33:50<00:00,  2.83s/it]
lyr->sha: 100%|██████████| 717/717 [29:14<00:00,  2.45s/it]


Start validation_AB_epoch...
self-BLEU A->B: 60.89116375625237
self-BLEU B->A: 69.40237417908423
self-BLEU avg: 65.1467689676683
self-ROUGE-1 A->B: 0.8465292653152178
self-ROUGE-1 B->A: 0.8616507351768528
self-ROUGE-2 A->B: 0.7446381007428208
self-ROUGE-2 B->A: 0.7773724578241972
self-ROUGE-L A->B: 0.8462167653152178
self-ROUGE-L B->A: 0.8616507351768528
style accuracy: 0.268361581920904
acc-BLEU: 45.991463579879344
g-acc-BLEU: 41.8125459368347
h-acc-BLEU: 38.01333644615948
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 60.76102885355554
self-BLEU B->A: 54.82513604355268
self-BLEU avg: 57.79308244855411
self-ROUGE-1 A->B: 0.8452113557283625
self-ROUGE-1 B->A: 0.7766593871498073
self-ROUGE-2 A->B: 0.7441910899590642
self-ROUGE-2 B->A: 0.6443132346213316
self-ROUGE-L A->B: 0.8448988557283623
self-ROUGE-L B->A: 0.7761590919904213
style accuracy: 0.5261538461538462
acc-BLEU: 55.204233531969365
g-acc-BLEU: 55.1434969977359
h-acc-BLEU: 55.08282678785736
End validatio

lyr->tru: 100%|██████████| 717/717 [33:20<00:00,  2.79s/it]
lyr->sha: 100%|██████████| 717/717 [28:55<00:00,  2.42s/it]


Start validation_AB_epoch...
self-BLEU A->B: 59.487004527276596
self-BLEU B->A: 63.308363355817484
self-BLEU avg: 61.397683941547044
self-ROUGE-1 A->B: 0.8340192116749222
self-ROUGE-1 B->A: 0.827277565018828
self-ROUGE-2 A->B: 0.7332054505212273
self-ROUGE-2 B->A: 0.7328247756154577
self-ROUGE-L A->B: 0.833769211674922
self-ROUGE-L B->A: 0.8271871327525954
style accuracy: 0.3107344632768362
acc-BLEU: 46.23556513461533
g-acc-BLEU: 43.67880076881398
h-acc-BLEU: 41.2634215763416
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 59.349586940475774
self-BLEU B->A: 52.767940962871414
self-BLEU avg: 56.0587639516736
self-ROUGE-1 A->B: 0.8332419073819928
self-ROUGE-1 B->A: 0.7563327975845695
self-ROUGE-2 A->B: 0.7323836695061898
self-ROUGE-2 B->A: 0.6204633333099094
self-ROUGE-L A->B: 0.8329919073819928
self-ROUGE-L B->A: 0.7561644474162194
style accuracy: 0.5307692307692308
acc-BLEU: 54.56784351429834
g-acc-BLEU: 54.5474720042127
h-acc-BLEU: 54.527107599683085
End valida

lyr->tru: 100%|██████████| 717/717 [33:10<00:00,  2.78s/it]
lyr->sha: 100%|██████████| 717/717 [29:07<00:00,  2.44s/it]


Start validation_AB_epoch...
self-BLEU A->B: 56.64657617435231
self-BLEU B->A: 59.18641058342876
self-BLEU avg: 57.916493378890536
self-ROUGE-1 A->B: 0.8101477574015915
self-ROUGE-1 B->A: 0.804005676293448
self-ROUGE-2 A->B: 0.7038169291522702
self-ROUGE-2 B->A: 0.6983264388027259
self-ROUGE-L A->B: 0.8098636664925006
self-ROUGE-L B->A: 0.8033576966864909
style accuracy: 0.3418079096045198
acc-BLEU: 46.048642169671254
g-acc-BLEU: 44.493050618565796
h-acc-BLEU: 42.99000879455002
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 56.57416025523334
self-BLEU B->A: 49.90256188816078
self-BLEU avg: 53.23836107169706
self-ROUGE-1 A->B: 0.8106881891920233
self-ROUGE-1 B->A: 0.732040055880001
self-ROUGE-2 A->B: 0.7049707579623491
self-ROUGE-2 B->A: 0.5876448404460429
self-ROUGE-L A->B: 0.8104040982829325
self-ROUGE-L B->A: 0.7306785084845809
style accuracy: 0.5646153846153846
acc-BLEU: 54.849949766617755
g-acc-BLEU: 54.82626898922537
h-acc-BLEU: 54.80259793614305
End valid

lyr->tru: 100%|██████████| 717/717 [32:48<00:00,  2.75s/it]
lyr->sha: 100%|██████████| 717/717 [28:51<00:00,  2.42s/it]


Start validation_AB_epoch...
self-BLEU A->B: 55.801143266818634
self-BLEU B->A: 55.94997450869129
self-BLEU avg: 55.87555888775496
self-ROUGE-1 A->B: 0.7980166065874006
self-ROUGE-1 B->A: 0.7869697671367597
self-ROUGE-2 A->B: 0.6945296146909861
self-ROUGE-2 B->A: 0.668982883580861
self-ROUGE-L A->B: 0.7977666065874007
self-ROUGE-L B->A: 0.7866665530676269
style accuracy: 0.3347457627118644
acc-BLEU: 44.675067579470706
g-acc-BLEU: 43.24824456187006
h-acc-BLEU: 41.866990652507205
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 55.691468794618366
self-BLEU B->A: 47.06792494568633
self-BLEU avg: 51.379696870152344
self-ROUGE-1 A->B: 0.7989148036418477
self-ROUGE-1 B->A: 0.7074038594162801
self-ROUGE-2 A->B: 0.6952006495739775
self-ROUGE-2 B->A: 0.5569347418475531
self-ROUGE-L A->B: 0.7986648036418477
self-ROUGE-L B->A: 0.7058393135026331
style accuracy: 0.563076923076923
acc-BLEU: 53.843694588922325
g-acc-BLEU: 53.787286250814326
h-acc-BLEU: 53.73093650889227
End va

lyr->tru: 100%|██████████| 717/717 [33:11<00:00,  2.78s/it]
lyr->sha: 100%|██████████| 717/717 [28:39<00:00,  2.40s/it]


Start validation_AB_epoch...
self-BLEU A->B: 54.202459840719165
self-BLEU B->A: 53.98718333589026
self-BLEU avg: 54.09482158830471
self-ROUGE-1 A->B: 0.7882462935586242
self-ROUGE-1 B->A: 0.7678225427495026
self-ROUGE-2 A->B: 0.6798726580378742
self-ROUGE-2 B->A: 0.6509900170626335
self-ROUGE-L A->B: 0.7882462935586242
self-ROUGE-L B->A: 0.7673988366888124
style accuracy: 0.3531073446327684
acc-BLEU: 44.70277802579078
g-acc-BLEU: 43.70500979227625
h-acc-BLEU: 42.72951131753333
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 54.37245549462007
self-BLEU B->A: 45.04364977234448
self-BLEU avg: 49.708052633482275
self-ROUGE-1 A->B: 0.7900392738626338
self-ROUGE-1 B->A: 0.6845739266676745
self-ROUGE-2 A->B: 0.6807542417319578
self-ROUGE-2 B->A: 0.5416260238085896
self-ROUGE-L A->B: 0.7900392738626338
self-ROUGE-L B->A: 0.6834405540331998
style accuracy: 0.5815384615384616
acc-BLEU: 53.930949393664214
g-acc-BLEU: 53.76536473841515
h-acc-BLEU: 53.60028798221511
End vali

lyr->tru: 100%|██████████| 717/717 [32:14<00:00,  2.70s/it]
lyr->sha: 100%|██████████| 717/717 [28:46<00:00,  2.41s/it]


Start validation_AB_epoch...
self-BLEU A->B: 53.257764772383645
self-BLEU B->A: 49.13168144524062
self-BLEU avg: 51.194723108812134
self-ROUGE-1 A->B: 0.7783368118462143
self-ROUGE-1 B->A: 0.7375963508200611
self-ROUGE-2 A->B: 0.6682014234603938
self-ROUGE-2 B->A: 0.6146282761950931
self-ROUGE-L A->B: 0.7783368118462143
self-ROUGE-L B->A: 0.7373470296944512
style accuracy: 0.3912429378531073
acc-BLEU: 45.159508447061434
g-acc-BLEU: 44.7544119296277
h-acc-BLEU: 44.35294877795912
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 53.492653007729565
self-BLEU B->A: 41.60053132223724
self-BLEU avg: 47.5465921649834
self-ROUGE-1 A->B: 0.7809788031255137
self-ROUGE-1 B->A: 0.6551016422187295
self-ROUGE-2 A->B: 0.6699398048706502
self-ROUGE-2 B->A: 0.5068524082059308
self-ROUGE-L A->B: 0.7809788031255137
self-ROUGE-L B->A: 0.6532184437122971
style accuracy: 0.6215384615384615
acc-BLEU: 54.85021915941478
g-acc-BLEU: 54.36178413703919
h-acc-BLEU: 53.87769808237766
End valid

lyr->tru: 100%|██████████| 717/717 [32:12<00:00,  2.69s/it]
lyr->sha: 100%|██████████| 717/717 [28:41<00:00,  2.40s/it]


Start validation_AB_epoch...
self-BLEU A->B: 51.53721373621668
self-BLEU B->A: 45.07136807296184
self-BLEU avg: 48.30429090458926
self-ROUGE-1 A->B: 0.7609685246713267
self-ROUGE-1 B->A: 0.7087103259630676
self-ROUGE-2 A->B: 0.6486815757561526
self-ROUGE-2 B->A: 0.5816883524634232
self-ROUGE-L A->B: 0.7609685246713267
self-ROUGE-L B->A: 0.7074422349163387
style accuracy: 0.4138418079096045
acc-BLEU: 44.844235847774854
g-acc-BLEU: 44.71055253264791
h-acc-BLEU: 44.57726723835402
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 51.66763140505816
self-BLEU B->A: 39.68807605507806
self-BLEU avg: 45.67785373006811
self-ROUGE-1 A->B: 0.7622191951180112
self-ROUGE-1 B->A: 0.6416506802060032
self-ROUGE-2 A->B: 0.6511115278506268
self-ROUGE-2 B->A: 0.4975191253649399
self-ROUGE-L A->B: 0.7622191951180112
self-ROUGE-L B->A: 0.6406678556353298
style accuracy: 0.6461538461538462
acc-BLEU: 55.14661917272636
g-acc-BLEU: 54.32763649537528
h-acc-BLEU: 53.520816048011355
End valid

lyr->tru: 100%|██████████| 717/717 [32:07<00:00,  2.69s/it]
lyr->sha: 100%|██████████| 717/717 [28:12<00:00,  2.36s/it]


Start validation_AB_epoch...
self-BLEU A->B: 51.895760732727275
self-BLEU B->A: 42.486452217751875
self-BLEU avg: 47.19110647523958
self-ROUGE-1 A->B: 0.7660605819076018
self-ROUGE-1 B->A: 0.6841673507114346
self-ROUGE-2 A->B: 0.6564933912525324
self-ROUGE-2 B->A: 0.5567116988062926
self-ROUGE-L A->B: 0.7660605819076018
self-ROUGE-L B->A: 0.6826296698911819
style accuracy: 0.4519774011299435
acc-BLEU: 46.194423294116966
g-acc-BLEU: 46.183669907365775
h-acc-BLEU: 46.17291852407866
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 51.515452400859104
self-BLEU B->A: 38.01890979012856
self-BLEU avg: 44.76718109549383
self-ROUGE-1 A->B: 0.7621768352756935
self-ROUGE-1 B->A: 0.6206328750891428
self-ROUGE-2 A->B: 0.6517093835935247
self-ROUGE-2 B->A: 0.47200868566230453
self-ROUGE-L A->B: 0.7621768352756935
self-ROUGE-L B->A: 0.6184743918857792
style accuracy: 0.6538461538461539
acc-BLEU: 55.07589824005461
g-acc-BLEU: 54.10254077011808
h-acc-BLEU: 53.14638498406439
End v

lyr->tru: 100%|██████████| 717/717 [31:46<00:00,  2.66s/it]
lyr->sha: 100%|██████████| 717/717 [28:38<00:00,  2.40s/it]


Start validation_AB_epoch...
self-BLEU A->B: 51.62334217387438
self-BLEU B->A: 40.20314277747333
self-BLEU avg: 45.913242475673854
self-ROUGE-1 A->B: 0.7585827370364803
self-ROUGE-1 B->A: 0.6559488661755737
self-ROUGE-2 A->B: 0.6484432806103023
self-ROUGE-2 B->A: 0.5251767060280509
self-ROUGE-L A->B: 0.7585827370364803
self-ROUGE-L B->A: 0.6544677115280979
style accuracy: 0.4533898305084746
acc-BLEU: 45.62611276326066
g-acc-BLEU: 45.625209286249046
h-acc-BLEU: 45.624305327147674
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 51.4675031007114
self-BLEU B->A: 36.2543149998738
self-BLEU avg: 43.8609090502926
self-ROUGE-1 A->B: 0.7578462757853132
self-ROUGE-1 B->A: 0.599020347449596
self-ROUGE-2 A->B: 0.6479745406514308
self-ROUGE-2 B->A: 0.45112826941849554
self-ROUGE-L A->B: 0.7578462757853132
self-ROUGE-L B->A: 0.5974876274505244
style accuracy: 0.6323076923076923
acc-BLEU: 53.545839140530916
g-acc-BLEU: 52.6626909529964
h-acc-BLEU: 51.79410831969747
End validat

lyr->tru: 100%|██████████| 717/717 [32:00<00:00,  2.68s/it]
lyr->sha: 100%|██████████| 717/717 [28:02<00:00,  2.35s/it]


Start validation_AB_epoch...
self-BLEU A->B: 49.94699255884448
self-BLEU B->A: 39.27710878826423
self-BLEU avg: 44.61205067355435
self-ROUGE-1 A->B: 0.7395953795929735
self-ROUGE-1 B->A: 0.6404428680487769
self-ROUGE-2 A->B: 0.6301152641003647
self-ROUGE-2 B->A: 0.518489490115291
self-ROUGE-L A->B: 0.7393112886838826
self-ROUGE-L B->A: 0.6393259309087492
style accuracy: 0.4788135593220339
acc-BLEU: 46.24670330287887
g-acc-BLEU: 46.21780476359678
h-acc-BLEU: 46.1889237829957
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 49.93415026528061
self-BLEU B->A: 35.091376388175625
self-BLEU avg: 42.51276332672812
self-ROUGE-1 A->B: 0.7374681214774221
self-ROUGE-1 B->A: 0.5834927452791251
self-ROUGE-2 A->B: 0.6270048032785068
self-ROUGE-2 B->A: 0.44329559088693615
self-ROUGE-L A->B: 0.7371840305683313
self-ROUGE-L B->A: 0.5821257634974354
style accuracy: 0.6615384615384615
acc-BLEU: 54.33330474028713
g-acc-BLEU: 53.03190364951313
h-acc-BLEU: 51.761673475478446
End valida

lyr->tru: 100%|██████████| 717/717 [32:02<00:00,  2.68s/it]
lyr->sha: 100%|██████████| 717/717 [28:15<00:00,  2.36s/it]


Start validation_AB_epoch...
self-BLEU A->B: 49.48722348768943
self-BLEU B->A: 39.08454501172929
self-BLEU avg: 44.285884249709355
self-ROUGE-1 A->B: 0.7317703241341151
self-ROUGE-1 B->A: 0.6327459520405515
self-ROUGE-2 A->B: 0.6217353364244693
self-ROUGE-2 B->A: 0.509988406818698
self-ROUGE-L A->B: 0.7314862332250243
self-ROUGE-L B->A: 0.6307262875901128
style accuracy: 0.481638418079096
acc-BLEU: 46.22486302880948
g-acc-BLEU: 46.18417827921589
h-acc-BLEU: 46.14352883912845
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 49.41071178263177
self-BLEU B->A: 33.367357382817566
self-BLEU avg: 41.38903458272467
self-ROUGE-1 A->B: 0.7298287193907778
self-ROUGE-1 B->A: 0.5720773254795711
self-ROUGE-2 A->B: 0.6197731053939166
self-ROUGE-2 B->A: 0.4290076538572197
self-ROUGE-L A->B: 0.72912796181502
self-ROUGE-L B->A: 0.570228517358077
style accuracy: 0.6553846153846153
acc-BLEU: 53.463748060593105
g-acc-BLEU: 52.08237370852019
h-acc-BLEU: 50.73669026479992
End validatio

lyr->tru: 100%|██████████| 717/717 [32:12<00:00,  2.69s/it]
lyr->sha: 100%|██████████| 717/717 [27:57<00:00,  2.34s/it]


Start validation_AB_epoch...
self-BLEU A->B: 45.85371684532421
self-BLEU B->A: 36.394200205375675
self-BLEU avg: 41.12395852534994
self-ROUGE-1 A->B: 0.6959385156141544
self-ROUGE-1 B->A: 0.6125547616238232
self-ROUGE-2 A->B: 0.5845227830121964
self-ROUGE-2 B->A: 0.4892597864064899
self-ROUGE-L A->B: 0.6950010156141545
self-ROUGE-L B->A: 0.6108876764915218
style accuracy: 0.501412429378531
acc-BLEU: 45.63260073160152
g-acc-BLEU: 45.40932057392806
h-acc-BLEU: 45.18713243017548
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 46.03385376729942
self-BLEU B->A: 31.535978040881496
self-BLEU avg: 38.784915904090454
self-ROUGE-1 A->B: 0.6966615060292405
self-ROUGE-1 B->A: 0.5351015128207824
self-ROUGE-2 A->B: 0.585296134660548
self-ROUGE-2 B->A: 0.3933515499774254
self-ROUGE-L A->B: 0.6962448393625738
self-ROUGE-L B->A: 0.5335259979349227
style accuracy: 0.6923076923076923
acc-BLEU: 54.00784256742984
g-acc-BLEU: 51.81804282864104
h-acc-BLEU: 49.71703015872095
End valida

lyr->tru: 100%|██████████| 717/717 [31:54<00:00,  2.67s/it]
lyr->sha: 100%|██████████| 717/717 [27:50<00:00,  2.33s/it]


Start validation_AB_epoch...
self-BLEU A->B: 46.524100852401055
self-BLEU B->A: 34.74837965870938
self-BLEU avg: 40.63624025555522
self-ROUGE-1 A->B: 0.7066790360409534
self-ROUGE-1 B->A: 0.5825547994554414
self-ROUGE-2 A->B: 0.5919129607622422
self-ROUGE-2 B->A: 0.4621827011786293
self-ROUGE-L A->B: 0.7066790360409534
self-ROUGE-L B->A: 0.5802755065147185
style accuracy: 0.5112994350282486
acc-BLEU: 45.88309187919004
g-acc-BLEU: 45.58210908277233
h-acc-BLEU: 45.2831001725615
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 46.253358574027175
self-BLEU B->A: 28.552025076432432
self-BLEU avg: 37.4026918252298
self-ROUGE-1 A->B: 0.7056750438762123
self-ROUGE-1 B->A: 0.508858313749714
self-ROUGE-2 A->B: 0.5891593688680252
self-ROUGE-2 B->A: 0.36840383643065105
self-ROUGE-L A->B: 0.7056750438762123
self-ROUGE-L B->A: 0.5070490709948987
style accuracy: 0.7246153846153847
acc-BLEU: 54.93211514338414
g-acc-BLEU: 52.060124781438624
h-acc-BLEU: 49.338288913789086
End vali

lyr->tru: 100%|██████████| 717/717 [33:10<00:00,  2.78s/it]
lyr->sha: 100%|██████████| 717/717 [27:49<00:00,  2.33s/it]


Start validation_AB_epoch...
self-BLEU A->B: 45.67524025886218
self-BLEU B->A: 30.60878462176866
self-BLEU avg: 38.14201244031542
self-ROUGE-1 A->B: 0.6924370051143107
self-ROUGE-1 B->A: 0.5422084177798391
self-ROUGE-2 A->B: 0.577489839007458
self-ROUGE-2 B->A: 0.4208850338847633
self-ROUGE-L A->B: 0.6920269522042577
self-ROUGE-L B->A: 0.5395418417693674
style accuracy: 0.5268361581920904
acc-BLEU: 45.412814129762225
g-acc-BLEU: 44.826991087703725
h-acc-BLEU: 44.24872464654787
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 45.469933966806025
self-BLEU B->A: 26.750230685897296
self-BLEU avg: 36.11008232635166
self-ROUGE-1 A->B: 0.6898545360445536
self-ROUGE-1 B->A: 0.48198044753177477
self-ROUGE-2 A->B: 0.5750995265786761
self-ROUGE-2 B->A: 0.3439962163817303
self-ROUGE-L A->B: 0.6891603922254098
self-ROUGE-L B->A: 0.4796274739999395
style accuracy: 0.7353846153846154
acc-BLEU: 54.8242719324066
g-acc-BLEU: 51.531348714225324
h-acc-BLEU: 48.43620868076379
End val

lyr->tru: 100%|██████████| 717/717 [32:48<00:00,  2.74s/it]
lyr->sha: 100%|██████████| 717/717 [27:55<00:00,  2.34s/it]


Start validation_AB_epoch...
self-BLEU A->B: 43.87545225177326
self-BLEU B->A: 31.5057795453828
self-BLEU avg: 37.69061589857803
self-ROUGE-1 A->B: 0.6782597067827657
self-ROUGE-1 B->A: 0.5467329543324085
self-ROUGE-2 A->B: 0.5620225444766361
self-ROUGE-2 B->A: 0.42571359609979714
self-ROUGE-L A->B: 0.6774753930572756
self-ROUGE-L B->A: 0.5450551021639336
style accuracy: 0.5127118644067796
acc-BLEU: 44.480901169628
g-acc-BLEU: 43.95955635353905
h-acc-BLEU: 43.44432154619443
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 43.78113105225062
self-BLEU B->A: 25.069527556540706
self-BLEU avg: 34.42532930439566
self-ROUGE-1 A->B: 0.6752468491481977
self-ROUGE-1 B->A: 0.47090039839504055
self-ROUGE-2 A->B: 0.5579348719797144
self-ROUGE-2 B->A: 0.3273575021594348
self-ROUGE-L A->B: 0.6744625354227074
self-ROUGE-L B->A: 0.4680044250158782
style accuracy: 0.7353846153846154
acc-BLEU: 53.981895421428604
g-acc-BLEU: 50.314866143120895
h-acc-BLEU: 46.8969403869203
End valida

lyr->tru: 100%|██████████| 717/717 [32:58<00:00,  2.76s/it]
lyr->sha: 100%|██████████| 717/717 [28:05<00:00,  2.35s/it]


Start validation_AB_epoch...
self-BLEU A->B: 43.19261411428304
self-BLEU B->A: 30.1105946477635
self-BLEU avg: 36.65160438102327
self-ROUGE-1 A->B: 0.6699032494481625
self-ROUGE-1 B->A: 0.5277535405145598
self-ROUGE-2 A->B: 0.5567585895311838
self-ROUGE-2 B->A: 0.40645347075485033
self-ROUGE-L A->B: 0.6699032494481625
self-ROUGE-L B->A: 0.5258545582326073
style accuracy: 0.5324858757062146
acc-BLEU: 44.950095975822364
g-acc-BLEU: 44.17743955331376
h-acc-BLEU: 43.41806399762539
End validation_AB_epoch...
Start validation_AC_epoch...
self-BLEU A->B: 43.1869082249939
self-BLEU B->A: 23.60241537258346
self-BLEU avg: 33.39466179878868
self-ROUGE-1 A->B: 0.6677739897267718
self-ROUGE-1 B->A: 0.4451393421540703
self-ROUGE-2 A->B: 0.5525721798680866
self-ROUGE-2 B->A: 0.3084987421401981
self-ROUGE-L A->B: 0.6675425082452904
self-ROUGE-L B->A: 0.4438409211639043
style accuracy: 0.7292307692307692
acc-BLEU: 53.1588693609328
g-acc-BLEU: 49.34816603657329
h-acc-BLEU: 45.810633249803125
End validat

lyr->tru:  53%|█████▎    | 377/717 [17:32<15:03,  2.66s/it]