In [1]:
!pip install datasets



In [2]:
import torch
import numpy as np
import random
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, BertModel, get_linear_schedule_with_warmup, TrainingArguments, Trainer, EvalPrediction
import torch.nn as nn
from torch.optim import AdamW
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, Dataset, DataLoader, Sampler
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os
from tqdm import tqdm

In [3]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the datasets
print("Loading datasets...")
sst2 = load_dataset("glue", "sst2")
imdb = load_dataset("imdb")

# Load the pre-trained BERT model fine-tuned on IMDB
model_name = "yyammerrrss/imdb-sft-bert"
print(f"Loading model from {model_name}...")
tokenizer = BertTokenizer.from_pretrained(model_name)

Using device: cuda
Loading datasets...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading model from yyammerrrss/imdb-sft-bert...


In [4]:
# Define the tokenization functions
def tokenize_sst2(examples):
    return tokenizer(examples["sentence"], padding="max_length", truncation=True, max_length=128)

def tokenize_imdb(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

# Tokenize the datasets
print("Tokenizing datasets...")
tokenized_sst2 = {}
for split in sst2:
    tokenized_sst2[split] = sst2[split].map(tokenize_sst2, batched=True)
    tokenized_sst2[split].set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# For IMDB, we'll use both train and test sets
tokenized_imdb = {}
tokenized_imdb['train'] = imdb['train'].map(tokenize_imdb, batched=True)
tokenized_imdb['test'] = imdb['test'].map(tokenize_imdb, batched=True)
for split in tokenized_imdb:
    tokenized_imdb[split].set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

print(f"SST2 train size: {len(tokenized_sst2['train'])}")
print(f"SST2 validation size: {len(tokenized_sst2['validation'])}")
print(f"IMDB train size: {len(tokenized_imdb['train'])}")
print(f"IMDB test size: {len(tokenized_imdb['test'])}")

Tokenizing datasets...
SST2 train size: 67349
SST2 validation size: 872
IMDB train size: 25000
IMDB test size: 25000


In [5]:
# Define the adversarial model architecture
class AdversarialBert(nn.Module):
    def __init__(self, bert_model_name, num_labels=2, lambda_param=0.1):
        super(AdversarialBert, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.domain_classifier = nn.Linear(self.bert.config.hidden_size, 2)
        self.lambda_param = lambda_param

    def forward(self, input_ids, attention_mask, labels=None, domain_labels=None, alpha=1.0):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        # Task classifier (sentiment)
        logits = self.classifier(pooled_output)

        # Domain classifier with gradient reversal
        if self.training and domain_labels is not None:
            # Apply gradient reversal
            reverse_feature = GradientReversalFunction.apply(pooled_output, alpha)
            domain_logits = self.domain_classifier(reverse_feature)

            # Calculate losses
            task_loss = F.cross_entropy(logits, labels)
            domain_loss = F.cross_entropy(domain_logits, domain_labels)
            loss = task_loss + self.lambda_param * domain_loss

            return {
                'loss': loss,
                'task_logits': logits,
                'domain_logits': domain_logits,
                'task_loss': task_loss,
                'domain_loss': domain_loss
            }
        else:
            # During evaluation, only use the task classifier
            loss = F.cross_entropy(logits, labels) if labels is not None else None
            return {
                'loss': loss,
                'task_logits': logits
            }

# Define gradient reversal layer for adversarial training
class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None

# Function to create dataloaders for adversarial training
def create_adversarial_dataloader(source_dataset, target_dataset, batch_size):
    # Create domain labels: 0 for source domain (IMDB), 1 for target domain (SST2)
    source_domain_labels = torch.zeros(len(source_dataset), dtype=torch.long)
    target_domain_labels = torch.ones(len(target_dataset), dtype=torch.long)

    # Sample from the source dataset to match the size of the target dataset
    source_indices = np.random.choice(len(source_dataset), min(len(source_dataset), len(target_dataset)), replace=False)
    source_indices = [int(idx) for idx in source_indices]  # Convert numpy.int64 to Python int

    # For IMDB (source), truncate the tensors to match SST2's length (128)
    source_data = {
        'input_ids': torch.stack([source_dataset[i]['input_ids'][:128] for i in source_indices]),
        'attention_mask': torch.stack([source_dataset[i]['attention_mask'][:128] for i in source_indices]),
        'labels': torch.stack([source_dataset[i]['label'] for i in source_indices]),
        'domain_labels': source_domain_labels[source_indices]
    }

    # Get all target data (SST2)
    target_data = {
        'input_ids': torch.stack([target_dataset[i]['input_ids'] for i in range(len(target_dataset))]),
        'attention_mask': torch.stack([target_dataset[i]['attention_mask'] for i in range(len(target_dataset))]),
        'labels': torch.stack([target_dataset[i]['label'] for i in range(len(target_dataset))]),
        'domain_labels': target_domain_labels
    }

    # Combine source and target data
    combined_input_ids = torch.cat([source_data['input_ids'], target_data['input_ids']], dim=0)
    combined_attention_mask = torch.cat([source_data['attention_mask'], target_data['attention_mask']], dim=0)
    combined_labels = torch.cat([source_data['labels'], target_data['labels']], dim=0)
    combined_domain_labels = torch.cat([source_data['domain_labels'], target_domain_labels], dim=0)

    # Create dataset and dataloader
    dataset = TensorDataset(combined_input_ids, combined_attention_mask, combined_labels, combined_domain_labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return dataloader

In [6]:
# Define evaluation function
def evaluate_model(model, eval_dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs['task_logits']

            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

# Create evaluation dataloaders
def create_eval_dataloader(dataset, batch_size):
    # Check if this is the IMDB dataset, which has length 512
    first_input_ids = dataset[0]['input_ids']

    if len(first_input_ids) > 128:
        input_ids = torch.stack([dataset[i]['input_ids'][:128] for i in range(len(dataset))])
        attention_mask = torch.stack([dataset[i]['attention_mask'][:128] for i in range(len(dataset))])
    else:
        input_ids = torch.stack([dataset[i]['input_ids'] for i in range(len(dataset))])
        attention_mask = torch.stack([dataset[i]['attention_mask'] for i in range(len(dataset))])

    labels = torch.stack([dataset[i]['label'] for i in range(len(dataset))])

    eval_dataset = TensorDataset(input_ids, attention_mask, labels)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)

    return eval_dataloader

In [7]:
# Main training function
def train_adversarial(source_train, target_train, target_val, target_test, lambda_param=0.1,
                      batch_size=16, num_epochs=3, learning_rate=2e-5, weight_decay=0.01):
    # Initialize model
    model = AdversarialBert(model_name, lambda_param=lambda_param)
    model.to(device)

    # Create dataloaders
    train_dataloader = create_adversarial_dataloader(source_train, target_train, batch_size)
    target_val_dataloader = create_eval_dataloader(target_val, batch_size)
    target_test_dataloader = create_eval_dataloader(target_test, batch_size)

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Training loop
    print("Starting adversarial training...")
    best_val_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_task_loss = 0
        epoch_domain_loss = 0

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            input_ids, attention_mask, labels, domain_labels = [b.to(device) for b in batch]

            # Clear gradients
            optimizer.zero_grad()

            # Calculate p value for increasing domain influence over time
            p = float(epoch) / num_epochs
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                domain_labels=domain_labels,
                alpha=alpha
            )

            # Backward pass
            outputs['loss'].backward()
            optimizer.step()
            scheduler.step()

            # Update progress bar
            epoch_loss += outputs['loss'].item()
            epoch_task_loss += outputs['task_loss'].item()
            epoch_domain_loss += outputs['domain_loss'].item()
            progress_bar.set_postfix({
                'loss': epoch_loss / (progress_bar.n + 1),
                'task_loss': epoch_task_loss / (progress_bar.n + 1),
                'domain_loss': epoch_domain_loss / (progress_bar.n + 1)
            })

        # Evaluate on target validation set
        print(f"Evaluating on target validation set (epoch {epoch+1})...")
        val_metrics = evaluate_model(model, target_val_dataloader)
        print(f"Validation metrics: {val_metrics}")

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            print(f"New best model with validation accuracy: {best_val_accuracy:.4f}")
            output_dir = "./results/adversarial_finetuned"
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), os.path.join(output_dir, "model.bin"))

    # Load best model for final evaluation
    print("Loading best model for final evaluation...")
    model.load_state_dict(torch.load(os.path.join(output_dir, "model.bin")))

    # Evaluate on target validation and test sets
    print("Final evaluation on target validation set...")
    final_val_metrics = evaluate_model(model, target_val_dataloader)
    print(f"Final validation metrics: {final_val_metrics}")

    print("Evaluating on target test set...")
    final_test_metrics = evaluate_model(model, target_test_dataloader)
    print(f"Target test metrics: {final_test_metrics}")

    # Evaluate on source test set (IMDB)
    print("Evaluating on source test set (IMDB)...")
    source_test_dataloader = create_eval_dataloader(tokenized_imdb['test'], batch_size)
    source_test_metrics = evaluate_model(model, source_test_dataloader)
    print(f"Source test metrics: {source_test_metrics}")

    return {
        'model': model,
        'target_val_metrics': final_val_metrics,
        'target_test_metrics': final_test_metrics,
        'source_test_metrics': source_test_metrics
    }

