### Artifact Expert model training

In [12]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import datasets
from helpers import compute_accuracy

# Load the model
artifact_expert = AutoModelForSequenceClassification.from_pretrained('google/electra-small-discriminator', num_labels=3)

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained('google/electra-small-discriminator', use_fast=True)

# Load the dataset
dataset = load_dataset('veluribharath/snli')

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Found cached dataset parquet (/Users/velurib/.cache/huggingface/datasets/veluribharath___parquet/veluribharath--snli-62105ac6e2fb9a2a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
def prepare_dataset_for_artifact_expert(examples, tokenizer, max_seq_length=None):
    """
    Tokenize just the hypothesis using the given tokenizer.
    """

    max_seq_length = tokenizer.model_max_length if max_seq_length is None else max_seq_length

    # Tokenize only the hypothesis part
    tokenized_examples = tokenizer(
        examples['hypothesis'],
        padding='max_length',
        truncation=True,
        max_length=max_seq_length
    )

    # Include labels as before
    tokenized_examples['label'] = examples['label']
    return tokenized_examples


In [14]:
# Prepare the dataset
prepared_dataset = dataset['train'].map(
    lambda examples: prepare_dataset_for_artifact_expert(examples, tokenizer, max_seq_length=tokenizer.model_max_length),
    batched=True,
    num_proc=30,
    remove_columns=dataset['train'].column_names  # This removes the original columns, only keeping the processed ones
)

prepared_dataset

Loading cached processed dataset at /Users/velurib/.cache/huggingface/datasets/veluribharath___parquet/veluribharath--snli-62105ac6e2fb9a2a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0fe64601b7adae6c_*_of_00030.arrow


Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 549367
})

In [15]:
train_dataset = prepared_dataset.select(range(0,100))

In [16]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./artifact_expert_model',  # directory for model output
    num_train_epochs=3,                    # number of training epochs
    per_device_train_batch_size=16,        # batch size for training
    per_device_eval_batch_size=64,         # batch size for evaluation
    warmup_steps=500,                      # number of warmup steps for learning rate scheduler
    weight_decay=0.01,                     # strength of weight decay
    logging_dir='./logs',                  # directory for storing logs
    logging_steps=10,
    save_strategy='epoch',
)

trainer = Trainer(
    model=artifact_expert,
    args=training_args,
    train_dataset=train_dataset,
    compute_metrics=compute_accuracy
)

In [18]:
trainer.train()
trainer.save_model()
# # Save the model and the tokenizer
# model_save_path = "./model_output/final_model"
# trainer.model.save_pretrained(model_save_path)
# trainer.tokenizer.save_pretrained(model_save_path)

# trainer.save_state()

Step,Training Loss
10,1.098
20,1.1011


In [19]:
# evaluate the model

eval_dataset = datasets.load_dataset('veluribharath/snli', split='validation')
eval_dataset_featurized = eval_dataset.map(
    lambda exs: prepare_dataset_for_artifact_expert(exs, tokenizer, max_seq_length=tokenizer.model_max_length),
    batched=True,
    remove_columns=eval_dataset.column_names
)

eval_dataset_featurized

Found cached dataset parquet (/Users/velurib/.cache/huggingface/datasets/veluribharath___parquet/veluribharath--snli-62105ac6e2fb9a2a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Map:   0%|          | 0/9842 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 9842
})

In [20]:
eval_results = trainer.evaluate(eval_dataset_featurized)
print('Evaluation results:', eval_results)

Evaluation results: {'eval_loss': 1.097949504852295, 'eval_accuracy': 0.3503352999687195, 'eval_runtime': 81.9243, 'eval_samples_per_second': 120.135, 'eval_steps_per_second': 1.88, 'epoch': 3.0}


In [22]:
# import torch
# # Save the main model
# torch.save(artifact_expert, 'models/artifact_expert.pth')

### Retraining the main model using weighted output to remove bias artifact

In [23]:
# Load the main model
main_model = AutoModelForSequenceClassification.from_pretrained('google/electra-small-discriminator', num_labels=3)

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
import torch
if isinstance(artifact_expert, torch.nn.Module) and isinstance(main_model, torch.nn.Module):
    print("This is a PyTorch model.")

This is a PyTorch model.


In [25]:
from helpers import prepare_dataset_nli

# Prepare the dataset
prepared_main_dataset = dataset['train'].map(
    lambda examples: prepare_dataset_nli(examples, tokenizer, max_seq_length=tokenizer.model_max_length),
    batched=True,
    num_proc=30,
    remove_columns=dataset['train'].column_names  # This removes the original columns, only keeping the processed ones
)

prepared_main_dataset

Map (num_proc=30):   0%|          | 0/549367 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 549367
})

In [28]:
prepared_main_dataset = prepared_main_dataset.select(range(0,100))

In [29]:
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator

# Convert the dataset to PyTorch format
prepared_main_dataset.set_format('torch', columns=['label', 'input_ids', 'attention_mask', 'token_type_ids'])

# Data collator used for dynamically padding the inputs and labels
data_collator = DefaultDataCollator()

