In [None]:
!pip install -U pyarrow --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m50.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.
ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.[0m[31m
[0m

In [None]:
!pip install datasets transformers torch seqeval evaluate  --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m527.3/527.3 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [None]:
import torch
from torch import nn
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from typing import Dict, List

class EWC(nn.Module):
    def __init__(self, model: nn.Module, dataset, tokenizer, importance: float = 1000):
        super().__init__()
        self.model = model
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.importance = importance
        self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad}
        self._means = {}
        self._precision_matrices = {}
        self.device = next(model.parameters()).device
        self._initialize_means_and_precision_matrices()
        self._calculate_importance()

    def _initialize_means_and_precision_matrices(self):
        for n, p in tqdm(self.params.items()):
            self._means[n] = p.clone().detach().to(self.device)
            self._precision_matrices[n] = p.clone().detach().fill_(0).to(self.device)

    def _calculate_importance(self):
        self.model.eval()
        for i in tqdm(range(len(self.dataset))):
            self.model.zero_grad()
            data = self.dataset[i]
            inputs = self.tokenizer(data['sentence'], return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            labels = torch.tensor([data['label']]).to(self.device)
            outputs = self.model(**inputs, labels=labels)
            loss = outputs.loss
            loss.backward()

            for n, p in tqdm(self.model.named_parameters()):
                self._precision_matrices[n].data += p.grad.data ** 2 / len(self.dataset)

    def penalty(self):
        loss = 0
        for n, p in self.model.named_parameters():
            _loss = self._precision_matrices[n] * (p.to(self.device) - self._means[n]) ** 2
            loss += _loss.sum()
        return self.importance * loss

    def update(self):
        for n, p in tqdm(self.model.named_parameters()):
            self._means[n] = p.clone().detach().to(self.device)

class EWCTrainer(Trainer):
    def __init__(self, ewc, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.ewc = ewc

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss
        ewc_loss = self.ewc.penalty()
        total_loss = loss + ewc_loss
        return (total_loss, outputs) if return_outputs else total_loss

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": (predictions == labels).mean()}

def main():
    # Load dataset
    dataset = load_dataset("glue", "sst2")
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

    def tokenize_function(examples):
        return tokenizer(examples["sentence"], padding="max_length", truncation=True)

    tokenized_datasets = dataset.map(tokenize_function, batched=True)
    tokenized_datasets = tokenized_datasets.remove_columns(['sentence', 'idx'])
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    tokenized_datasets.set_format("torch")

    # Load pre-trained model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(device)

    # Initialize EWC
    ewc = EWC(model, dataset["train"].select(range(100)), tokenizer)  # Use a subset for EWC initialization

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=3,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir="./logs",
    )

    # Initialize EWCTrainer
    trainer = EWCTrainer(
        ewc=ewc,
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        compute_metrics=compute_metrics,
    )

    # Evaluate the model
    eval_results = trainer.evaluate()
    print(eval_results)

    # Train the model
    trainer.train()

    # Update EWC means after training
    ewc.update()

    # Evaluate the model
    eval_results = trainer.evaluate()
    print(eval_results)

if __name__ == "__main__":
    main()

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'eval_loss': 0.7311079502105713, 'eval_accuracy': 0.4908256880733945, 'eval_runtime': 6.3536, 'eval_samples_per_second': 137.246, 'eval_steps_per_second': 17.156}


Step,Training Loss
500,0.4634
1000,0.4037
1500,0.3705
2000,0.368
2500,0.3559
3000,0.3372
3500,0.3484
4000,0.3169
4500,0.3317
5000,0.3268


{'eval_loss': 0.3378593325614929, 'eval_accuracy': 0.9185779816513762, 'eval_runtime': 6.4433, 'eval_samples_per_second': 135.335, 'eval_steps_per_second': 16.917, 'epoch': 3.0}
