<a href="https://colab.research.google.com/github/reagan13/gpt2-distilbert-thesis-files/blob/main/notebook/Hybrid_Model_Parallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multitask Learning with Hybrid (GPT2-Distilbert)

## Import Libraries

In [None]:
# Cell 1: Imports and Setup
# (Add DistilBERT imports)
import json
import os
import time
from typing import List, Dict, Optional
from collections import Counter, defaultdict
import sys

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Model, GPT2Config, GPT2TokenizerFast,
    DistilBertModel, DistilBertTokenizerFast,  # Added DistilBERT imports
    AdamW, get_linear_schedule_with_warmup
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from datetime import datetime

# Device setup remains the same
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Selected device: {device}")

if device.type == "cuda":
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"Initial GPU Memory Allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    torch.cuda.empty_cache()
else:
    print("No GPU detected. Running on CPU.")

# check_device and setup_logging functions remain unchanged
def check_device(item, name="Item"):
    if isinstance(item, torch.nn.Module):
        param = next(item.parameters(), None)
        if param is not None:
            print(f"{name} is on: {param.device}")
        else:
            print(f"{name} has no parameters to check")
    elif isinstance(item, torch.Tensor):
        print(f"{name} is on: {item.device}")
    else:
        print(f"{name} is not a tensor or model: {type(item)}")

def setup_logging(save_path: str, filename: str = "training_log.txt"):
    os.makedirs(save_path, exist_ok=True)
    log_path = os.path.join(save_path, filename)
    class Logger:
        def __init__(self, file_handle, original_stdout):
            self.file = file_handle
            self.stdout = original_stdout
        def write(self, message):
            self.file.write(message)
            self.stdout.write(message)
        def flush(self):
            self.file.flush()
            self.stdout.flush()
    log_file = open(log_path, "w", encoding="utf-8")
    sys.stdout = Logger(log_file, sys.stdout)
    print(f"Logging started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Log file: {log_path}")

print(f"Current date: March 06, 2025")


## Daset Loading Functions

In [None]:

def load_dataset(json_file: str) -> List[Dict]:
    with open(json_file, 'r', encoding='utf-8') as f:
        return json.load(f)

def detect_labels(data: List[Dict]) -> Dict[str, Dict]:
    start_time = time.time()
    if not isinstance(data, list):
        raise TypeError("Input 'data' must be a list of dictionaries")
    if not data:
        return {"category_encoder": {}, "intent_encoder": {}, "ner_label_encoder": {"O": 0}}

    unique_categories = set()
    unique_intents = set()
    unique_ner_labels = set(["O"])
    missing_fields = defaultdict(int)
    category_counts = Counter()
    intent_counts = Counter()
    ner_counts = Counter()

    for i, sample in enumerate(data):
        try:
            category = sample["category"]
            intent = sample["intent"]
            unique_categories.add(category)
            unique_intents.add(intent)
            category_counts[category] += 1
            intent_counts[intent] += 1
            ner_labels = sample["ner_labels_only"]
            if not isinstance(ner_labels, list):
                raise ValueError(f"'ner_labels_only' must be a list at sample {i}")
            for label in ner_labels:
                if not isinstance(label, dict) or "label" not in label or "text" not in label:
                    raise ValueError(f"NER label must have 'label' and 'text' fields at sample {i}")
                label_type = label["label"]
                unique_ner_labels.add(f"B-{label_type}")
                unique_ner_labels.add(f"I-{label_type}")
                ner_counts[f"B-{label_type}"] += 1
                ner_counts[f"I-{label_type}"] += 1
        except KeyError as e:
            missing_fields[str(e).strip("'")] += 1
            continue

    if missing_fields:
        print("Warning: Missing fields detected:")
        for field, count in missing_fields.items():
            print(f"  - '{field}' missing in {count} samples")

    category_encoder = {cat: idx for idx, cat in enumerate(sorted(unique_categories))}
    intent_encoder = {intent: idx for idx, intent in enumerate(sorted(unique_intents))}
    ner_label_encoder = {ner: idx for idx, ner in enumerate(sorted(unique_ner_labels))}

    print(f"Dataset summary:\n  - {len(data)} samples\n  - {len(category_encoder)} categories\n  - {len(intent_encoder)} intents\n  - {len(ner_label_encoder)} NER tags")
    print("Category distribution:", dict(category_counts))
    print("Intent distribution:", dict(intent_counts))
    print("NER tag distribution (non-O):", dict(ner_counts))
    print(f"Processing time: {time.time() - start_time:.3f} seconds")

    return {"category_encoder": category_encoder, "intent_encoder": intent_encoder, "ner_label_encoder": ner_label_encoder}


