In [None]:
# @title Start installing and importing the required libraries

# Ensure the Google Colab runtime is using GPU
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")

#Install libraries
!pip install fsspec
!pip install torch
!pip install scikit-learn
!pip install tokenizers
!pip install transformers
!pip install bert-score
!pip install rouge-score
!pip install sacrebleu
!pip install evaluate
!pip install tabulate




Is CUDA available: True
CUDA version: 12.4
PyTorch version: 2.5.1+cu124
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cura

In [None]:
# @title Reorganized imports
import argparse
import csv
import logging
import os
import pickle
import random
import sys
import time
from typing import List, Optional, Tuple, Union, Dict

# Third-party libraries
import evaluate
import numpy as np
import pandas as pd
import rouge_score
import sacrebleu
import sklearn
import tokenizers
import transformers
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DebertaForSequenceClassification,
    DebertaTokenizer,
    get_scheduler
)


In [None]:
# @title Mount Google Drive on Colab for persistent storage
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Class Definition

In [None]:
# @title MonostyleDataset Class, it is used to represents datasets
# Configure logging
logging.basicConfig(level=logging.DEBUG)

class MonostyleDataset(Dataset):
    """
    Mono-style dataset:
    Loads textual data from CSV files, line-based files, or a provided list of sentences.
    """

    def __init__(
        self,
        dataset_format: str,
        dataset_path: str = None,
        sentences_list: List[str] = None,
        text_column_name: str = None,
        separator: str = None,
        style: str = None,
        max_dataset_samples: int = None,
        SEED: int = 42
    ):
        super(MonostyleDataset, self).__init__()

        self.allowed_dataset_formats = ["list", "csv", "line_file"]
        if dataset_format not in self.allowed_dataset_formats:
            raise Exception(
                f"MonostyleDataset: '{dataset_format}' is not supported. "
                f"Allowed formats: {self.allowed_dataset_formats}."
            )

        self.dataset_format = dataset_format
        self.dataset_path = dataset_path
        self.sentences_list = sentences_list
        self.text_column_name = text_column_name
        self.separator = separator
        self.style = style
        self.max_dataset_samples = max_dataset_samples

        # Load data based on the format
        self.load_data(SEED)

    def _load_data_csv(self):
        try:
            df = pd.read_csv(self.dataset_path, sep=self.separator, header=None, encoding='utf-8')
            df.dropna(inplace=True)
            if self.text_column_name is not None:
                self.data = df[self.text_column_name].tolist()
            else:
                self.data = df.iloc[:, 0].tolist()
            logging.debug(
                f"MonostyleDataset, _load_data_csv: parsed {len(self.data)} examples from '{self.dataset_path}'."
            )
        except UnicodeDecodeError as e:
            logging.error(
                f"MonostyleDataset, _load_data_csv: UnicodeDecodeError while reading '{self.dataset_path}': {e}"
            )
            raise
        except FileNotFoundError:
            logging.error(
                f"MonostyleDataset, _load_data_csv: File not found: '{self.dataset_path}'."
            )
            raise
        except Exception as e:
            logging.error(
                f"MonostyleDataset, _load_data_csv: Error loading CSV dataset: {e}"
            )
            raise

    def _load_data_line_file(self):
        try:
            with open(self.dataset_path, 'r', encoding='utf-8') as f:
                self.data = f.read().split(self.separator)
            logging.debug(
                f"MonostyleDataset, _load_data_line_file: parsed {len(self.data)} examples from '{self.dataset_path}'."
            )
        except UnicodeDecodeError as e:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: UnicodeDecodeError while reading '{self.dataset_path}': {e}"
            )
            raise
        except FileNotFoundError:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: File not found: '{self.dataset_path}'."
            )
            raise
        except Exception as e:
            logging.error(
                f"MonostyleDataset, _load_data_line_file: Error loading line_file dataset: {e}"
            )
            raise

    def load_data(self, SEED=42):
        if self.dataset_format == "csv":
            self._load_data_csv()
        elif self.dataset_format == "line_file":
            self._load_data_line_file()
        elif self.dataset_format == "list":
            if self.sentences_list is None:
                raise Exception(
                    "MonostyleDataset: 'list' format specified but 'sentences_list' is None."
                )
            self.data = self.sentences_list
            logging.debug(
                f"MonostyleDataset, load_data: data already loaded, {len(self.data)} examples."
            )
        else:
            raise Exception(
                f"MonostyleDataset, load_data: '{self.dataset_format}' format is not supported."
            )

        # Limit the number of samples if needed
        if self.max_dataset_samples is not None and self.max_dataset_samples < len(self.data):
            random.seed(SEED)
            ix = random.sample(range(len(self.data)), self.max_dataset_samples)
            self.data = [self.data[i] for i in ix]
            logging.debug(f"MonostyleDataset, load_data: reduced data to {len(self.data)} samples.")

        # Shuffle the data
        random.shuffle(self.data)
        logging.debug("MonostyleDataset, load_data: data has been shuffled.")

    def reduce_data(self, n_samples):
        if n_samples < len(self.data):
            self.data = self.data[:n_samples]
            logging.debug(f"MonostyleDataset, reduce_data: reduced data to {n_samples} samples.")
        else:
            logging.debug(
                f"MonostyleDataset, reduce_data: requested {n_samples}, but dataset has {len(self.data)}."
            )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [None]:
