# Text Classification - Direct Finetune Soft MoE (augmented data)

----

## $\color{blue}{Sections:}$
* Preamble
* Admin - importing libraries
* Load - Loading our data from pandas
* Dataset - Create PyTorch Dataset
* Model - Create PyTorch direct finetune soft MoE model
* Helper - Training helper functions
* Training - Training Loop


## $\color{blue}{Preamble:}$

This notebook will upload our embedding model, add a soft 3 expert MoE classification head, then train the complete network together. This version of the model is trained using an augmented dataset.

## $\color{blue}{Admin:}$

In [None]:
from google.colab import drive

In [None]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive/'


Mounted at /content/drive
/content/drive/MyDrive


In [None]:
%%capture
!pip install torch
!pip install dill

In [None]:
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
import os
from getpass import getpass
from huggingface_hub import login

# Prompt for your Hugging Face token securely
token = getpass("Please enter your Hugging Face token: ")

Please enter your Hugging Face token: ··········


In [None]:
# Use the token for Hugging Face login
if token:
    print("HuggingFace token has been successfully entered.")
    login(token=token)
else:
    print("Continuing without Hugging Face login")

HuggingFace token has been successfully entered.


## $\color{blue}{Load:}$

In [None]:
import pandas as pd
path = "class/datasets/"
df_train = pd.read_pickle(path + "df_train_augmentation.1")
df_dev = pd.read_pickle(path + "df_dev_augmentation.1")
df_test = pd.read_pickle(path + "df_test_augmentation.1")

In [None]:
df_train.head()

Unnamed: 0,master,book_idx,chapter_idx,content,vanilla_embedding.1
0,Ulysses,0,0,"Halted, he peered down the dark winding stairs...","[-0.01852537, -0.021713095, 0.041504614, -0.00..."
1,Ulysses,0,0,"Then, catching sight of Stephen Dedalus, he be...","[-0.019168912, -0.0048065097, -0.012622914, -0..."
2,Ulysses,0,0,"Stephen Dedalus, displeased and sleepy, leaned...","[-0.025832051, -0.0060330997, -0.013755375, 0...."
3,Ulysses,0,0,he said sternly. He added in a preacher’s to...,"[-0.008437265, -0.011068143, 0.029162964, 0.00..."
4,Ulysses,0,0,He peered sideways up and gave a long slow whi...,"[-0.016204245, 0.015205742, 0.023865266, -0.01..."


In [None]:
df_train.shape

(20474, 5)

## $\color{blue}{Dataset:}$

In [None]:
# Custom Dataset
from torch.utils.data import Dataset, DataLoader

class CustomTextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        # Tokenization
        encoding = self.tokenizer(text,
                                  truncation=True,
                                  padding='max_length',
                                  max_length=self.max_length,
                                  return_tensors='pt')

        return {
            'input_ids': encoding['input_ids'].squeeze(0).to(device),
            'attention_mask': encoding['attention_mask'].squeeze(0).to(device),
            'labels': torch.tensor(label, dtype=torch.long).to(device)
        }

In [None]:
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")

train_texts = list(df_train['content'])
train_labels = list(df_train['chapter_idx'])
dev_texts = list(df_dev['content'])
dev_labels = list(df_dev['chapter_idx'])
test_texts = list(df_test['content'])
test_labels = list(df_test['chapter_idx'])

train_dataset = CustomTextDataset(train_texts, train_labels, tokenizer)
dev_dataset = CustomTextDataset(dev_texts, dev_labels, tokenizer)
test_dataset = CustomTextDataset(test_texts, test_labels, tokenizer)