## Tokenization and NER Alignment

In [None]:

# Cell 3: Tokenization and NER Alignment (Updated for Dual Tokenizers)
def tokenize_text_hybrid(text: str, gpt2_tokenizer, distilbert_tokenizer, max_length: int) -> Dict[str, torch.Tensor]:
    """Tokenize text using both GPT-2 and DistilBERT tokenizers."""
    gpt2_inputs = gpt2_tokenizer(
        text, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    distilbert_inputs = distilbert_tokenizer(
        text, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return {
        "gpt2_input_ids": gpt2_inputs["input_ids"].squeeze(0),
        "gpt2_attention_mask": gpt2_inputs["attention_mask"].squeeze(0),
        "distilbert_input_ids": distilbert_inputs["input_ids"].squeeze(0),
        "distilbert_attention_mask": distilbert_inputs["attention_mask"].squeeze(0)
    }

def align_ner_labels(text: str, ner_labels: List[Dict], tokenizer, ner_label_encoder: Dict, max_length: int) -> torch.Tensor:
    """Align NER labels with tokenized input (unchanged, using GPT-2 tokenizer for consistency)."""
    sorted_labels = sorted(ner_labels, key=lambda x: len(x["text"]), reverse=True) if ner_labels else []
    encoding = tokenizer(
        text, max_length=max_length, padding="max_length", truncation=True, return_offsets_mapping=True, return_tensors="pt"
    )
    token_to_char_map = encoding["offset_mapping"][0].tolist()
    ner_aligned = [ner_label_encoder["O"]] * max_length

    for label in sorted_labels:
        if "text" not in label or "label" not in label:
            print(f"Warning: Skipping invalid NER entry {label} (missing 'text' or 'label')")
            continue
        try:
            label_text, label_type = label["text"], label["label"]
            start_pos = 0
            while True:
                label_start = text.find(label_text, start_pos)
                if label_start == -1:
                    break
                label_end = label_start + len(label_text)
                start_pos = label_end
                first_token = True
                for i, (start, end) in enumerate(token_to_char_map):
                    if start == 0 and end == 0:
                        continue
                    if max(start, label_start) < min(end, label_end):
                        prefix = "B-" if first_token else "I-"
                        first_token = False
                        ner_aligned[i] = ner_label_encoder.get(f"{prefix}{label_type}", ner_label_encoder["O"])
        except KeyError as e:
            print(f"Warning: Label '{e}' not found in encoder. Skipping.")

    return torch.tensor(ner_aligned, dtype=torch.long)


## Dataset and Dataloader

In [None]:

# Cell 4: Dataset and DataLoader (Updated)
class HybridMultiTaskDataset(Dataset):
    def __init__(self, data: List[Dict], gpt2_tokenizer, distilbert_tokenizer, label_encoders, max_length: int):
        self.data = data
        self.gpt2_tokenizer = gpt2_tokenizer
        self.distilbert_tokenizer = distilbert_tokenizer
        self.label_encoders = label_encoders
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        text = sample["instruction"]
        hybrid_inputs = tokenize_text_hybrid(text, self.gpt2_tokenizer, self.distilbert_tokenizer, self.max_length)
        ner_labels = align_ner_labels(text, sample["ner_labels_only"], self.gpt2_tokenizer, self.label_encoders["ner_label_encoder"], self.max_length)

        return {
            "gpt2_input_ids": hybrid_inputs["gpt2_input_ids"],
            "gpt2_attention_mask": hybrid_inputs["gpt2_attention_mask"],
            "distilbert_input_ids": hybrid_inputs["distilbert_input_ids"],
            "distilbert_attention_mask": hybrid_inputs["distilbert_attention_mask"],
            "category_labels": torch.tensor(self.label_encoders["category_encoder"][sample["category"]], dtype=torch.long),
            "intent_labels": torch.tensor(self.label_encoders["intent_encoder"][sample["intent"]], dtype=torch.long),
            "ner_labels": ner_labels
        }

def get_dataloaders(train_data, val_data, test_data, gpt2_tokenizer, distilbert_tokenizer, label_encoders, batch_size, num_workers, max_length):
    pin_memory = device.type == "cuda"
    train_dataset = HybridMultiTaskDataset(train_data, gpt2_tokenizer, distilbert_tokenizer, label_encoders, max_length)
    val_dataset = HybridMultiTaskDataset(val_data, gpt2_tokenizer, distilbert_tokenizer, label_encoders, max_length)
    test_dataset = HybridMultiTaskDataset(test_data, gpt2_tokenizer, distilbert_tokenizer, label_encoders, max_length)

    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    )

## Model Architecture

In [None]:

# Cell 5: Model Definition (Hybrid Model)
class FusionLayer(nn.Module):
    def __init__(self, gpt2_dim: int, bert_dim: int, output_dim: int, dropout_rate: float = 0.4):
        super().__init__()
        self.gpt2_proj = nn.Linear(gpt2_dim, output_dim)
        self.bert_proj = nn.Linear(bert_dim, output_dim)
        self.fusion = nn.Sequential(
            nn.Linear(output_dim * 2, output_dim),
            nn.Tanh(),
            nn.Dropout(dropout_rate)
        )
        self.layer_norm = nn.LayerNorm(output_dim)

    def forward(self, gpt2_features: torch.Tensor, bert_features: torch.Tensor) -> torch.Tensor:
        gpt2_proj = self.gpt2_proj(gpt2_features)
        bert_proj = self.bert_proj(bert_features)
        concat_features = torch.cat([gpt2_proj, bert_proj], dim=-1)
        fused = self.fusion(concat_features)
        return self.layer_norm(fused)

class HybridGPT2DistilBERTMultiTask(nn.Module):
    def __init__(self, num_intents: int, num_categories: int, num_ner_labels: int, dropout_rate: float = 0.4):
        super().__init__()
        self.gpt2_config = GPT2Config.from_pretrained('gpt2')
        self.gpt2 = GPT2Model.from_pretrained('gpt2')
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')

        for param in self.gpt2.parameters():
            param.requires_grad = False
        for param in self.distilbert.parameters():
            param.requires_grad = False

        gpt2_dim = self.gpt2_config.n_embd
        bert_dim = self.distilbert.config.hidden_size
        hidden_size = gpt2_dim

        self.fusion_layer = FusionLayer(gpt2_dim, bert_dim, hidden_size, dropout_rate)

        self.intent_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, num_intents)
        )
        self.category_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, num_categories)
        )
        self.ner_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size, num_ner_labels)
        )

        self.intent_loss_fn = nn.CrossEntropyLoss()
        self.category_loss_fn = nn.CrossEntropyLoss()
        self.ner_loss_fn = nn.CrossEntropyLoss()

    def forward(self, gpt2_input_ids: torch.Tensor, gpt2_attention_mask: torch.Tensor,
                distilbert_input_ids: torch.Tensor, distilbert_attention_mask: torch.Tensor,
                intent_labels: Optional[torch.Tensor] = None,
                category_labels: Optional[torch.Tensor] = None,
                ner_labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:

        gpt2_outputs = self.gpt2(input_ids=gpt2_input_ids, attention_mask=gpt2_attention_mask)
        distilbert_outputs = self.distilbert(input_ids=distilbert_input_ids, attention_mask=distilbert_attention_mask)

        gpt2_features = gpt2_outputs.last_hidden_state
        bert_features = distilbert_outputs.last_hidden_state

        fused_features = self.fusion_layer(gpt2_features, bert_features)

        batch_size = fused_features.shape[0]
        sequence_lengths = gpt2_attention_mask.sum(dim=1) - 1
        last_token_indexes = sequence_lengths.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, fused_features.shape[-1])
        sequence_repr = torch.gather(fused_features, 1, last_token_indexes).squeeze(1)

        intent_logits = self.intent_head(sequence_repr)
        category_logits = self.category_head(sequence_repr)
        ner_logits = self.ner_head(fused_features)

        output_dict = {
            'intent_logits': intent_logits,
            'category_logits': category_logits,
            'ner_logits': ner_logits
        }

        if all(label is not None for label in [intent_labels, category_labels, ner_labels]):
            intent_loss = self.intent_loss_fn(intent_logits, intent_labels)
            category_loss = self.category_loss_fn(category_logits, category_labels)
            active_loss = gpt2_attention_mask.view(-1) == 1
            active_logits = ner_logits.view(-1, ner_logits.size(-1))[active_loss]
            active_labels = ner_labels.view(-1)[active_loss]
            ner_loss = self.ner_loss_fn(active_logits, active_labels)

            output_dict.update({
                'loss': intent_loss + category_loss + ner_loss,
                'intent_loss': intent_loss,
                'category_loss': category_loss,
                'ner_loss': ner_loss
            })

        return output_dict