# @title GeneratorModel Class, it represents the Generator of the GAN
class GeneratorModel(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        new_style_tokens: List[str] = None,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
    ):
        super(GeneratorModel, self).__init__()

        self.model_name_or_path = model_name_or_path
        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        # If no style tokens are provided, use default ones
        if new_style_tokens is None:
            new_style_tokens = [
                '[pos->neu]', '[pos->neg]',
                '[neu->pos]', '[neg->pos]',
                '[neu->neg]', '[neg->neu]'
            ]

        if pretrained_path is None:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        else:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_path)
            self.tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_path}tokenizer/")

        num_added_tokens = self.tokenizer.add_tokens(new_style_tokens)
        print(f"Added {num_added_tokens} new tokens to the tokenizer.")

        # Resizing embeddings to include the new tokens
        self.model.resize_token_embeddings(len(self.tokenizer))
        print(f"New embedding size: {len(self.tokenizer)} tokens.")

    def train(self):
        # Setting the model in training mode
        self.model.train()

    def eval(self):
        # Setting the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_sentences: List[str] = None,
        device=None,
    ):

        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )

        if target_sentences is not None:
            target = self.tokenizer(
                target_sentences,
                truncation=self.truncation,
                padding=self.padding,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
            labels = target["input_ids"]
            inputs = inputs.to(device)
            labels = labels.to(device)
            output_supervised = self.model(**inputs, labels=labels)

        inputs = inputs.to(device)
        output = self.model.generate(**inputs, max_length=self.max_seq_length)
        transferred_sentences = self.tokenizer.batch_decode(output, skip_special_tokens=True)

        if target_sentences is not None:
            return output, transferred_sentences, output_supervised.loss
        else:
            return output, transferred_sentences

    def transfer(
        self,
        sentences: List[str],
        device=None
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )

        inputs = inputs.to(device)
        output = self.model.generate(**inputs, max_length=self.max_seq_length)
        transferred_sentences = self.tokenizer.batch_decode(output, skip_special_tokens=True)
        return transferred_sentences

    def save_model(
        self,
        path: Union[str]
    ):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(f"{path}/tokenizer/")


In [None]:
# @title DiscriminatorModel Class, it represents the discriminator
class DiscriminatorModel(nn.Module):
    def __init__(
        self,
        model_name_or_path: str,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
    ):
        super(DiscriminatorModel, self).__init__()

        self.model_name_or_path = model_name_or_path
        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        if pretrained_path is None:
            self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        else:
            self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_path)
            self.tokenizer = AutoTokenizer.from_pretrained(f"{pretrained_path}tokenizer/")

    def train(self):
        # Set the model in training mode
        self.model.train()

    def eval(self):
        # Set the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_labels: Tensor,
        return_hidden: bool = False,
        device=None,
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )
        inputs["labels"] = target_labels
        inputs = inputs.to(device)
        output = self.model(**inputs, output_hidden_states=return_hidden)
        return output, output.loss

    def save_model(
        self,
        path: Union[str]
    ):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(f"{path}/tokenizer")


In [None]:
# @title ClassifierModel Class, it represents the classifier
class ClassifierModel(nn.Module):
    def __init__(
        self,
        pretrained_path: str = None,
        max_seq_length: int = 64,
        truncation: str = "longest_first",
        padding: str = "max_length",
    ):
        super(ClassifierModel, self).__init__()

        self.max_seq_length = max_seq_length
        self.truncation = truncation
        self.padding = padding

        self.model = AutoModelForSequenceClassification.from_pretrained(pretrained_path)
        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
        self.model.eval()

    def eval(self):
        # Set the model in evaluation mode
        self.model.eval()

    def forward(
        self,
        sentences: List[str],
        target_labels: Tensor,
        return_hidden: bool = False,
        device=None,
    ):
        inputs = self.tokenizer(
            sentences,
            truncation=self.truncation,
            padding=self.padding,
            max_length=self.max_seq_length,
            return_tensors="pt"
        )
        inputs["labels"] = target_labels
        inputs = inputs.to(device)
        output = self.model(**inputs, output_hidden_states=return_hidden)
        return output, output.loss


