In [4]:
import torch
from torch import nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from transformers.modeling_outputs import SequenceClassifierOutput

In [3]:
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
dataset = load_dataset("glue", "mrpc")

In [5]:
# Define the soft prompt (random initialization for demonstration)
soft_prompt_length = 20 # Length of the soft prompt (number of tokens)
soft_prompt_dim = model.config.hidden_size  # Dimension of model's hidden layer
soft_prompt = nn.Parameter(torch.randn(soft_prompt_length, soft_prompt_dim))

In [23]:
for param in model.bert.parameters():
    param.requires_grad = False  # Freeze all layers

for param in model.bert.encoder.layer[-1:].parameters():
    param.requires_grad = True   # Unfreeze the last two layers

In [24]:
class SoftPromptingModel(nn.Module):
    def __init__(self, model, soft_prompt):
        super(SoftPromptingModel, self).__init__()
        self.model = model
        self.soft_prompt = soft_prompt

    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get input embeddings from the original model
        input_embeddings = self.model.bert.embeddings.word_embeddings(input_ids)

        # Repeat soft prompt for each input in the batch
        batch_size = input_ids.size(0)
        repeated_soft_prompt = self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1)

        # Concatenate soft prompt to the input embeddings
        combined_embeddings = torch.cat((repeated_soft_prompt, input_embeddings), dim=1)

        # Adjust attention mask to include soft prompt
        if attention_mask is not None:
            prompt_attention_mask = torch.ones(batch_size, self.soft_prompt.size(0), device=input_ids.device)
            attention_mask = torch.cat((prompt_attention_mask, attention_mask), dim=1)

        # Forward pass through the model with combined embeddings
        outputs = self.model.bert(
            inputs_embeds=combined_embeddings, 
            attention_mask=attention_mask, 
            return_dict=True
        )
        # Using the classification head on top of BERT for sequence classification
        logits = self.model.classifier(outputs.pooler_output)

        # Return the outputs for the Trainer API
        return SequenceClassifierOutput(
            loss=nn.CrossEntropyLoss()(logits, labels) if labels is not None else None,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [25]:
soft_prompting_model = SoftPromptingModel(model, soft_prompt)


In [26]:
def tokenize_function(examples):
    return tokenizer(examples["sentence1"], examples["sentence2"], padding="max_length", truncation=True, max_length=128)


In [27]:
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch", columns=["input_ids", "attention_mask", "labels"])


Map: 100%|██████████| 3668/3668 [00:00<00:00, 20023.12 examples/s]


In [28]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    save_strategy="epoch",  # Save the model at the end of each epoch
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    use_mps_device=True 
)




In [29]:
trainer = Trainer(
    model=soft_prompting_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
)

In [30]:
trainer.train()



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                               
 28%|██▊       | 1274/4590 [05:02<06:56,  7.97it/s]
