# Domain Adaptation

This notebook explores domain adaptation techniques for sentiment analysis.

## Objectives
- Train models on source domain (IMDB)
- Adapt to target domain (Yelp)
- Compare fine-tuning vs adversarial adaptation
- Evaluate domain shift effects


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd())))

import torch
import pandas as pd
import numpy as np
from pathlib import Path

from src.data.dataset_loader import load_preprocessed_data
from src.models.transformer_model import BERTForSentiment
from src.models.adapter_layers import DomainAdapter, AdversarialDiscriminator
from src.train.train_transformer import train_transformer
from src.utils.seed_everything import seed_everything
from src.utils.config_loader import load_config

seed_everything(42)
config = load_config('../config.yaml')

# Set device: MPS (Metal) for Apple Silicon, CUDA for NVIDIA, else CPU
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✅ Using Apple Silicon GPU (MPS)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("✅ Using CUDA GPU")
else:
    device = torch.device("cpu")
    print("⚠️ Using CPU (no GPU available)")
print(f"Device: {device}")


## 1. Load Preprocessed Source and Target Domains

Load preprocessed data from notebook 02 - no need to clean again!


# Load preprocessed source domain (IMDB) - train split for domain adaptation
print("Loading preprocessed source domain (IMDB train)...")
source_train_texts, source_train_labels = load_preprocessed_data('imdb_train', data_dir='../intermediate/data')
source_val_texts, source_val_labels = load_preprocessed_data('imdb_val', data_dir='../intermediate/data')

# Load preprocessed target domain (Yelp) - train split for fine-tuning
print("Loading preprocessed target domain (Yelp train)...")
target_train_texts, target_train_labels = load_preprocessed_data('yelp_train', data_dir='../intermediate/data')
target_val_texts, target_val_labels = load_preprocessed_data('yelp_val', data_dir='../intermediate/data')

# Sample for domain adaptation experiments (use subset for faster training)
target_train_texts = target_train_texts[:5000]
target_train_labels = target_train_labels[:5000]
target_val_texts = target_val_texts[:1000]
target_val_labels = target_val_labels[:1000]

print(f"✅ Source (IMDB): Train={len(source_train_texts)}, Val={len(source_val_texts)}")
print(f"✅ Target (Yelp): Train={len(target_train_texts)}, Val={len(target_val_texts)} (subset for faster training)")


In [None]:
## 2. Train on Source Domain


In [None]:
# Create data loaders (data already split from preprocessing)
from src.train.trainer_utils import create_dataloader

# Use subset for faster training
source_train_loader = create_dataloader(source_train_texts[:2000], source_train_labels[:2000],
                                       'bert-base-uncased', batch_size=16, shuffle=True)
source_val_loader = create_dataloader(source_val_texts[:500], source_val_labels[:500],
                                     'bert-base-uncased', batch_size=16, shuffle=False)

# Train on source domain
source_model = BERTForSentiment('bert-base-uncased', num_classes=2).to(device)
source_history = train_transformer(source_model, source_train_loader, source_val_loader, device,
                                  num_epochs=2, learning_rate=2e-5)

print("Source domain training complete!")


## 3. Fine-tune on Target Domain


In [None]:
# Fine-tune on target domain (transfer learning) - data already split
target_train_loader = create_dataloader(target_train_texts[:1000], target_train_labels[:1000],
                                       'bert-base-uncased', batch_size=16, shuffle=True)
target_val_loader = create_dataloader(target_val_texts[:200], target_val_labels[:200],
                                     'bert-base-uncased', batch_size=16, shuffle=False)

# Fine-tune the source model on target domain
adapted_history = train_transformer(source_model, target_train_loader, target_val_loader, device,
                                    num_epochs=2, learning_rate=1e-5)  # Lower LR for fine-tuning

print("Domain adaptation complete!")