In [None]:
# @title CycleGANModel Class, it represents the gan
class CycleGANModel(nn.Module):
    def __init__(
        self,
        G_ab: Union['GeneratorModel', None],
        G_ba: Union['GeneratorModel', None],
        D_ab: Union['DiscriminatorModel', None],
        D_ba: Union['DiscriminatorModel', None],
        Cls: Union['ClassifierModel', None],
        device=None,
        label2id: Dict[str, int] = None
    ):
        """
        Initialization method for the CycleGANModel

        Args:
            G_ab (GeneratorModel): Generator model for mapping A->B
            G_ba (GeneratorModel): Generator model for mapping B->A
            D_ab (DiscriminatorModel): Discriminator model for B
            D_ba (DiscriminatorModel): Discriminator model for A
            Cls (ClassifierModel): Style classifier
            label2id (Dict[str,int]): Style-to-integer mapping (e.g., {"neu": 0, "pos": 1, "neg": 2})
        """
        super(CycleGANModel, self).__init__()

        if G_ab is None or G_ba is None or D_ab is None or D_ba is None:
            logging.warning(
                "CycleGANModel: Some models are missing. Please call 'load_models' to load from a previous checkpoint."
            )

        self.G_ab = G_ab
        self.G_ba = G_ba
        self.D_ab = D_ab
        self.D_ba = D_ba
        self.Cls = Cls

        self.device = device
        logging.info(f"Device: {device}")

        # Use default label2id if none is provided
        if label2id is None:
            label2id = {"neu": 0, "pos": 1, "neg": 2}
        self.label2id = label2id

        # Move all models to device
        self.G_ab.model.to(self.device)
        self.G_ba.model.to(self.device)
        self.D_ab.model.to(self.device)
        self.D_ba.model.to(self.device)
        if self.Cls is not None:
            self.Cls.model.to(self.device)

    def train(self):
        self.G_ab.train()
        self.G_ba.train()
        self.D_ab.train()
        self.D_ba.train()

    def eval(self):
        self.G_ab.eval()
        self.G_ba.eval()
        self.D_ab.eval()
        self.D_ba.eval()

    def get_optimizer_parameters(self):
        params = list(self.G_ab.model.parameters())
        params += list(self.G_ba.model.parameters())
        params += list(self.D_ab.model.parameters())
        params += list(self.D_ba.model.parameters())
        return params

    def training_cycle(
        self,
        sentences_a: List[str],
        sentences_b: List[str],
        target_sentences_ab: List[str] = None,
        target_sentences_ba: List[str] = None,
        style_source=str,
        style_target=str,
        lambdas: List[float] = None,
        loss_logging=None,
        training_step: int = None
    ):
        # ----- Cycle A -> B -----
        token_a_b = f"[{style_source}->{style_target}]"
        token_b_a = f"[{style_target}->{style_source}]"

        label2id = self.label2id

        # First half
        mono_a_with_style = [f"{token_a_b} {s}" for s in sentences_a]
        _, transferred_ab = self.G_ab(mono_a_with_style, device=self.device)

        # D_ab fake
        self.D_ab.eval()
        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_fake_sentences = torch.column_stack((ones, zeros))  # generator side
        _, loss_g_ab = self.D_ab(transferred_ab, labels_fake_sentences, device=self.device)

        if lambdas[4] != 0:
            labels_style_b_sentences = torch.full(
                (len(transferred_ab),),
                label2id[style_target],
                dtype=int
            )
            _, loss_g_ab_cls = self.Cls(transferred_ab, labels_style_b_sentences, device=self.device)

        # Second half
        mono_transferred_ab_with_style = [f"{token_b_a} {s}" for s in transferred_ab]
        _, _, cycle_loss_aba = self.G_ba(mono_transferred_ab_with_style, sentences_a, device=self.device)

        complete_loss_g_ab = lambdas[0] * cycle_loss_aba + lambdas[1] * loss_g_ab

        loss_logging['Cycle Loss A-B-A'].append((lambdas[0] * cycle_loss_aba).item())
        loss_logging['Loss generator  A-B'].append((lambdas[1] * loss_g_ab).item())

        if lambdas[4] != 0:
            complete_loss_g_ab += lambdas[4] * loss_g_ab_cls

            loss_logging['Classifier-guided A-B'].append((lambdas[4] * loss_g_ab_cls).item())

        complete_loss_g_ab.backward()

        # D_ab training
        self.D_ab.train()
        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_fake_sentences = torch.column_stack((zeros, ones))  # discriminator side
        _, loss_d_ab_fake = self.D_ab(transferred_ab, labels_fake_sentences, device=self.device)

        zeros = torch.zeros(len(transferred_ab))
        ones = torch.ones(len(transferred_ab))
        labels_real_sentences = torch.column_stack((ones, zeros))
        _, loss_d_ab_real = self.D_ab(sentences_b, labels_real_sentences, device=self.device)
        complete_loss_d_ab = lambdas[2] * loss_d_ab_fake + lambdas[3] * loss_d_ab_real


        loss_logging['Loss D(A->B)'].append(complete_loss_d_ab.item())
        complete_loss_d_ab.backward()

        # ----- Cycle B -> A -----
        mono_b_with_style = [f"{token_b_a} {s}" for s in sentences_b]

        # First half
        _, transferred_ba = self.G_ba(mono_b_with_style, device=self.device)

        # D_ba
        self.D_ba.eval()
        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_fake_sentences = torch.column_stack((ones, zeros))
        _, loss_g_ba = self.D_ba(transferred_ba, labels_fake_sentences, device=self.device)

        if lambdas[4] != 0:
            labels_style_a_sentences = torch.full(
                (len(transferred_ba),),
                label2id[style_source],
                dtype=int
            )
            _, loss_g_ba_cls = self.Cls(transferred_ba, labels_style_a_sentences, device=self.device)

        # Second half
        mono_transferred_ba_with_style = [f"{token_a_b} {s}" for s in transferred_ba]
        _, _, cycle_loss_bab = self.G_ab(mono_transferred_ba_with_style, sentences_b, device=self.device)

        complete_loss_g_ba = lambdas[0] * cycle_loss_bab + lambdas[1] * loss_g_ba

        loss_logging['Cycle Loss B-A-B'].append((lambdas[0] * cycle_loss_bab).item())
        loss_logging['Loss generator  B-A'].append((lambdas[1] * loss_g_ba).item())

        if lambdas[4] != 0:
            complete_loss_g_ba += lambdas[4] * loss_g_ba_cls
            loss_logging['Classifier-guided B-A'].append((lambdas[4] * loss_g_ba_cls).item())

        complete_loss_g_ba.backward()

        # D_ba training
        self.D_ba.train()
        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_fake_sentences = torch.column_stack((zeros, ones))
        _, loss_d_ba_fake = self.D_ba(transferred_ba, labels_fake_sentences, device=self.device)

        zeros = torch.zeros(len(transferred_ba))
        ones = torch.ones(len(transferred_ba))
        labels_real_sentences = torch.column_stack((ones, zeros))
        _, loss_d_ba_real = self.D_ba(sentences_a, labels_real_sentences, device=self.device)
        complete_loss_d_ba = lambdas[2] * loss_d_ba_fake + lambdas[3] * loss_d_ba_real

        loss_logging['Loss D(B->A)'].append(complete_loss_d_ba.item())
        complete_loss_d_ba.backward()

    def save_models(self, base_path: Union[str]):
        self.G_ab.save_model(base_path + "/G_ab/")
        self.G_ba.save_model(base_path + "/G_ba/")
        self.D_ab.save_model(base_path + "/D_ab/")
        self.D_ba.save_model(base_path + "/D_ba/")

    def transfer(self, sentences: List[str], direction: str):
        if direction == "AB":
            transferred_sentences = self.G_ab.transfer(sentences, device=self.device)
        else:
            transferred_sentences = self.G_ba.transfer(sentences, device=self.device)
        return transferred_sentences


