# Gemma Shield 2B Model Training

This notebook trains a Gemma Shield 2B model on an instruction dataset for evaluation purposes.

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

## Dataset Preparation

In [None]:
class InstructionDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=1024):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Format the input as a conversation
        prompt = f"System: {row['system_prompt']}\n\nUser: {row['user_prompt']}"
        response = row['response']
        
        # Format as instruction-response pair
        text = f"{prompt}\n\nResponse: {response}"
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        # For causal language modeling
        encoding["labels"] = encoding["input_ids"].clone()
        
        # Return all encoded data
        return {key: val.squeeze() for key, val in encoding.items()}

## Load Data from Parquet

In [None]:
# Set the path to your parquet files
data_path = "/home/eduardo/Desktop/Others/Adapta/vizeval/synthetic_data2/"

# Load data from parquet files
df = pd.read_parquet(data_path)

# Display sample data
df.head()

In [None]:
# Check data structure
print(f"Dataset shape: {df.shape}")
print("\nColumns:")
for col in df.columns:
    print(f"- {col}")

## Prepare Training and Validation Sets

In [None]:
# Split data into train and validation sets
train_df, val_df = train_test_split(df, test_size=0.15, random_state=42)

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")

# Create datasets
train_dataset = InstructionDataset(train_df, tokenizer)
val_dataset = InstructionDataset(val_df, tokenizer)

## Load Model

In [None]:
# Load the model
model = AutoModelForCausalLM.from_pretrained("google/shieldgemma-2b")

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)

## Training Setup

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./gemma_shield_results",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,  # Adjust based on your GPU memory
    per_device_eval_batch_size=1,
    eval_steps=500,
    save_steps=500,
    warmup_steps=100,
    logging_dir="./logs",
    logging_steps=100,
    evaluation_strategy="steps",
    load_best_model_at_end=True,
    save_total_limit=2,
    fp16=True,  # Use mixed precision training if available
    gradient_accumulation_steps=4,  # Accumulate gradients to simulate larger batch sizes
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

## Train the Model

In [None]:
# Train the model
trainer.train()

## Save the Model

In [None]:
# Save the model
trainer.save_model("./gemma_shield_model")
tokenizer.save_pretrained("./gemma_shield_model")

## Export to TorchScript

In [None]:
def export_to_torchscript(model, tokenizer, save_path="gemma_shield_model.pt"):
    # Set model to evaluation mode
    model.eval()
    
    # Create a wrapper class for tracing
    class ModelWrapper(torch.nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
            
        def forward(self, input_ids, attention_mask):
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            return outputs.logits
    
    # Create the wrapper
    wrapped_model = ModelWrapper(model)
    
    # Create example inputs for tracing
    example_text = "System: You are a helpful assistant.\n\nUser: What is the capital of France?"
    encoded = tokenizer(example_text, return_tensors="pt", padding="max_length", max_length=128)
    example_input_ids = encoded["input_ids"].to(device)
    example_attention_mask = encoded["attention_mask"].to(device)
    
    # Trace the model
    with torch.no_grad():
        traced_model = torch.jit.trace(wrapped_model, (example_input_ids, example_attention_mask))
    
    # Save the traced model
    torch.jit.save(traced_model, save_path)
    print(f"Model exported to TorchScript format and saved at {save_path}")
    
    return save_path

In [None]:
# Export the trained model to TorchScript format
model_path = export_to_torchscript(model, tokenizer, "gemma_shield_model.pt")