# Creating a DataLoader to iterate over the training dataset
train_dataloader = DataLoader(
    prepared_main_dataset, 
    batch_size=32,  # TODO: batch size
    shuffle=True,  # Shuffling the data for each epoch
    collate_fn=data_collator  # Collator will handle dynamic padding
)


In [30]:
import numpy as np
from tqdm import tqdm

# Generating predictions from artifact expert
artifact_predictions = []
with torch.no_grad():
    for batch in tqdm(train_dataloader):
        inputs = {k: v.to(artifact_expert.device) for k, v in batch.items() if k != 'labels'}
        outputs = artifact_expert(**inputs)
        artifact_predictions.append(outputs.logits.detach().cpu().numpy())
artifact_predictions = np.concatenate(artifact_predictions, axis=0)

100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  3.05it/s]


In [32]:
def adjust_weights(true_labels, artifact_preds, threshold=0.5):
    """
    Adjust the sample weights based on artifact expert predictions.

    Args:
    - true_labels (Tensor): The true labels for the training data.
    - artifact_preds (ndarray): The predictions from the artifact expert model.
    - threshold (float): A threshold to decide the weight of each sample.

    Returns:
    - Tensor: Weights for each sample.
    """
    # Convert artifact predictions to probabilities
    artifact_probs = torch.softmax(torch.from_numpy(artifact_preds), dim=-1)

    # Get the max probability and corresponding label for each prediction
    max_probs, artifact_labels = torch.max(artifact_probs, dim=-1)

    # Check if the artifact prediction agrees with the true label
    agreement = artifact_labels == true_labels

    # Calculate weights: Lower weight if artifact model is too confident and agrees with the true label
    weights = torch.where((max_probs > threshold) & agreement, 0.5, 1.0)  # Half the weight if too confident and in agreement

    return weights

In [33]:
import torch
from transformers import AdamW
from tqdm import tqdm

# Optimizer
optimizer = AdamW(main_model.parameters(), lr=5e-5)

# Training loop
num_epochs = 3  # TODO: number of epochs

for epoch in tqdm(range(num_epochs)):
    main_model.train()
    for i, batch in enumerate(train_dataloader):

        # Extract labels and adjust the weights
        true_labels = batch['labels'].numpy()
        batch_artifact_preds = artifact_predictions[i * batch['labels'].size(0):(i + 1) * batch['labels'].size(0)]
        weights = adjust_weights(true_labels, batch_artifact_preds)

        # Prepare batch for training
        batch = {k: v.to(main_model.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = main_model(**batch)
        loss = outputs.loss

        # Modify loss based on weights and compute weighted loss
        weighted_loss = (loss * weights.to(main_model.device)).mean()

        # Backward pass and optimization
        optimizer.zero_grad()
        weighted_loss.backward()
        optimizer.step()

        print(f"Epoch: {epoch}, Loss: {weighted_loss.item()}")

  0%|                                                     | 0/3 [00:00<?, ?it/s]

Epoch: 0, Loss: 1.1014829874038696
Epoch: 0, Loss: 1.0941085815429688
Epoch: 0, Loss: 1.1036146879196167


 33%|███████████████                              | 1/3 [00:20<00:41, 20.64s/it]

Epoch: 0, Loss: 1.0908246040344238
Epoch: 1, Loss: 1.0922794342041016
Epoch: 1, Loss: 1.0995546579360962
Epoch: 1, Loss: 1.0951451063156128


 67%|██████████████████████████████               | 2/3 [00:38<00:19, 19.21s/it]

Epoch: 1, Loss: 1.1114813089370728
Epoch: 2, Loss: 1.091253399848938
Epoch: 2, Loss: 1.0976595878601074
Epoch: 2, Loss: 1.0903254747390747


100%|█████████████████████████████████████████████| 3/3 [00:58<00:00, 19.40s/it]

Epoch: 2, Loss: 1.1074068546295166





In [34]:
# evaluate the model
# dummy trainer object arguments
main_model_training_args = TrainingArguments(
    output_dir='./main_model',             # directory for model output
    num_train_epochs=3,                    # number of training epochs
    per_device_train_batch_size=16,        # batch size for training
    per_device_eval_batch_size=64,         # batch size for evaluation
    warmup_steps=500,                      # number of warmup steps for learning rate scheduler
    weight_decay=0.01,                     # strength of weight decay
    logging_dir='./main_logs',             # directory for storing logs
    logging_steps=10,
)

# dummy trainer object
main_trainer = Trainer(
    model=main_model,
    args=main_model_training_args,
    train_dataset=prepared_main_dataset,
    compute_metrics=compute_accuracy
)

#main model evaluation
eval_results = main_trainer.evaluate(eval_dataset_featurized)
print('Evaluation results:', eval_results)

Evaluation results: {'eval_loss': 1.0997998714447021, 'eval_accuracy': 0.336212158203125, 'eval_runtime': 81.041, 'eval_samples_per_second': 121.445, 'eval_steps_per_second': 1.9}


In [None]:
# Save the main model
torch.save(main_model, 'models/main_model.pth')