In [None]:
# @title Evaluator Class, it is used to evaluate model performance using a ternary classifier
class Evaluator():
    def __init__(self, cycleGAN, args, experiment=None, label2id=None):
        """
        Class for evaluation
        """
        super(Evaluator, self).__init__()

        self.cycleGAN = cycleGAN
        self.args = args
        self.experiment = experiment

        # If label2id is not provided, use a default mapping
        if label2id is None:
            label2id = {"neu": 0, "pos": 1, "neg": 2}
        self.label2id = label2id

        self.bleu = evaluate.load('sacrebleu')
        self.rouge = evaluate.load('rouge')
        # if args.bertscore: self.bertscore = evaluate.load('bertscore')


    def __compute_metric__(self, predictions, references, metric_name, direction=None):
        # predictions = list | references = list of lists
        scores = []
        if metric_name in ['bleu', 'rouge']:
            for pred, ref in zip(predictions, references):
                if metric_name == 'bleu':
                    res = self.bleu.compute(predictions=[pred], references=[ref])
                    scores.append(res['score'])
                elif metric_name == 'rouge':
                    tmp_rouge1, tmp_rouge2, tmp_rougeL = [], [], []
                    for r in ref:
                        res = self.rouge.compute(predictions=[pred], references=[r], use_aggregator=False)
                        tmp_rouge1.append(res['rouge1'][0])
                        tmp_rouge2.append(res['rouge2'][0])
                        tmp_rougeL.append(res['rougeL'][0])
                    scores.append([max(tmp_rouge1), max(tmp_rouge2), max(tmp_rougeL)])
        else:
            raise Exception(f"Metric {metric_name} is not supported.")
        return scores

    def __compute_classif_metrics__(self, pred_A, pred_B, style_A, style_B):
        # Using self.label2id
        label2id = self.label2id

        device = self.cycleGAN.device
        truncation, padding = 'longest_first', 'max_length'

        # If certain conditions are met, load an external classifier instead of using self.cycleGAN.Cls
        if ('lambdas' not in vars(self.args)
            or self.args.lambdas[4] == 0
            or self.args.pretrained_classifier_eval != self.args.pretrained_classifier_model):
            classifier = AutoModelForSequenceClassification.from_pretrained(self.args.pretrained_classifier_eval)
            classifier_tokenizer = AutoTokenizer.from_pretrained(f'{self.args.pretrained_classifier_eval}tokenizer/')
            classifier.to(device)
        else:
            classifier = self.cycleGAN.Cls.model
            classifier_tokenizer = self.cycleGAN.Cls.tokenizer
        classifier.eval()

        y_pred, y_true = [], np.concatenate([
            np.full(len(pred_A), label2id[style_A]),
            np.full(len(pred_B), label2id[style_B])
        ])

        for i in range(0, len(pred_A), self.args.batch_size):
            batch_a = pred_A[i:i+self.args.batch_size]
            inputs = classifier_tokenizer(
                batch_a,
                truncation=truncation,
                padding=padding,
                max_length=self.args.max_sequence_length,
                return_tensors="pt"
            )
            inputs = inputs.to(device)
            with torch.no_grad():
                output = classifier(**inputs)
            y_pred.extend(np.argmax(output.logits.cpu().numpy(), axis=1))

        for i in range(0, len(pred_B), self.args.batch_size):
            batch_b = pred_B[i:i+self.args.batch_size]
            inputs = classifier_tokenizer(
                batch_b,
                truncation=truncation,
                padding=padding,
                max_length=self.args.max_sequence_length,
                return_tensors="pt"
            )
            inputs = inputs.to(device)
            with torch.no_grad():
                output = classifier(**inputs)
            y_pred.extend(np.argmax(output.logits.cpu().numpy(), axis=1))

        acc = accuracy_score(y_true, y_pred)
        prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
        rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
        f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
        return acc, prec, rec, f1

    def run_eval_mono(self, epoch, current_training_step, phase, dl_source, dl_target, style_source, style_target):
        print(f'Start {phase}...')
        self.cycleGAN.eval()

        real_A, real_B = [], []
        pred_A, pred_B = [], []
        scores_AB_bleu_self, scores_BA_bleu_self = [], []
        scores_AB_r1_self, scores_BA_r1_self = [], []
        scores_AB_r2_self, scores_BA_r2_self = [], []
        scores_AB_rL_self, scores_BA_rL_self = [], []

        # Define the style token for style B
        style_token_A = f"[{dl_target.dataset.style}->{dl_source.dataset.style}]"
        style_token_B = f"[{dl_source.dataset.style}->{dl_target.dataset.style}]"

        for batch in dl_source:
            mono_a = list(batch)
            mono_a_with_style = [f"{style_token_B} {sentence}" for sentence in mono_a]
            with torch.no_grad():
                transferred = self.cycleGAN.transfer(sentences=mono_a_with_style, direction='AB')
            real_A.extend(mono_a)
            pred_B.extend(transferred)
            mono_a = [[s] for s in mono_a]
            scores_AB_bleu_self.extend(self.__compute_metric__(transferred, mono_a, 'bleu'))
            scores_rouge_self = np.array(self.__compute_metric__(transferred, mono_a, 'rouge'))
            scores_AB_r1_self.extend(scores_rouge_self[:, 0].tolist())
            scores_AB_r2_self.extend(scores_rouge_self[:, 1].tolist())
            scores_AB_rL_self.extend(scores_rouge_self[:, 2].tolist())

        avg_AB_bleu_self = np.mean(scores_AB_bleu_self)
        avg_AB_r1_self = np.mean(scores_AB_r1_self)
        avg_AB_r2_self = np.mean(scores_AB_r2_self)
        avg_AB_rL_self = np.mean(scores_AB_rL_self)

        for batch in dl_target:
            mono_b = list(batch)
            mono_b_with_style = [f"{style_token_A} {sentence}" for sentence in mono_b]
            with torch.no_grad():
                transferred = self.cycleGAN.transfer(sentences=mono_b_with_style, direction='BA')
            real_B.extend(mono_b)
            pred_A.extend(transferred)
            mono_b = [[s] for s in mono_b]
            scores_BA_bleu_self.extend(self.__compute_metric__(transferred, mono_b, 'bleu'))
            scores_rouge_self = np.array(self.__compute_metric__(transferred, mono_b, 'rouge'))
            scores_BA_r1_self.extend(scores_rouge_self[:, 0].tolist())
            scores_BA_r2_self.extend(scores_rouge_self[:, 1].tolist())
            scores_BA_rL_self.extend(scores_rouge_self[:, 2].tolist())

        avg_BA_bleu_self = np.mean(scores_BA_bleu_self)
        avg_BA_r1_self = np.mean(scores_BA_r1_self)
        avg_BA_r2_self = np.mean(scores_BA_r2_self)
        avg_BA_rL_self = np.mean(scores_BA_rL_self)
        avg_2dir_bleu_self = (avg_AB_bleu_self + avg_BA_bleu_self) / 2

        acc, _, _, _ = self.__compute_classif_metrics__(pred_A, pred_B, style_source, style_target)
        acc_scaled = acc * 100
        avg_acc_bleu_self = (avg_2dir_bleu_self + acc_scaled) / 2
        avg_acc_bleu_self_geom = (avg_2dir_bleu_self * acc_scaled) ** 0.5
        avg_acc_bleu_self_h = (2 * avg_2dir_bleu_self * acc_scaled) / (avg_2dir_bleu_self + acc_scaled + 1e-6)

        metrics = {
            'epoch': epoch,
            'step': current_training_step,
            'self-BLEU A->B': avg_AB_bleu_self,
            'self-BLEU B->A': avg_BA_bleu_self,
            'self-BLEU avg': avg_2dir_bleu_self,
            'self-ROUGE-1 A->B': avg_AB_r1_self,
            'self-ROUGE-1 B->A': avg_BA_r1_self,
            'self-ROUGE-2 A->B': avg_AB_r2_self,
            'self-ROUGE-2 B->A': avg_BA_r2_self,
            'self-ROUGE-L A->B': avg_AB_rL_self,
            'self-ROUGE-L B->A': avg_BA_rL_self,
            'style accuracy': acc,
            'acc-BLEU': avg_acc_bleu_self,
            'g-acc-BLEU': avg_acc_bleu_self_geom,
            'h-acc-BLEU': avg_acc_bleu_self_h
        }

        if phase[:10] == 'validation':
            base_path = f"{self.args.save_base_folder}epoch_{epoch}/"
            suffix = f'epoch{epoch}'
        else:
            if self.args.from_pretrained is not None:
                if self.args.save_base_folder is not None:
                    base_path = f"{self.args.save_base_folder}"
                else:
                    base_path = f"{self.args.from_pretrained}epoch_{epoch}/"
            else:
                base_path = f"{self.args.save_base_folder}test/epoch_{epoch}/"
            suffix = f'epoch{epoch}_test'

        os.makedirs(os.path.dirname(base_path), exist_ok=True)
        pickle.dump(metrics, open(f"{base_path}metrics_{suffix}.pickle", 'wb'))

        for m, v in metrics.items():
            if m not in ['epoch', 'step']:
                print(f'{m}: {v}')

        df_AB = pd.DataFrame()
        df_AB['A (source)'] = real_A
        df_AB['B (generated)'] = pred_B
        df_AB.to_csv(f"{base_path}{style_source}_{style_target}_{suffix}.csv", sep=',', header=True)

        df_BA = pd.DataFrame()
        df_BA['B (source)'] = real_B
        df_BA['A (generated)'] = pred_A
        df_BA.to_csv(f"{base_path}{style_target}_{style_source}_{suffix}.csv", sep=',', header=True)

        del df_AB, df_BA
        print(f'End {phase}...')

    def dummy_classif(self):
        pred_A = [
            'wake up or you are going to lose your business .',
            'this place has none of them .',
            'it is april and there are no grass tees yet .',
            'there is no grass on the range .',
            'bottom line , this place sucks .',
            'someone should buy this place .',
            'very disappointed in the customer service .',
            'we will not be back .'
        ]
        pred_B = [
            'huge sandwich !',
            'i added mushrooms , it was very flavorful .',
            'he enjoyed it as well .',
            'fast and friendly service .',
            'will definitely be back .',
            "my dad 's favorite .",
            'huge burgers , fish sandwiches , salads .',
            'decent service .'
        ]
        acc, _, _, _ = self.__compute_classif_metrics__(pred_A, pred_B, 'neg', 'pos')
        print('Dummy classification metrics computation end')