# Run the adversarial training
print("Starting adversarial finetuning experiment...")
results = train_adversarial(
    source_train=tokenized_imdb['train'],
    target_train=tokenized_sst2['train'],
    target_val=tokenized_sst2['validation'],
    target_test=tokenized_sst2['validation'],
    batch_size=16,
    num_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01
)

Starting adversarial finetuning experiment...
Starting adversarial training...


Epoch 1/3: 100%|██████████| 5772/5772 [07:55<00:00, 12.14it/s, loss=0.243, task_loss=0.203, domain_loss=0.398]


Evaluating on target validation set (epoch 1)...
Validation metrics: {'accuracy': 0.926605504587156, 'precision': 0.9377880184331797, 'recall': 0.9166666666666666, 'f1': 0.9271070615034168}
New best model with validation accuracy: 0.9266


Epoch 2/3: 100%|██████████| 5772/5772 [07:54<00:00, 12.15it/s, loss=0.259, task_loss=0.114, domain_loss=1.44]


Evaluating on target validation set (epoch 2)...
Validation metrics: {'accuracy': 0.9231651376146789, 'precision': 0.929384965831435, 'recall': 0.918918918918919, 'f1': 0.9241223103057757}


Epoch 3/3: 100%|██████████| 5772/5772 [07:55<00:00, 12.15it/s, loss=0.125, task_loss=0.058, domain_loss=0.665]