## Training Loop

In [None]:

# Cell 6: Training Function (Updated)
from tqdm import tqdm

def train_model(model, train_loader, val_loader, num_epochs, learning_rate):
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    history = {
        "train_loss": [], "val_loss": [],
        "train_intent_acc": [], "val_intent_acc": [],
        "train_category_f1": [], "val_category_f1": [],
        "train_ner_f1": [], "val_ner_f1": []
    }

    model.to(device)
    check_device(model, "Model")

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        all_train_intent_preds, all_train_intent_labels = [], []
        all_train_category_preds, all_train_category_labels = [], []
        all_train_ner_preds, all_train_ner_labels = [], []

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", leave=False) as train_loop:
            for i, batch in enumerate(train_loop):
                optimizer.zero_grad()
                inputs = {k: v.to(device) for k, v in batch.items()}
                if i == 0 and epoch == 0:
                    check_device(inputs["gpt2_input_ids"], "GPT2 Input IDs")
                    check_device(inputs["distilbert_input_ids"], "DistilBERT Input IDs")
                    print(f"GPU Memory Allocated After Data Load: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")

                outputs = model(**inputs)
                loss = outputs["loss"]
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                train_loop.set_postfix(loss=loss.item())

                intent_preds = torch.argmax(outputs["intent_logits"], dim=-1).cpu().numpy()
                category_preds = torch.argmax(outputs["category_logits"], dim=-1).cpu().numpy()
                ner_preds = torch.argmax(outputs["ner_logits"], dim=-1).cpu().numpy()

                all_train_intent_preds.extend(intent_preds)
                all_train_intent_labels.extend(batch["intent_labels"].cpu().numpy())
                all_train_category_preds.extend(category_preds)
                all_train_category_labels.extend(batch["category_labels"].cpu().numpy())
                all_train_ner_preds.extend(ner_preds.flatten())
                all_train_ner_labels.extend(batch["ner_labels"].cpu().numpy().flatten())

        train_intent_acc = accuracy_score(all_train_intent_labels, all_train_intent_preds)
        train_category_f1 = precision_recall_fscore_support(all_train_category_labels, all_train_category_preds, average="macro", zero_division=0)[2]
        train_ner_f1 = precision_recall_fscore_support(all_train_ner_labels, all_train_ner_preds, average="macro", zero_division=0)[2]

        history["train_loss"].append(total_loss / len(train_loader))
        history["train_intent_acc"].append(train_intent_acc)
        history["train_category_f1"].append(train_category_f1)
        history["train_ner_f1"].append(train_ner_f1)

        model.eval()
        val_loss = 0
        all_val_intent_preds, all_val_intent_labels = [], []
        all_val_category_preds, all_val_category_labels = [], []
        all_val_ner_preds, all_val_ner_labels = [], []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]", leave=False):
                inputs = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**inputs)
                val_loss += outputs["loss"].item()

                intent_preds = torch.argmax(outputs["intent_logits"], dim=-1).cpu().numpy()
                category_preds = torch.argmax(outputs["category_logits"], dim=-1).cpu().numpy()
                ner_preds = torch.argmax(outputs["ner_logits"], dim=-1).cpu().numpy()

                all_val_intent_preds.extend(intent_preds)
                all_val_intent_labels.extend(batch["intent_labels"].cpu().numpy())
                all_val_category_preds.extend(category_preds)
                all_val_category_labels.extend(batch["category_labels"].cpu().numpy())
                all_val_ner_preds.extend(ner_preds.flatten())
                all_val_ner_labels.extend(batch["ner_labels"].cpu().numpy().flatten())

        val_intent_acc = accuracy_score(all_val_intent_labels, all_val_intent_preds)
        val_category_f1 = precision_recall_fscore_support(all_val_category_labels, all_val_category_preds, average="macro", zero_division=0)[2]
        val_ner_f1 = precision_recall_fscore_support(all_val_ner_labels, all_val_ner_preds, average="macro", zero_division=0)[2]

        history["val_loss"].append(val_loss / len(val_loader))
        history["val_intent_acc"].append(val_intent_acc)
        history["val_category_f1"].append(val_category_f1)
        history["val_ner_f1"].append(val_ner_f1)

        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train Loss:      {history['train_loss'][-1]:.4f}")
        print(f"  Val Loss:       {history['val_loss'][-1]:.4f}\n")
        print(f"  Train Intent Acc: {train_intent_acc:.4f}")
        print(f"  Val Intent Acc:  {val_intent_acc:.4f}\n")
        print(f"  Train Category F1:{train_category_f1:.4f}")
        print(f"  Val Category F1: {val_category_f1:.4f}\n")
        print(f"  Train NER F1:     {train_ner_f1:.4f}")
        print(f"  Val NER F1:      {val_ner_f1:.4f}\n")

    return history