In [None]:
# @title Main function: it instantiates Generators, Discriminators and classifiers and control the training flow
def main(args):
    # List of required attributes
    required_attrs = [
        "epochs", "style_a", "style_b", "style_c",
        "path_mono_A", "path_mono_B", "path_mono_C",
        "path_mono_A_eval", "path_mono_B_eval", "path_mono_C_eval",
        "batch_size", "max_samples_train", "max_samples_eval",
        "nonparal_same_size", "generator_model_tag", "discriminator_model_tag",
        "pretrained_classifier_model", "pretrained_classifier_eval",
        "from_pretrained", "save_base_folder", "save_steps",
        "lambdas", "learning_rate", "max_sequence_length",
        "lr_scheduler_type", "warmup_ratio", "use_cuda_if_available"
    ]

    # Check for missing attributes
    missing_attrs = [attr for attr in required_attrs if not hasattr(args, attr)]
    if missing_attrs:
        raise AttributeError(f"Args object is missing: {', '.join(missing_attrs)}")

    # Seeding
    SEED = 42
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # Print paths for debugging
    print(f"Loading dataset A from: {args.path_mono_A}")
    print(f"Loading dataset B from: {args.path_mono_B}")
    print(f"Loading dataset C from: {args.path_mono_C}")

    # ----- Load datasets -----
    mono_ds_a = MonostyleDataset(
        dataset_format="line_file",
        dataset_path=args.path_mono_A,
        style=args.style_a,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )
    mono_ds_b = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_b,
        dataset_path=args.path_mono_B,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )
    mono_ds_c = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_c,
        dataset_path=args.path_mono_C,
        separator='\n',
        max_dataset_samples=args.max_samples_train,
    )

    # Parse lambdas
    lambdas = [float(l) for l in args.lambdas.split('|')]

    # Print all args for clarity
    hyper_params = {}
    print("\nArguments summary:")
    for key, value in vars(args).items():
        hyper_params[key] = value
        print(f"  {key}:\t{value}")

    # If specified, reduce all datasets to the same size
    if args.nonparal_same_size:
        min_len = min(len(mono_ds_a), len(mono_ds_b), len(mono_ds_c))
        mono_ds_a.reduce_data(min_len)
        mono_ds_b.reduce_data(min_len)
        mono_ds_c.reduce_data(min_len)

    # Create eval datasets
    mono_ds_a_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_a,
        dataset_path=args.path_mono_A_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )
    mono_ds_b_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_b,
        dataset_path=args.path_mono_B_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )
    mono_ds_c_eval = MonostyleDataset(
        dataset_format="line_file",
        style=args.style_c,
        dataset_path=args.path_mono_C_eval,
        separator='\n',
        max_dataset_samples=args.max_samples_eval
    )

    # Dataloaders
    mono_dl_a = DataLoader(mono_ds_a, batch_size=args.batch_size, shuffle=True)
    mono_dl_b = DataLoader(mono_ds_b, batch_size=args.batch_size, shuffle=True)
    mono_dl_c = DataLoader(mono_ds_c, batch_size=args.batch_size, shuffle=True)

    mono_dl_a_eval = DataLoader(mono_ds_a_eval, batch_size=args.batch_size, shuffle=False)
    mono_dl_b_eval = DataLoader(mono_ds_b_eval, batch_size=args.batch_size, shuffle=False)
    mono_dl_c_eval = DataLoader(mono_ds_c_eval, batch_size=args.batch_size, shuffle=False)

    # Optional: free memory
    del mono_ds_a, mono_ds_b, mono_ds_c
    del mono_ds_a_eval, mono_ds_b_eval, mono_ds_c_eval

    # ----- Instantiate G, D, Cls -----
    if args.from_pretrained:
        G_ab = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            pretrained_path=f"{args.from_pretrained}G_ab/",
            max_seq_length=args.max_sequence_length
        )
        G_ba = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            pretrained_path=f"{args.from_pretrained}G_ba/",
            max_seq_length=args.max_sequence_length
        )
        D_ab = DiscriminatorModel(
            args.discriminator_model_tag,
            f"{args.from_pretrained}D_ab/",
            max_seq_length=args.max_sequence_length
        )
        D_ba = DiscriminatorModel(
            args.discriminator_model_tag,
            f"{args.from_pretrained}D_ba/",
            max_seq_length=args.max_sequence_length
        )
        print("[INFO] Loaded pretrained G_ab, G_ba, D_ab, D_ba")
    else:
        G_ab = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            max_seq_length=args.max_sequence_length
        )
        G_ba = GeneratorModel(
            model_name_or_path=args.generator_model_tag,
            new_style_tokens=args.style_token_list,
            max_seq_length=args.max_sequence_length
        )
        D_ab = DiscriminatorModel(
            args.discriminator_model_tag,
            max_seq_length=args.max_sequence_length
        )
        D_ba = DiscriminatorModel(
            args.discriminator_model_tag,
            max_seq_length=args.max_sequence_length
        )
        print("[INFO] Using fresh G_ab, G_ba, D_ab, D_ba")

    # If we need the classifier
    if lambdas[4] != 0 and args.pretrained_classifier_model:
        Cls = ClassifierModel(args.pretrained_classifier_model, max_seq_length=args.max_sequence_length)
        print("[INFO] Loaded pretrained classifier")
    else:
        Cls = None

    # Device
    if args.use_cuda_if_available and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    # Create CycleGAN
    cycleGAN = CycleGANModel(
        G_ab=G_ab,
        G_ba=G_ba,
        D_ab=D_ab,
        D_ba=D_ba,
        Cls=Cls,
        device=device,
        label2id=args.label2id
    )

    # Calculate total training steps
    n_batch_ab = min(len(mono_dl_a), len(mono_dl_b))
    n_batch_ac = min(len(mono_dl_a), len(mono_dl_c))
    steps_per_epoch = n_batch_ab + n_batch_ac
    total_training_steps = args.epochs * steps_per_epoch

    # Optimizer
    optimizer = AdamW(cycleGAN.get_optimizer_parameters(), lr=args.learning_rate)

    # Scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=int(steps_per_epoch / 2),
        T_mult=1,
        eta_min=0
    )

    current_training_step = 0
    start_epoch = 0

    # Resume checkpoint if available
    if args.from_pretrained and os.path.exists(f"{args.from_pretrained}checkpoint.pth"):
        ckpt = torch.load(f"{args.from_pretrained}checkpoint.pth", map_location="cpu")
        optimizer.load_state_dict(ckpt["optimizer"])
        scheduler.load_state_dict(ckpt["lr_scheduler"])
        current_training_step = ckpt["training_step"]
        del ckpt

    # Evaluator
    evaluator = Evaluator(cycleGAN, args, label2id=args.label2id)

    # Training subphase function
    def train_subphase(dataloader_a, dataloader_x, style_src, style_tgt, loss_log):
        nonlocal current_training_step
        n_batch = min(len(dataloader_a), len(dataloader_x))
        progress_bar = tqdm(range(n_batch), desc=f"{style_src}->{style_tgt}")

        cycleGAN.train()
        for batch_a, batch_x in zip(dataloader_a, dataloader_x):
            # Ensure batch_a and batch_x have the same size
            len_a, len_x = len(batch_a), len(batch_x)
            if len_a > len_x:
                batch_a = batch_a[:len_x]
            elif len_x > len_a:
                batch_x = batch_x[:len_a]

            cycleGAN.training_cycle(
                sentences_a=batch_a,
                sentences_b=batch_x,
                style_source=style_src,
                style_target=style_tgt,
                lambdas=lambdas,
                loss_logging=loss_log,
                training_step=current_training_step
            )

            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            current_training_step += 1
            progress_bar.update(1)

        progress_bar.close()

    # Loss logging
    loss_logging = {
        'Cycle Loss A-B-A': [],
        'Loss generator  A-B': [],
        'Classifier-guided A-B': [],
        'Loss D(A->B)': [],
        'Cycle Loss B-A-B': [],
        'Loss generator  B-A': [],
        'Classifier-guided B-A': [],
        'Loss D(B->A)': []
    }
    loss_logging['hyper_params'] = hyper_params

    # ----- Training loop -----
    for epoch_idx in range(start_epoch, args.epochs):
        print(f"\n=== EPOCH {epoch_idx} ===")

        # (1) A->B
        train_subphase(mono_dl_a, mono_dl_b, style_src=args.style_a, style_tgt=args.style_b, loss_log=loss_logging)
        # (2) A->C
        train_subphase(mono_dl_a, mono_dl_c, style_src=args.style_a, style_tgt=args.style_c, loss_log=loss_logging)

        # (3) End-of-epoch evaluation
        evaluator.run_eval_mono(
            epoch_idx,
            current_training_step,
            phase="validation_AB_epoch",
            dl_source=mono_dl_a_eval,
            dl_target=mono_dl_b_eval,
            style_source=args.style_a,
            style_target=args.style_b
        )
        evaluator.run_eval_mono(
            epoch_idx,
            current_training_step,
            phase="validation_AC_epoch",
            dl_source=mono_dl_a_eval,
            dl_target=mono_dl_c_eval,
            style_source=args.style_a,
            style_target=args.style_c
        )

        # (4) Checkpoint saving
        if epoch_idx % args.save_steps == 0:
            cycleGAN.save_models(f"{args.save_base_folder}epoch_{epoch_idx}/")

            checkpoint = {
                'epoch': epoch_idx + 1,
                'training_step': current_training_step,
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': scheduler.state_dict()
            }
            torch.save(checkpoint, f"{args.save_base_folder}/checkpoint.pth")

            # Remove old loss file if needed
            if epoch_idx > 0:
                prev_loss_file = f"{args.save_base_folder}loss.pickle"
                if os.path.exists(prev_loss_file):
                    os.remove(prev_loss_file)

            # Save training loss
            pickle.dump(loss_logging, open(f"{args.save_base_folder}loss.pickle", "wb"))

        cycleGAN.train()

    print("\n=== Training completed ===")