Evaluating on target validation set (epoch 3)...
Validation metrics: {'accuracy': 0.9243119266055045, 'precision': 0.9108695652173913, 'recall': 0.9436936936936937, 'f1': 0.9269911504424779}
Loading best model for final evaluation...
Final evaluation on target validation set...
Final validation metrics: {'accuracy': 0.926605504587156, 'precision': 0.9377880184331797, 'recall': 0.9166666666666666, 'f1': 0.9271070615034168}
Evaluating on target test set...
Target test metrics: {'accuracy': 0.926605504587156, 'precision': 0.9377880184331797, 'recall': 0.9166666666666666, 'f1': 0.9271070615034168}
Evaluating on source test set (IMDB)...
Source test metrics: {'accuracy': 0.89308, 'precision': 0.8874694424729911, 'recall': 0.90032, 'f1': 0.8938485365950518}


In [8]:
# Display final results
print("\n\n" + "="*80)
print("ADVERSARIAL DOMAIN ADAPTATION RESULTS")
print("="*80)

print(f"Target validation (SST2) results:")
print(f"Accuracy: {results['target_val_metrics']['accuracy']:.4f}")
print(f"F1 Score: {results['target_val_metrics']['f1']:.4f}")
print(f"Precision: {results['target_val_metrics']['precision']:.4f}")
print(f"Recall: {results['target_val_metrics']['recall']:.4f}")

print(f"\nSource test (IMDB) results:")
print(f"Accuracy: {results['source_test_metrics']['accuracy']:.4f}")
print(f"F1 Score: {results['source_test_metrics']['f1']:.4f}")
print(f"Precision: {results['source_test_metrics']['precision']:.4f}")
print(f"Recall: {results['source_test_metrics']['recall']:.4f}")