## Evaluation

In [None]:

# Cell 7: Evaluation Function (Updated)
def evaluate_model(model, test_loader):
    model.eval()
    all_intent_preds, all_intent_labels = [], []
    all_category_preds, all_category_labels = [], []
    all_ner_preds, all_ner_labels = [], []
    total_loss = 0

    device = next(model.parameters()).device

    test_loop = tqdm(test_loader, desc="Evaluation", leave=True)
    with torch.no_grad():
        for batch in test_loop:
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            batch_loss = outputs["loss"].item()
            total_loss += batch_loss
            test_loop.set_postfix(loss=batch_loss)

            intent_preds = torch.argmax(outputs["intent_logits"], dim=-1).cpu().numpy()
            category_preds = torch.argmax(outputs["category_logits"], dim=-1).cpu().numpy()
            ner_preds = torch.argmax(outputs["ner_logits"], dim=-1).cpu().numpy()

            all_intent_preds.extend(intent_preds)
            all_intent_labels.extend(batch["intent_labels"].cpu().numpy())
            all_category_preds.extend(category_preds)
            all_category_labels.extend(batch["category_labels"].cpu().numpy())
            all_ner_preds.extend(ner_preds.flatten())
            all_ner_labels.extend(batch["ner_labels"].cpu().numpy().flatten())

    intent_acc = accuracy_score(all_intent_labels, all_intent_preds)
    category_f1 = precision_recall_fscore_support(all_category_labels, all_category_preds, average="macro", zero_division=0)[2]
    ner_f1 = precision_recall_fscore_support(all_ner_labels, all_ner_preds, average="macro", zero_division=0)[2]
    avg_loss = total_loss / len(test_loader)

    print(f"Test Results:")
    print(f"  Loss:            {avg_loss:.4f}")
    print(f"  Intent Acc:      {intent_acc:.4f}")
    print(f"  Category F1:     {category_f1:.4f}")
    print(f"  NER F1:          {ner_f1:.4f}")

    return {"loss": avg_loss, "intent_accuracy": intent_acc, "category_f1": category_f1, "ner_f1": ner_f1}