In [None]:
# @title Args Class, it is used to configure the GAN
class Args:
    def __init__(
        self,
        epochs,
        style_a,
        style_b,
        style_c,
        path_mono_A,
        path_mono_B,
        path_mono_C,
        path_mono_A_eval,
        path_mono_B_eval,
        path_mono_C_eval,
        generator_model_tag,
        discriminator_model_tag,
        label2id,
        style_token_list,
        batch_size=8,
        max_samples_train=None,
        max_samples_eval=None,
        nonparal_same_size=False,
        pretrained_classifier_model=None,
        pretrained_classifier_eval=None,
        from_pretrained=None,
        save_base_folder="./checkpoints/",
        save_steps=1,
        lambdas="10|1|1|1|1|1",
        learning_rate=5e-5,
        max_sequence_length=32,
        lr_scheduler_type="cosine_with_restarts",
        warmup_ratio=0.0,
        use_cuda_if_available=False
    ):
        """
        Class to store all training/testing arguments.
        """
        self.epochs = epochs
        self.style_a = style_a
        self.style_b = style_b
        self.style_c = style_c
        self.path_mono_A = path_mono_A
        self.path_mono_B = path_mono_B
        self.path_mono_C = path_mono_C
        self.path_mono_A_eval = path_mono_A_eval
        self.path_mono_B_eval = path_mono_B_eval
        self.path_mono_C_eval = path_mono_C_eval
        self.generator_model_tag = generator_model_tag
        self.discriminator_model_tag = discriminator_model_tag
        self.batch_size = batch_size
        self.max_samples_train = max_samples_train
        self.max_samples_eval = max_samples_eval
        self.nonparal_same_size = nonparal_same_size
        self.pretrained_classifier_model = pretrained_classifier_model
        self.pretrained_classifier_eval = pretrained_classifier_eval
        self.from_pretrained = from_pretrained
        self.save_base_folder = save_base_folder
        self.save_steps = save_steps
        self.lambdas = lambdas
        self.learning_rate = learning_rate
        self.max_sequence_length = max_sequence_length
        self.lr_scheduler_type = lr_scheduler_type
        self.warmup_ratio = warmup_ratio
        self.use_cuda_if_available = use_cuda_if_available
        self.label2id = label2id
        self.style_token_list = style_token_list