ADVERSARIAL DOMAIN ADAPTATION RESULTS
Target validation (SST2) results:
Accuracy: 0.9266
F1 Score: 0.9271
Precision: 0.9378
Recall: 0.9167

Source test (IMDB) results:
Accuracy: 0.8931
F1 Score: 0.8938
Precision: 0.8875
Recall: 0.9003


It appears that IMDB results have fallen a bit this could be due to truncation of the IMDB dataset to match the dimensions of the SST-2 dataset.Lets try a dynamic padding approach to try and reduce this and see our results

In [9]:
from collections import defaultdict

# Custom dataset that handles different sequence lengths
class DynamicPaddingDataset(Dataset):
    def __init__(self, source_dataset, target_dataset, source_domain_label=0, target_domain_label=1, balance=True):
        self.source_dataset = source_dataset
        self.target_dataset = target_dataset
        self.source_domain_label = source_domain_label
        self.target_domain_label = target_domain_label

        if balance and len(source_dataset) > len(target_dataset):
            # Sample indices from source to match target size
            self.source_indices = np.random.choice(len(source_dataset), len(target_dataset), replace=False)
            self.source_indices = [int(idx) for idx in self.source_indices]
        else:
            self.source_indices = list(range(len(source_dataset)))

        self.source_size = len(self.source_indices)
        self.target_size = len(target_dataset)
        self.total_size = self.source_size + self.target_size

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        # Source examples come first, then target examples
        if idx < self.source_size:
            source_idx = self.source_indices[idx]
            item = self.source_dataset[source_idx]
            return {
                'input_ids': item['input_ids'],
                'attention_mask': item['attention_mask'],
                'label': item['label'],
                'domain_label': torch.tensor(self.source_domain_label, dtype=torch.long),
                'is_source': True
            }
        else:
            target_idx = idx - self.source_size
            item = self.target_dataset[target_idx]
            return {
                'input_ids': item['input_ids'],
                'attention_mask': item['attention_mask'],
                'label': item['label'],
                'domain_label': torch.tensor(self.target_domain_label, dtype=torch.long),
                'is_source': False
            }

# Custom collate function for dynamic padding
def dynamic_padding_collate_fn(batch):
    # Separate source and target examples to process them differently
    source_examples = [item for item in batch if item['is_source']]
    target_examples = [item for item in batch if not item['is_source']]

    # Process source examples (can be longer)
    if source_examples:
        source_max_len = max(len(ex['input_ids']) for ex in source_examples)
        source_input_ids = []
        source_attention_mask = []

        for ex in source_examples:
            # Pad or truncate to source_max_len
            input_ids = ex['input_ids']
            attention_mask = ex['attention_mask']

            if len(input_ids) < source_max_len:
                # Pad
                padding_length = source_max_len - len(input_ids)
                input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)])
                attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=torch.long)])
            elif len(input_ids) > source_max_len:
                # Truncate
                input_ids = input_ids[:source_max_len]
                attention_mask = attention_mask[:source_max_len]

            source_input_ids.append(input_ids)
            source_attention_mask.append(attention_mask)
    else:
        source_input_ids = []
        source_attention_mask = []

    # Process target examples (usually shorter)
    if target_examples:
        target_max_len = max(len(ex['input_ids']) for ex in target_examples)
        target_input_ids = []
        target_attention_mask = []

        for ex in target_examples:
            # Pad to target_max_len (no truncation needed as these are already properly sized)
            input_ids = ex['input_ids']
            attention_mask = ex['attention_mask']

            if len(input_ids) < target_max_len:
                # Pad
                padding_length = target_max_len - len(input_ids)
                input_ids = torch.cat([input_ids, torch.zeros(padding_length, dtype=torch.long)])
                attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=torch.long)])

            target_input_ids.append(input_ids)
            target_attention_mask.append(attention_mask)
    else:
        target_input_ids = []
        target_attention_mask = []

    # Combine all examples
    all_input_ids = source_input_ids + target_input_ids
    all_attention_mask = source_attention_mask + target_attention_mask
    all_labels = [ex['label'] for ex in batch]
    all_domain_labels = [ex['domain_label'] for ex in batch]

    # Convert lists to tensors
    input_ids_tensor = torch.stack(all_input_ids) if all_input_ids else torch.tensor([])
    attention_mask_tensor = torch.stack(all_attention_mask) if all_attention_mask else torch.tensor([])
    labels_tensor = torch.stack(all_labels)
    domain_labels_tensor = torch.stack(all_domain_labels)

    return input_ids_tensor, attention_mask_tensor, labels_tensor, domain_labels_tensor