## $\color{blue}{Model:}$

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DenseBlock(nn.Module):
    def __init__(self, input_size, output_size, dropout_rate):
        super(DenseBlock, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.batch_norm = nn.BatchNorm1d(output_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.linear(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        x = self.dropout(x)
        return x

class FeedForwardExpert(nn.Module):
    def __init__(self, dropout_rate):
        super(FeedForwardExpert, self).__init__()

        # Define the dense blocks
        self.block1 = DenseBlock(768, 400, dropout_rate)
        self.block2 = DenseBlock(400, 200, dropout_rate)
        self.final_layer = nn.Linear(200, 70)

        self.initialize_weights()

    def forward(self, x):
        x = self.block1(x)  # Bx768 -> Bx400
        x = self.block2(x)  # Bx400 -> Bx200
        x = self.final_layer(x)  # Bx200 -> Bx70

        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts, dropout_rate=0.1, gate_hidden_size = 128):
      super(MixtureOfExperts, self).__init__()
      self.dropout = dropout_rate
      self.num_experts = num_experts
      self.gate_hidden_size = gate_hidden_size

      # Create a list of feedforward experts
      self.experts = nn.ModuleList([FeedForwardExpert(self.dropout) for _ in range(self.num_experts)])

      # A gating network
      self.gate_fc1 = nn.Linear(768, self.gate_hidden_size)
      self.gate_fc2 = nn.Linear(self.gate_hidden_size, self.num_experts) # Output is num_experts for gating

    def forward(self, x):

      # Calculate gating weights
      gate_hidden = F.relu(self.gate_fc1(x))
      weights = F.softmax(self.gate_fc2(gate_hidden), dim=1).unsqueeze(2) #Shape (batch_size, num_experts, 1)

      # Get outputs from all experts
      outputs = torch.stack([expert(x) for expert in self.experts], dim=2) # Shape: (batch_size, output_size, num_experts)

      # apply weights using a batch mm
      weighted_outputs = torch.bmm(outputs, weights).squeeze(2)

      return weighted_outputs # Shape: (batch_size, output_size)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

def average_pool(last_hidden_states, attention_mask):
    # average the token embeddings
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

class CustomModel(nn.Module):
    def __init__(self, base_model, classifier):
        super().__init__()
        self.base_model = base_model
        self.classifier = classifier

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        pooled_output = average_pool(last_hidden_state, attention_mask)  # Ensure to define this function appropriately
        normalized_output = F.normalize(pooled_output, p=2, dim=1)
        logits = self.classifier(normalized_output)
        return logits


In [None]:
base_model = AutoModel.from_pretrained("thenlper/gte-base")
classifier = MixtureOfExperts(3,.15)
# path_moe = "class/models/vanilla_moe_e2e_soft.pt"
# classifier.load_state_dict(torch.load(path_moe))
model = CustomModel(base_model,classifier).to(device)
sum(p.numel() for p in model.parameters() if p.requires_grad)

config.json:   0%|          | 0.00/618 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/219M [00:00<?, ?B/s]

110790269

## $\color{blue}{Helper:}$

In [None]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [None]:
import numpy as np

def train(model, train_loader, criterion, optimizer, scheduler):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()

        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        y = batch['labels']

        out = model(input_ids, attention_mask)
        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)

        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

        # Update Learning Rate
        scheduler.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [None]:
def validate(model, dev_loader, criterion):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(dev_loader):

            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            y = batch['labels']

            out = model(input_ids, attention_mask)
            dev_loss = criterion(out, y)
            dev_accuracy = accuracy(out, y)

            dev_loss = criterion(out, y)
            dev_accuracy = accuracy(out, y)

            epoch_dev_losses.append(dev_loss.item())
            epoch_dev_accuracy.append(dev_accuracy)

    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

Test that the combined embedding model and pretrained classifier, gives the same results as embedding then passing those to the classifier.

In [None]:
dev_loader = DataLoader(dev_dataset, batch_size=len(dev_dataset))
criterion = nn.CrossEntropyLoss()
dev_loss, dev_accuracy = validate(model,dev_loader, criterion)
print(f'Dev loss: {dev_loss}')
print(f'Dev accuracy: {dev_accuracy}')

Dev loss: 4.24925422668457
Dev accuracy: 0.01742627345844504


In [None]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'bs',
    'lr',
    'alpha',
    'max_accuracy'
])

In [None]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

## $\color{blue}{Training:}$