# Experiments

In [None]:
# @title Amazon runs: uncomment to test it

'''
# Create the label2id map based on the defined styles
label2id = {
    "pos": 0,
    "neu": 1,
    "neg": 2
}

# Dynamically create the list of style tokens
style_token_list = [
    "[pos->neu]", "[pos->neg]",
    "[neu->pos]", "[neu->neg]",
    "[neg->pos]", "[neg->neu]"
]

# Define args directly in the notebook
args = Args(
    epochs=7,
    style_a="pos",
    style_b="neu",
    style_c="neg",
    path_mono_A="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/positive_train.txt",
    path_mono_B="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/neutral_train.txt",
    path_mono_C="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/negative_train.txt",
    path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Eval_split/positive_eval.txt",
    path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Eval_split/neutral_eval.txt",
    path_mono_C_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Eval_split/negative_eval.txt",
    generator_model_tag="facebook/bart-base",
    discriminator_model_tag="distilbert/distilbert-base-cased",
    pretrained_classifier_model="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/CLASSIFIER_CHECKPOINT",
    pretrained_classifier_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/CLASSIFIER_CHECKPOINT",
    lambdas="10|1|1|1|1",
    learning_rate=1e-4,
    max_sequence_length=64,
    save_base_folder="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/GAN_CHECKPOINT_POS_PROVA/",
    save_steps=1,
    use_cuda_if_available=True,
    label2id=label2id,           # Explicitly pass the label mapping
    style_token_list=style_token_list  # Explicitly pass the style tokens
)

# Call the main function with the updated arguments
main(args)

# from_pretrained="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/GAN_CHECKPOINT_POS/"
'''