# Custom batch sampler that creates mini-batches with similar sequence lengths
class SimilarLengthBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

        # Group indices by sequence length
        self.length_buckets = defaultdict(list)

        # Process source examples
        for idx in range(dataset.source_size):
            source_idx = dataset.source_indices[idx]
            length = len(dataset.source_dataset[source_idx]['input_ids'])
            # Round length to nearest 32 to create reasonable buckets
            bucket = (length // 32) * 32
            self.length_buckets[bucket].append(idx)

        # Process target examples
        for idx in range(dataset.target_size):
            target_idx = idx
            length = len(dataset.target_dataset[target_idx]['input_ids'])
            # Round length to nearest 32
            bucket = (length // 32) * 32
            self.length_buckets[bucket].append(idx + dataset.source_size)

        # Calculate number of batches
        self.num_batches = sum(max(1, len(indices) // batch_size) for indices in self.length_buckets.values())
        if not drop_last:
            self.num_batches += sum(1 for indices in self.length_buckets.values() if len(indices) % batch_size > 0)

    def __iter__(self):
        # Shuffle within each bucket if required
        if self.shuffle:
            for bucket in self.length_buckets.values():
                random.shuffle(bucket)

        # Create batches
        batches = []
        for bucket, indices in self.length_buckets.items():
            # Split indices into batches
            for i in range(0, len(indices), self.batch_size):
                if i + self.batch_size <= len(indices) or not self.drop_last:
                    batches.append(indices[i:i + self.batch_size])

        # Shuffle the batches
        if self.shuffle:
            random.shuffle(batches)

        # Yield batches
        for batch in batches:
            yield batch

    def __len__(self):
        return self.num_batches

# Function to create dataloaders with dynamic padding
def create_dynamic_padding_dataloaders(source_train, target_train, target_val, target_test, batch_size=16):
    # Create training dataset
    train_dataset = DynamicPaddingDataset(
        source_train, target_train,
        source_domain_label=0, target_domain_label=1,
        balance=True
    )

    # Create batch sampler for training
    train_sampler = SimilarLengthBatchSampler(
        train_dataset, batch_size=batch_size,
        shuffle=True, drop_last=False
    )

    # Create training dataloader
    train_dataloader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        collate_fn=dynamic_padding_collate_fn,
        num_workers=0
    )

    # Create validation and test dataloaders
    # Note: For evaluation, we don't need domain labels
    def create_eval_dataloader(dataset, batch_size):
        eval_dataset = TensorDataset(
            torch.stack([dataset[i]['input_ids'] for i in range(len(dataset))]),
            torch.stack([dataset[i]['attention_mask'] for i in range(len(dataset))]),
            torch.stack([dataset[i]['label'] for i in range(len(dataset))])
        )
        return DataLoader(eval_dataset, batch_size=batch_size)

    target_val_dataloader = create_eval_dataloader(target_val, batch_size)
    target_test_dataloader = create_eval_dataloader(target_test, batch_size)

    return {
        'train': train_dataloader,
        'val': target_val_dataloader,
        'test': target_test_dataloader
    }



In [10]:
# Modified training function to use dynamic padding
# Modified training function to use dynamic padding
def train_adversarial_with_dynamic_padding(source_train, target_train, target_val, target_test, lambda_param=0.1,
                                          batch_size=16, num_epochs=3, learning_rate=2e-5, weight_decay=0.01):
    # Track training time
    import time
    start_time = time.time()

    # Initialize model
    model = AdversarialBert(model_name, lambda_param=lambda_param)
    model.to(device)

    # Create dataloaders with dynamic padding
    dataloaders = create_dynamic_padding_dataloaders(
        source_train, target_train, target_val, target_test, batch_size
    )

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    total_steps = len(dataloaders['train']) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Training loop
    print("Starting adversarial training with dynamic padding...")
    best_val_accuracy = 0

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        model.train()
        epoch_loss = 0
        epoch_task_loss = 0
        epoch_domain_loss = 0

        progress_bar = tqdm(dataloaders['train'], desc=f"Epoch {epoch+1}/{num_epochs}")
        for batch in progress_bar:
            input_ids, attention_mask, labels, domain_labels = [b.to(device) for b in batch]

            # Clear gradients
            optimizer.zero_grad()

            # Calculate p value for increasing domain influence over time
            p = float(epoch) / num_epochs
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                domain_labels=domain_labels,
                alpha=alpha
            )

            # Backward pass
            outputs['loss'].backward()
            optimizer.step()
            scheduler.step()

            # Update progress bar
            epoch_loss += outputs['loss'].item()
            epoch_task_loss += outputs['task_loss'].item()
            epoch_domain_loss += outputs['domain_loss'].item()
            progress_bar.set_postfix({
                'loss': epoch_loss / (progress_bar.n + 1),
                'task_loss': epoch_task_loss / (progress_bar.n + 1),
                'domain_loss': epoch_domain_loss / (progress_bar.n + 1)
            })

        # Calculate epoch time
        epoch_end_time = time.time()
        epoch_time = epoch_end_time - epoch_start_time

        # Evaluate on target validation set
        print(f"Evaluating on target validation set (epoch {epoch+1})...")
        val_metrics = evaluate_model(model, dataloaders['val'])
        print(f"Validation metrics: {val_metrics}")
        print(f"Epoch {epoch+1} time: {epoch_time:.2f} seconds")

        # Save best model
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            print(f"New best model with validation accuracy: {best_val_accuracy:.4f}")
            output_dir = "./results/adversarial_dynamic_padding"
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            torch.save(model.state_dict(), os.path.join(output_dir, "model.bin"))

    # Load best model for final evaluation
    print("Loading best model for final evaluation...")
    model.load_state_dict(torch.load(os.path.join(output_dir, "model.bin")))

    # Evaluate on target validation and test sets
    print("Final evaluation on target validation set...")
    final_val_metrics = evaluate_model(model, dataloaders['val'])
    print(f"Final validation metrics: {final_val_metrics}")

    print("Evaluating on target test set...")
    final_test_metrics = evaluate_model(model, dataloaders['test'])
    print(f"Target test metrics: {final_test_metrics}")

    # Evaluate on source test set (IMDB)
    print("Evaluating on source test set (IMDB)...")
    # Create a standard dataloader for IMDB test set evaluation

    # Helper function to pad sequences to max length in batch
    def pad_sequences(batch):
        input_ids = [item['input_ids'] for item in batch]
        attention_mask = [item['attention_mask'] for item in batch]
        labels = [item['label'] for item in batch]

        max_len = max(len(ids) for ids in input_ids)

        # Pad sequences
        padded_input_ids = []
        padded_attention_mask = []

        for ids, mask in zip(input_ids, attention_mask):
            if len(ids) < max_len:
                padding_len = max_len - len(ids)
                padded_input_ids.append(torch.cat([ids, torch.zeros(padding_len, dtype=torch.long)]))
                padded_attention_mask.append(torch.cat([mask, torch.zeros(padding_len, dtype=torch.long)]))
            else:
                padded_input_ids.append(ids)
                padded_attention_mask.append(mask)

        return (
            torch.stack(padded_input_ids),
            torch.stack(padded_attention_mask),
            torch.stack(labels)
        )

    # Create dataloader for IMDB test set
    class SimpleDataset(Dataset):
        def __init__(self, dataset):
            self.dataset = dataset

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            return self.dataset[idx]

    imdb_test_dataset = SimpleDataset(tokenized_imdb['test'])

    imdb_test_dataloader = DataLoader(
        imdb_test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=pad_sequences
    )

    source_test_metrics = evaluate_model(model, imdb_test_dataloader)
    print(f"Source test metrics: {source_test_metrics}")

    # Calculate total training time
    end_time = time.time()
    training_time = end_time - start_time

    return {
        'model': model,
        'target_val_metrics': final_val_metrics,
        'target_test_metrics': final_test_metrics,
        'source_test_metrics': source_test_metrics,
        'training_time': training_time
    }

In [11]:
# Run the adversarial training with dynamic padding
print("Starting adversarial finetuning with dynamic padding...")
results = train_adversarial_with_dynamic_padding(
    source_train=tokenized_imdb['train'],
    target_train=tokenized_sst2['train'],
    target_val=tokenized_sst2['validation'],
    target_test=tokenized_sst2['validation'],
    batch_size=16,
    num_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01
)

Starting adversarial finetuning with dynamic padding...
Starting adversarial training with dynamic padding...


Epoch 1/3: 100%|██████████| 5773/5773 [14:18<00:00,  6.72it/s, loss=0.203, task_loss=0.159, domain_loss=0.44]


Evaluating on target validation set (epoch 1)...
Validation metrics: {'accuracy': 0.9311926605504587, 'precision': 0.9343891402714932, 'recall': 0.9301801801801802, 'f1': 0.9322799097065463}
Epoch 1 time: 858.60 seconds
New best model with validation accuracy: 0.9312


Epoch 2/3: 100%|██████████| 5773/5773 [14:18<00:00,  6.73it/s, loss=0.226, task_loss=0.0836, domain_loss=1.42]


Evaluating on target validation set (epoch 2)...
Validation metrics: {'accuracy': 0.9231651376146789, 'precision': 0.927437641723356, 'recall': 0.9211711711711712, 'f1': 0.9242937853107345}
Epoch 2 time: 858.08 seconds


Epoch 3/3: 100%|██████████| 5773/5773 [14:19<00:00,  6.71it/s, loss=0.114, task_loss=0.045, domain_loss=0.692]


Evaluating on target validation set (epoch 3)...
Validation metrics: {'accuracy': 0.9277522935779816, 'precision': 0.9261744966442953, 'recall': 0.9324324324324325, 'f1': 0.9292929292929293}
Epoch 3 time: 860.00 seconds
Loading best model for final evaluation...
Final evaluation on target validation set...
Final validation metrics: {'accuracy': 0.9311926605504587, 'precision': 0.9343891402714932, 'recall': 0.9301801801801802, 'f1': 0.9322799097065463}
Evaluating on target test set...
Target test metrics: {'accuracy': 0.9311926605504587, 'precision': 0.9343891402714932, 'recall': 0.9301801801801802, 'f1': 0.9322799097065463}
Evaluating on source test set (IMDB)...
Source test metrics: {'accuracy': 0.9388, 'precision': 0.9381690365873143, 'recall': 0.93952, 'f1': 0.9388440322967463}


In [12]:
# Display final results
print("\n\n" + "="*80)
print("ADVERSARIAL DOMAIN ADAPTATION RESULTS WITH DYNAMIC PADDING")
print("="*80)

print(f"Target validation (SST2) results:")
print(f"Accuracy: {results['target_val_metrics']['accuracy']:.4f}")
print(f"F1 Score: {results['target_val_metrics']['f1']:.4f}")
print(f"Precision: {results['target_val_metrics']['precision']:.4f}")
print(f"Recall: {results['target_val_metrics']['recall']:.4f}")

print(f"\nSource test (IMDB) results:")
print(f"Accuracy: {results['source_test_metrics']['accuracy']:.4f}")
print(f"F1 Score: {results['source_test_metrics']['f1']:.4f}")
print(f"Precision: {results['source_test_metrics']['precision']:.4f}")
print(f"Recall: {results['source_test_metrics']['recall']:.4f}")



ADVERSARIAL DOMAIN ADAPTATION RESULTS WITH DYNAMIC PADDING
Target validation (SST2) results:
Accuracy: 0.9312
F1 Score: 0.9323
Precision: 0.9344
Recall: 0.9302

Source test (IMDB) results:
Accuracy: 0.9388
F1 Score: 0.9388
Precision: 0.9382
Recall: 0.9395