[A

{'eval_loss': 0.5782884359359741, 'eval_runtime': 2.9903, 'eval_samples_per_second': 136.439, 'eval_steps_per_second': 17.055, 'epoch': 1.0}


                                                   
 28%|██▊       | 1274/4590 [05:08<06:56,  7.97it/s]

{'loss': 0.2864, 'grad_norm': 1.253495216369629, 'learning_rate': 8.910675381263617e-05, 'epoch': 1.09}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                               
 28%|██▊       | 1274/4590 [06:04<06:56,  7.97it/s]
[A

{'eval_loss': 0.5485397577285767, 'eval_runtime': 2.8026, 'eval_samples_per_second': 145.58, 'eval_steps_per_second': 18.197, 'epoch': 2.0}


                                                   
 28%|██▊       | 1274/4590 [06:15<06:56,  7.97it/s]

{'loss': 0.2332, 'grad_norm': 2.7007904052734375, 'learning_rate': 7.821350762527233e-05, 'epoch': 2.18}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

 28%|██▊       | 1274/4590 [07:06<06:56,  7.97it/s]
[A
[A

{'eval_loss': 0.42297425866127014, 'eval_runtime': 2.8149, 'eval_samples_per_second': 144.943, 'eval_steps_per_second': 18.118, 'epoch': 3.0}


                                                   
 28%|██▊       | 1274/4590 [07:23<06:56,  7.97it/s]

{'loss': 0.2473, 'grad_norm': 2.3062713146209717, 'learning_rate': 6.73202614379085e-05, 'epoch': 3.27}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                                

 28%|██▊       | 1274/4590 [08:14<06:56,  7.97it/s]
[A
[A

{'eval_loss': 0.526482880115509, 'eval_runtime': 3.2631, 'eval_samples_per_second': 125.035, 'eval_steps_per_second': 15.629, 'epoch': 4.0}


                                                   
 28%|██▊       | 1274/4590 [08:40<06:56,  7.97it/s]

{'loss': 0.268, 'grad_norm': 2.947028875350952, 'learning_rate': 5.642701525054467e-05, 'epoch': 4.36}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [09:36<06:56,  7.97it/s]
[A

{'eval_loss': 0.5650869011878967, 'eval_runtime': 4.097, 'eval_samples_per_second': 99.584, 'eval_steps_per_second': 12.448, 'epoch': 5.0}


                                                   
 28%|██▊       | 1274/4590 [10:15<06:56,  7.97it/s]

{'loss': 0.2483, 'grad_norm': 4.467926979064941, 'learning_rate': 4.5533769063180834e-05, 'epoch': 5.45}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [11:03<06:56,  7.97it/s]
[A

{'eval_loss': 0.6241656541824341, 'eval_runtime': 3.5886, 'eval_samples_per_second': 113.692, 'eval_steps_per_second': 14.211, 'epoch': 6.0}


                                                   
 28%|██▊       | 1274/4590 [11:42<06:56,  7.97it/s]

{'loss': 0.2382, 'grad_norm': 2.1550838947296143, 'learning_rate': 3.464052287581699e-05, 'epoch': 6.54}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [12:18<06:56,  7.97it/s]
[A

{'eval_loss': 0.7092337012290955, 'eval_runtime': 3.324, 'eval_samples_per_second': 122.744, 'eval_steps_per_second': 15.343, 'epoch': 7.0}


                                                   
 28%|██▊       | 1274/4590 [13:02<06:56,  7.97it/s]

{'loss': 0.2141, 'grad_norm': 8.966988563537598, 'learning_rate': 2.374727668845316e-05, 'epoch': 7.63}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [13:30<06:56,  7.97it/s]
[A

{'eval_loss': 0.7193672060966492, 'eval_runtime': 3.1981, 'eval_samples_per_second': 127.575, 'eval_steps_per_second': 15.947, 'epoch': 8.0}


                                                   
 28%|██▊       | 1274/4590 [14:18<06:56,  7.97it/s]

{'loss': 0.2149, 'grad_norm': 0.3270353674888611, 'learning_rate': 1.2854030501089326e-05, 'epoch': 8.71}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [14:40<06:56,  7.97it/s]
[A

{'eval_loss': 0.7541595697402954, 'eval_runtime': 3.1178, 'eval_samples_per_second': 130.859, 'eval_steps_per_second': 16.357, 'epoch': 9.0}


                                                   
 28%|██▊       | 1274/4590 [15:32<06:56,  7.97it/s]

{'loss': 0.1949, 'grad_norm': 3.3522238731384277, 'learning_rate': 1.96078431372549e-06, 'epoch': 9.8}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                   
[A                                                
 28%|██▊       | 1274/4590 [15:48<06:56,  7.97it/s]
[A

{'eval_loss': 0.7415537238121033, 'eval_runtime': 2.9967, 'eval_samples_per_second': 136.151, 'eval_steps_per_second': 17.019, 'epoch': 10.0}


                                                   
100%|██████████| 4590/4590 [11:58<00:00,  6.39it/s]

{'train_runtime': 718.8229, 'train_samples_per_second': 51.028, 'train_steps_per_second': 6.385, 'train_loss': 0.23729240151532074, 'epoch': 10.0}





TrainOutput(global_step=4590, training_loss=0.23729240151532074, metrics={'train_runtime': 718.8229, 'train_samples_per_second': 51.028, 'train_steps_per_second': 6.385, 'total_flos': 0.0, 'train_loss': 0.23729240151532074, 'epoch': 10.0})

In [31]:
eval_results = trainer.evaluate()
print(f"Validation Accuracy: {eval_results}")

100%|██████████| 51/51 [00:02<00:00, 24.34it/s]

Validation Accuracy: {'eval_loss': 0.42297425866127014, 'eval_runtime': 2.1591, 'eval_samples_per_second': 188.972, 'eval_steps_per_second': 23.621, 'epoch': 10.0}





In [36]:
model.save_pretrained("./glue_mrpc_softprompt_bert_model")
tokenizer.save_pretrained("./glue_mrpc_softprompt_bert_tokenizer")

('./glue_mrpc_softprompt_bert_tokenizer/tokenizer_config.json',
 './glue_mrpc_softprompt_bert_tokenizer/special_tokens_map.json',
 './glue_mrpc_softprompt_bert_tokenizer/vocab.txt',
 './glue_mrpc_softprompt_bert_tokenizer/added_tokens.json',
 './glue_mrpc_softprompt_bert_tokenizer/tokenizer.json')

In [32]:
predictions = trainer.predict(tokenized_datasets["validation"])



100%|██████████| 51/51 [00:02<00:00, 21.71it/s]


In [33]:
import numpy as np
from sklearn.metrics import classification_report


In [34]:

labels = np.array(tokenized_datasets["validation"]["labels"])
logits = predictions.predictions
y_pred = np.argmax(logits, axis=-1)


In [35]:
# Classification Report
class_report = classification_report(labels, y_pred)
print("Classification Report:")
print(class_report)

Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.68      0.70       129
           1       0.86      0.87      0.86       279

    accuracy                           0.81       408
   macro avg       0.78      0.78      0.78       408
weighted avg       0.81      0.81      0.81       408