'\n# Create the label2id map based on the defined styles\nlabel2id = {\n    "pos": 0,\n    "neu": 1,\n    "neg": 2\n}\n\n# Dynamically create the list of style tokens\nstyle_token_list = [\n    "[pos->neu]", "[pos->neg]",\n    "[neu->pos]", "[neu->neg]",\n    "[neg->pos]", "[neg->neu]"\n]\n\n# Define args directly in the notebook\nargs = Args(\n    epochs=7,\n    style_a="pos",\n    style_b="neu",\n    style_c="neg",\n    path_mono_A="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/positive_train.txt",\n    path_mono_B="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/neutral_train.txt",\n    path_mono_C="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Train_split/negative_train.txt",\n    path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Eval_split/positive_eval.txt",\n    path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/00.Amazon_Project/data/AMAZON/Eval_split/neutral_eval.txt",\

In [None]:
# @title Author runs
# Create the label2id map based on the defined styles
label2id = {
    "tru": 0,
    "lyr": 1,
    "sha": 2
}

# Dynamically create the list of style tokens
style_token_list = [
    "[tru->lyr]", "[tru->sha]",
    "[lyr->tru]", "[lyr->sha]",
    "[sha->tru]", "[sha->lyr]"
]

# Define args directly in the notebook
args = Args(
    epochs=7,
    style_a="tru",
    style_b="lyr",
    style_c="sha",
    path_mono_A="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_trump_spellchecked.txt",
    path_mono_B="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_lyrics_spellchecked.txt",
    path_mono_C="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/train_shakespeare_spellchecked.txt",
    path_mono_A_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_trump_spellchecked.txt",
    path_mono_B_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_lyrics_spellchecked.txt",
    path_mono_C_eval="/content/drive/MyDrive/ProjectNLP/03.Ultima_Estensione_Shakespeare/Data_spellchecked/eval_shakespeare_spellchecked.txt",
    generator_model_tag="facebook/bart-base",
    discriminator_model_tag="distilbert/distilbert-base-cased",
    pretrained_classifier_model="/content/drive/MyDrive/ProjectNLP/20250112_Autori/ClassifierCheckpoint/checkpoint-4654",
    pretrained_classifier_eval="/content/drive/MyDrive/ProjectNLP/20250112_Autori/ClassifierCheckpoint/checkpoint-4654",
    lambdas="10|1|1|1|1",
    learning_rate=1e-4,
    max_sequence_length=64,
    save_base_folder="/content/drive/MyDrive/ProjectNLP/20250112_Autori/GANPROVA/",
    save_steps=1,
    use_cuda_if_available=True,
    label2id=label2id,           # Explicitly pass the label mapping
    style_token_list=style_token_list  # Explicitly pass the style tokens
)

# Call the main function with the updated arguments
main(args)