In [None]:
def tv_run(epochs, model, bs, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  if bs > 64:
    bs = 64

  if bs < 16:
    bs = 16
  print('\n ##### \n')
  print('\n ##### \n')
  print('\n ##### \n')
  print(' Run Starting')
  print('bs', bs)
  print('lr', lr)
  print('alpha', alpha)
  # Prepare data loaders
  train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
  dev_loader = DataLoader(dev_dataset, batch_size=bs)

  # Set up new model
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  total_steps = len(train_loader) * epochs
  warmups = total_steps // 12 # 8%

  scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmups,  # Proportion of the training to perform learning rate warmup
    num_training_steps=total_steps
  )



  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0

  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 2:
      break

    # call training and validation functions
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, scheduler)
    dev_loss, dev_acc = validate(model, dev_loader, criterion)

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)

    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc

    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print('\n ---------- \n ')
      print('\n ---------- \n ')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

  # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      bs, lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'bs:{stats.bs}, lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return stats

In [None]:
"""
Main Admin
"""
epochs = 15
max_accuracy = 0
path = "class/models/direct_ft_moe_augmented.pt"
results = []

for lr in [0.00004]:
  for bs in [64]:
    for alpha in [0.012]:


      # define model
      model = CustomModel(base_model, classifier) # model with dropout
      model = model.to(device)

      # run training
      res = tv_run(epochs, model, bs, lr, alpha, max_accuracy, path, verbose = 2)
      max_accuracy = res.max_accuracy
      results.append(res)

      # get best result of the round or even so far
      stats = search_stats(results)
      print(stats) # debug


 ##### 


 ##### 


 ##### 

 Run Starting
bs 64
lr 4e-05
alpha 0.012

 --------- 
Epoch: 1


 ---------- 
 

 ---------- 
 
Epoch 1 train loss: 3.7698
Epoch 1 train accuracy: 0.1458
Epoch 1 dev loss: 2.4627
Epoch 1 dev accuracy: 0.3929

 --------- 
Epoch: 2


 ---------- 
 

 ---------- 
 
Epoch 2 train loss: 2.1705
Epoch 2 train accuracy: 0.4796
Epoch 2 dev loss: 1.6583
Epoch 2 dev accuracy: 0.5831

 --------- 
Epoch: 3


 ---------- 
 

 ---------- 
 
Epoch 3 train loss: 1.3368
Epoch 3 train accuracy: 0.6916
Epoch 3 dev loss: 1.3800
Epoch 3 dev accuracy: 0.6262

 --------- 
Epoch: 4


 ---------- 
 

 ---------- 
 
Epoch 4 train loss: 0.7390
Epoch 4 train accuracy: 0.8435
Epoch 4 dev loss: 1.3524
Epoch 4 dev accuracy: 0.6381

 --------- 
Epoch: 5


 ---------- 
 

 ---------- 
 
Epoch 5 train loss: 0.3311
Epoch 5 train accuracy: 0.9319
Epoch 5 dev loss: 1.4071
Epoch 5 dev accuracy: 0.6563

 --------- 
Epoch: 6


 ---------- 
 

 ---------- 
 
Epoch 6 train loss: 0.1286
Epoch 6 trai

In [None]:
import dill
def save_results_to_file(namedtuples, filename):
    """Saves a list of namedtuples to a specified file using dill."""
    with open(filename, 'wb') as f:
        dill.dump(namedtuples, f)

def load_results_from_file(filename):
    """Loads a list of namedtuples from a specified file using dill."""
    with open(filename, 'rb') as f:
        return dill.load(f)

In [None]:
path = 'class/results/'
save_results_to_file(results, path + 'direct_ft_moe_augmented.pk')

In [None]:
model

CustomModel(
  (base_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [None]:
path = "class/models/direct_ft_moe_augmented.pt"
model.load_state_dict(torch.load(path))

<All keys matched successfully>

In [None]:
base_model = model.base_model
classification_model = model.classifier

In [None]:
torch.save(base_model.state_dict(), "class/models/direct_ft_augmented_embedding_model.pt")
torch.save(classification_model.state_dict(), "class/models/direct_ft_augmented_soft_moe_model.pt")