## Fine-Tune Mixture of Experts (Switch Transnformer) for Text Summarization Task

Credits: This notebook has been adapted from https://github.com/mlabonne and [@abhimishra91](https://github.com/abhimishra91) that can be found [here](https://github.com/abhimishra91/transformers-tutorials/blob/master/transformers_summarization_wandb.ipynb)

## Install dependencies

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git@main
!pip install -q accelerate
!pip install -q datasets
!pip install wandb -q
!pip install sentencepiece -q
%mkdir output

## Import Libraries

In [4]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import T5Tokenizer, T5ForConditionalGeneration
import wandb

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [7]:
# Creating a custom dataset for reading the dataframe and loading it into the dataloader
# to pass it to the neural network at a later stage for finetuning the model and to prepare it for predictions

class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, source_len, summ_len):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.source_len = source_len
        self.summ_len = summ_len
        self.context = self.data["document"]
        self.summaries = self.data["summary"]

    def __len__(self):
        return len(self.context)

    def __getitem__(self, index):
        context = self.context[index]
        summary = self.summaries[index]

        source = self.tokenizer.batch_encode_plus([context], max_length= self.source_len, pad_to_max_length=True,return_tensors='pt')
        target = self.tokenizer.batch_encode_plus([summary], max_length= self.summ_len, pad_to_max_length=True,return_tensors='pt')

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long),
            'source_mask': source_mask.to(dtype=torch.long),
            'target_ids': target_ids.to(dtype=torch.long),
            'target_ids_y': target_ids.to(dtype=torch.long)
        }

## Training and validation loop

Below is the definition of our main training loop

In [8]:
# Creating the training function. This will be called in the main function. It is run depending on the epoch value.
# The model is put into train mode and then we wnumerate over the training loader and passed to the defined network

def train(epoch, tokenizer, model, device, loader, optimizer):
    model.train()
    for _, data in enumerate(loader, 0):
        labels = data['target_ids'].to(device, dtype = torch.long)
        labels = model._shift_right(labels) # shifter implemented in model for convenience
        # We set the pad tokens (0) to -100 to be ignored by the CrossEntropy loss
        labels = labels.masked_fill_(labels == 0, -100)
        ids = data['source_ids'].to(device, dtype = torch.long)
        mask = data['source_mask'].to(device, dtype = torch.long)
        decoder_input_ids = torch.zeros_like(labels).long()

        outputs = model(input_ids = ids, attention_mask = mask, labels=labels, output_router_logits=True, return_dict=True)
        loss = outputs[0]

        if _%10 == 0:
            wandb.log({"Training Loss": loss.item()})
            wandb.log({"Training Encoder z-Loss": outputs.encoder_z_loss.item()})
            wandb.log({"Training Encoder aux-Loss": outputs.encoder_aux_loss.item()})
            wandb.log({"Training Decoder z-Loss": outputs.decoder_z_loss.item()})
            wandb.log({"Training Decoder aux-Loss": outputs.decoder_aux_loss.item()})

        if _%500==0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')

        if (_ + 1) %2000==0:
          break

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

And validation loop!

In [9]:
def validate(epoch, tokenizer, model, device, loader):
    model.eval()
    predictions = []
    actuals = []
    with torch.no_grad():
        for _, data in enumerate(loader, 0):
            y = data['target_ids'].to(device, dtype = torch.long)
            ids = data['source_ids'].to(device, dtype = torch.long)
            mask = data['source_mask'].to(device, dtype = torch.long)

            generated_ids = model.generate(
                input_ids = ids,
                attention_mask = mask,
                max_length=16,
                num_beams=2,
                repetition_penalty=2.5,
                length_penalty=1.0,
                early_stopping=True
                )
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
            target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
            if _%100==0:
                print(f'Completed {_}')
                break

            predictions.extend(preds)
            actuals.extend(target)
    return predictions, actuals

## Main loop

Below is the main script for training and validation, the trained model will be saved after `TRAIN_EPOCHS` iterations!