## Save Artifacts

In [None]:

# Cell 8: Save Functions (Unchanged)
def save_training_artifacts(model, gpt2_tokenizer, label_encoders, metrics, test_results, save_path):
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, "model.pth"))
    gpt2_tokenizer.save_pretrained(os.path.join(save_path, "gpt2_tokenizer"))
    with open(os.path.join(save_path, "label_encoders.json"), "w", encoding="utf-8") as f:
        json.dump(label_encoders, f, ensure_ascii=False, indent=4)
    with open(os.path.join(save_path, "training_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=4)
    with open(os.path.join(save_path, "test_results.json"), "w", encoding="utf-8") as f:
        json.dump(test_results, f, ensure_ascii=False, indent=4)
    print(f"Artifacts saved to {save_path}")

## Saving training config

In [None]:
import json
import os

def save_training_config(config: dict, save_path: str, filename: str = "training_config.json"):
    """
    Save the training configuration to a JSON file.

    Args:
        config (dict): Dictionary containing training hyperparameters.
        save_path (str): Directory where the config file will be saved.
        filename (str): Name of the config file (default: "training_config.json").
    """
    os.makedirs(save_path, exist_ok=True)
    config_path = os.path.join(save_path, filename)

    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(config, f, ensure_ascii=False, indent=4)
    print(f"Training configuration saved to {config_path}")

### Save Model

In [None]:
# Cell 8: Save Model Functions
import json
import os
import torch

def save_full_model(model, gpt2_tokenizer, label_encoders, metrics, test_results, save_path):
    """Save the entire model directly, tokenizer, label encoders, and results."""
    os.makedirs(save_path, exist_ok=True)

    # Save the full model (architecture + weights)
    torch.save(model, os.path.join(save_path, "full_model.pt"))

    # Save tokenizer
    gpt2_tokenizer.save_pretrained(os.path.join(save_path, "tokenizer"))

    # Save label encoders
    with open(os.path.join(save_path, "label_encoders.json"), "w", encoding="utf-8") as f:
        json.dump(label_encoders, f, ensure_ascii=False, indent=4)

    # Save training metrics
    with open(os.path.join(save_path, "training_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(metrics, f, ensure_ascii=False, indent=4)

    # Save test results
    with open(os.path.join(save_path, "test_results.json"), "w", encoding="utf-8") as f:
        json.dump(test_results, f, ensure_ascii=False, indent=4)

    print(f"Full model and artifacts saved to {save_path}")


## Main Execution

### Paths and Hyperparameters

In [None]:
# Cell 9: Main Execution
# Data paths and hyperparameters
train_file = "train.json"
val_file = "val.json"
test_file = "test.json"
batch_size = 16
num_epochs = 1
learning_rate = 2e-5
max_length = 128
num_workers = 2
save_path = "saved_models"


# Define training configuration
training_config = {
    "train_file": train_file,
    "val_file": val_file,
    "test_file": test_file,
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "max_length": max_length,
    "num_workers": num_workers,
    "model_name": "BaselineGPT2MultiTask",
    "gpt2_base": "gpt2",
    "dropout_rate": 0.4,
    "device": str(device),
    "date": "March 06, 2025"
}



### Initialization

In [None]:

# Load datasets
print("Loading datasets...\n")
train_data = load_dataset(train_file)[:100]  # Limited for demo
val_data = load_dataset(val_file)[:20]
test_data = load_dataset(test_file)[:20]

print("*" * 30)
print(f"""Dataset Summary:
Training samples: {len(train_data)}
Validation samples: {len(val_data)}
Test samples: {len(test_data)}""")

# Detect labels
label_encoders = detect_labels(train_data)


# Initialize tokenizer
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
if gpt2_tokenizer.pad_token is None:
    gpt2_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Create data loaders
train_loader, val_loader, test_loader = get_dataloaders(
    train_data, val_data, test_data, gpt2_tokenizer, label_encoders, batch_size, num_workers, max_length
)

# Initialize model
model = BaselineGPT2MultiTask(
    num_intents=len(label_encoders["intent_encoder"]),
    num_categories=len(label_encoders["category_encoder"]),
    num_ner_labels=len(label_encoders["ner_label_encoder"])
)
if gpt2_tokenizer.pad_token_id is not None:
    model.gpt2.resize_token_embeddings(len(gpt2_tokenizer))

# Setup logging to save print statements
setup_logging(save_path)

model.to(device)  # Ensure model is moved to GPU
check_device(model, "Model before training")  # Verify

# Save training config before training
save_training_config(training_config, save_path)

# Train model
print("*" * 30)
print("Starting training...")

start_time = time.time()


metrics = train_model(model, train_loader, val_loader, num_epochs, learning_rate)
print(f"Training completed in {(time.time() - start_time) / 60:.2f} minutes")

print("*" * 30)

# Evaluate model
print("Evaluating on test set...")
test_results = evaluate_model(model, test_loader)

# Save artifacts
save_training_artifacts(model, gpt2_tokenizer, label_encoders, metrics, test_results, save_path)

# Save Full model
save_full_model(model, gpt2_tokenizer, label_encoders, metrics, test_results, save_path)