In [10]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig, SwitchTransformersForConditionalGeneration


# WandB – Initialize a new run
run = wandb.init(project="moe_switch_transformers")

# WandB – Config is a variable that holds and saves hyperparameters and inputs
# Defining some key variables that will be used later on in the training
config = wandb.config          # Initialize config
config.TRAIN_BATCH_SIZE = 2    # input batch size for training (default: 64)
config.VALID_BATCH_SIZE = 2    # input batch size for testing (default: 1000)
config.TRAIN_EPOCHS = 3       # number of epochs to train (default: 10)
config.VAL_EPOCHS = 1
config.LEARNING_RATE = 1e-4    # learning rate (default: 0.01)
config.SEED = 42               # random seed (default: 42)
config.MAX_LEN = 52
config.SUMMARY_LEN = 16


# Set random seeds and deterministic pytorch for reproducibility
torch.manual_seed(config.SEED) # pytorch random seed
np.random.seed(config.SEED) # numpy random seed
torch.backends.cudnn.deterministic = True

[34m[1mwandb[0m: Currently logged in as: [33mshah_zeb_naveed[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
# tokenzier for encoding the text
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")

dataset = load_dataset("xsum")
def preprend(example):
    return {"document":["summarize: "+ x for x in example["document"]]}
encoded_dataset = dataset.map(preprend, batched=True)

# Creation of Dataset and Dataloader
# Defining the train size. So 80% of the data will be used for training and the rest will be used for validation.
train_dataset = encoded_dataset["train"]
val_dataset = encoded_dataset["validation"]

tokenizer_config.json:   0%|          | 0.00/2.35k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/2.05k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/954 [00:00<?, ?B/s]

Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

In [12]:
val_dataset[0]

{'document': 'summarize: The ex-Reading defender denied fraudulent trading charges relating to the Sodje Sports Foundation - a charity to raise money for Nigerian sport.\nMr Sodje, 37, is jointly charged with elder brothers Efe, 44, Bright, 50 and Stephen, 42.\nAppearing at the Old Bailey earlier, all four denied the offence.\nThe charge relates to offences which allegedly took place between 2008 and 2014.\nSam, from Kent, Efe and Bright, of Greater Manchester, and Stephen, from Bexley, are due to stand trial in July.\nThey were all released on bail.',
 'summary': 'Former Premier League footballer Sam Sodje has appeared in court alongside three brothers accused of charity fraud.',
 'id': '38295789'}

In [13]:
# Creating the Training and Validation dataset for further creation of Dataloader
training_set = CustomDataset(train_dataset, tokenizer, config.MAX_LEN, config.SUMMARY_LEN)
val_set = CustomDataset(val_dataset, tokenizer, config.MAX_LEN, config.SUMMARY_LEN)
val_set[0]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


{'source_ids': tensor([21603,    10,    37,  1215,    18, 19915,    53,     3, 13720, 11958,
         24283,  3415,  3991,     3,  8321,    12,     8,   264,    26,  1924,
          5716,  2941,     3,    18,     3,     9,  7813,    12,  3033,   540,
            21,  7904,    29,  2600,     5,  1363,   264,    26,  1924,     6,
          6862,     6,    19, 22801,  4977,    28, 17813, 10740,   262,    89,
            15,     1]),
 'source_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]),
 'target_ids': tensor([18263,  6552,  3815,  3370,    49,  3084,   264,    26,  1924,    65,
          4283,    16,  1614,  5815,   386,     1]),
 'target_ids_y': tensor([18263,  6552,  3815,  3370,    49,  3084,   264,    26,  1924,    65,
          4283,    16,  1614,  5815,   386,     1])}

In [14]:
tokenizer.decode(val_set[0]['target_ids_y'])

'Former Premier League footballer Sam Sodje has appeared in court alongside three</s>'

In [15]:
# Defining the parameters for creation of dataloaders
train_params = {
    'batch_size': config.TRAIN_BATCH_SIZE,
    'shuffle': True,
    'num_workers': 0
    }

val_params = {
    'batch_size': config.VALID_BATCH_SIZE,
    'shuffle': False,
    'num_workers': 0
    }

# Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
training_loader = DataLoader(training_set, **train_params)
val_loader = DataLoader(val_set, **val_params)

In [16]:
# Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary.
# Further this model is sent to device (GPU/TPU) for using the hardware.
model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16)
model = model.to(device)

# Defining the optimizer that will be used to tune the weights of the network in the training session.
optimizer = torch.optim.Adam(params =  model.parameters(), lr=config.LEARNING_RATE)

# Log metrics with wandb
wandb.watch(model, log="all")

# Training loop
print('Initiating Fine-Tuning for the model on our dataset')

for epoch in range(config.TRAIN_EPOCHS):
    train(epoch, tokenizer, model, device, training_loader, optimizer)

config.json:   0%|          | 0.00/1.86k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Initiating Fine-Tuning for the model on our dataset
Epoch: 0, Loss:  16.965253829956055
Epoch: 0, Loss:  5.750101566314697
Epoch: 0, Loss:  4.125947952270508
Epoch: 0, Loss:  5.532607078552246
Epoch: 1, Loss:  4.291008472442627
Epoch: 1, Loss:  3.6450307369232178
Epoch: 1, Loss:  4.662201404571533
Epoch: 1, Loss:  4.721072196960449
Epoch: 2, Loss:  4.152870178222656
Epoch: 2, Loss:  3.409647226333618
Epoch: 2, Loss:  3.3837666511535645
Epoch: 2, Loss:  4.019887447357178


In [17]:
# Validation loop and saving the resulting file with predictions and acutals in a dataframe.
# Saving the dataframe as predictions.csv
print('Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe')
for epoch in range(config.VAL_EPOCHS):
    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})
    final_df.to_csv('./output/predictions.csv')
    try:
        run.log({'t1': wandb.Table(final_df)})
    except Exception as e:
        print(e)
        
    print('Output Files generated for review')

Now generating summaries on our fine tuned model for the validation dataset and saving it in a dataframe
Completed 0
columns argument expects a `list` object
Output Files generated for review


In [18]:
eval_df = pd.read_csv('./output/predictions.csv')
eval_df

Unnamed: 0.1,Unnamed: 0,Generated Text,Actual Text


# Test the model

In [11]:
from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration

trained_model = SwitchTransformersForConditionalGeneration.from_pretrained("shahzebnaveed/moe_switch_transformer_summarization", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("shahzebnaveed/moe_switch_transformer_summarization")

config.json:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/20.7k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

## Generate some text!

In [26]:
text = "summarize: Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital. Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well. Therefore, Peter stayed with her at the hospital for 3 days without leaving."
input_ids = tokenizer(text, return_tensors="pt").input_ids
output_ids = trained_model.generate(input_ids)

#print(tokenizer.decode(output_ids[0], decoder_input_ids=[0], skip_special_tokens=False))

In [28]:
repo_name = "shahzebnaveed/moe_switch_transformer_summarization"
trained_model.push_to_hub(repo_name)

model.safetensors:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/shahzebnaveed/moe_switch_transformer_summarization/commit/f0c22c764ef60aeed6cfedb0c6c8058554294578', commit_message='Upload SwitchTransformersForConditionalGeneration', commit_description='', oid='f0c22c764ef60aeed6cfedb0c6c8058554294578', pr_url=None, pr_revision=None, pr_num=None)

In [10]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
tokenizer.push_to_hub("shahzebnaveed/moe_switch_transformer_summarization")

tokenizer_config.json:   0%|          | 0.00/2.35k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.18k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/shahzebnaveed/moe_switch_transformer_summarization/commit/d13f0b8cd9806bbf049d1c0aab3ca2ad7e50bce5', commit_message='Upload tokenizer', commit_description='', oid='d13f0b8cd9806bbf049d1c0aab3ca2ad7e50bce5', pr_url=None, pr_revision=None, pr_num=None)