# Text Classification - Vanilla Mixture of Experts (Hard, Pretrained)

----

## $\color{blue}{Sections:}$
* Preamble
* Admin - importing libraries
* Load - Loading our data from pandas
* Dataset - Create PyTorch Dataset
* Model - Create PyTorch Vanilla Model
* Helper - Training helper functions
* Training - Training Loop


## $\color{blue}{Preamble:}$

This notebook creates a vanilla mixture of experts model. We pre train experts for each author, then build and train a hard MoE classifier with these experts. The weights of the experts are frozen during the MoE classifier.

## $\color{blue}{Admin:}$

In [1]:
from google.colab import drive

In [2]:
drive.mount("/content/drive")
%cd '/content/drive/MyDrive/'


Mounted at /content/drive
/content/drive/MyDrive


In [3]:
%%capture
!pip install torch
!pip install dill

In [4]:
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## $\color{blue}{Load:}$

In [5]:
import pandas as pd
path = "class/datasets/"
df_train = pd.read_pickle(path + "df_train")
df_dev = pd.read_pickle(path + "df_dev")
df_test = pd.read_pickle(path + "df_test")

In [None]:
df_train.head()

Unnamed: 0,index,master,book_idx,book,chapter_idx,chapter,author,content,vanilla_embedding
8114,8114,Dubliners,3,Dubliners,31,GRACE,Joyce,“Is it John of Tuam?” “Are you sure of that ...,"[-0.012913608, -0.026916211, 0.0023321153, -0...."
4951,4951,Ulysses,2,Nostos,15,Eumaeus,Joyce,sibly there were several others. He personally...,"[-0.019626686, -0.035692617, -0.034875672, 0.0..."
4629,4629,Ulysses,2,Nostos,15,Eumaeus,Joyce,"Stephen, who was trying his dead best to yawn ...","[0.015934143, -0.0034991587, 0.0035751674, 0.0..."
11556,11556,Dracula,4,Dracula,59,CHAPTER XXVII: MINA HARKER’S JOURNAL,Bram Stoker,"Now to the historical, for as Madam Mina write...","[-4.009433e-05, -0.0041142944, 0.026873538, -0..."
12262,12262,Republic,5,Republic,62,Book III,Plato,The harmonies which you mean are the mixed or ...,"[0.0048890463, -0.0060007297, 0.0054147574, -0..."


### $\color{red}{Subset:}$

subset the training and the dev DF into Joyce / Stoker / Plato

In [6]:
df_train_joyce = df_train[df_train['book_idx'].isin([0,1,2,3])]
df_train_stoker = df_train[df_train['book_idx'] == 4]
df_train_plato = df_train[df_train['book_idx'] == 5]

df_dev_joyce = df_dev[df_dev['book_idx'].isin([0,1,2,3])]
df_dev_stoker = df_dev[df_dev['book_idx'] == 4]
df_dev_plato = df_dev[df_dev['book_idx'] == 5]

## $\color{blue}{Dataset:}$

In [7]:
train_embeddings = [torch.tensor(array) for array in df_train['vanilla_embedding']]
train_x = torch.stack(train_embeddings).to(device)

dev_embeddings = [torch.tensor(array) for array in df_dev['vanilla_embedding']]
dev_x = torch.stack(dev_embeddings).to(device)

test_embeddings = [torch.tensor(array) for array in df_test['vanilla_embedding']]
test_x = torch.stack(test_embeddings).to(device)

In [8]:
# train_y = torch.LongTensor(list(df_train['book_idx'])).to(device)
# dev_y = torch.LongTensor(list(df_dev['book_idx'])).to(device)
# test_y = torch.LongTensor(list(df_test['book_idx'])).to(device)

train_y = torch.LongTensor(list(df_train['chapter_idx'])).to(device)
dev_y = torch.LongTensor(list(df_dev['chapter_idx'])).to(device)
test_y = torch.LongTensor(list(df_test['chapter_idx'])).to(device)

In [9]:
from torch.utils.data import Dataset, DataLoader
# assuming already tensors, allready on device
class VanillaDataset(Dataset):
  """Dataset maker"""

  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __getitem__(self,index):
    x = self.x[index]
    y = self.y[index]

    return x, y

  def __len__(self):
    return len(self.y)


In [10]:
train_dataset = VanillaDataset(train_x, train_y)
dev_dataset = VanillaDataset(dev_x, dev_y)
test_dataset = VanillaDataset(test_x, test_y)

In [11]:
train_dataset[0][0].size()

torch.Size([768])

### $\color{red}{Joyce:}$

In [12]:
# get X
train_embeddings_joyce = [torch.tensor(array) for array in df_train_joyce['vanilla_embedding']]
train_x_joyce = torch.stack(train_embeddings_joyce).to(device)

dev_embeddings_joyce = [torch.tensor(array) for array in df_dev_joyce['vanilla_embedding']]
dev_x_joyce = torch.stack(dev_embeddings_joyce).to(device)

# get y
train_y_joyce = torch.LongTensor(list(df_train_joyce['chapter_idx'])).to(device)
dev_y_joyce = torch.LongTensor(list(df_dev_joyce['chapter_idx'])).to(device)


# create pytorch set
train_dataset_joyce = VanillaDataset(train_x_joyce, train_y_joyce)
dev_dataset_joyce = VanillaDataset(dev_x_joyce, dev_y_joyce)

### $\color{red}{Stoker:}$

In [13]:
# get X
train_embeddings_stoker = [torch.tensor(array) for array in df_train_stoker['vanilla_embedding']]
train_x_stoker = torch.stack(train_embeddings_stoker).to(device)

dev_embeddings_stoker = [torch.tensor(array) for array in df_dev_stoker['vanilla_embedding']]
dev_x_stoker = torch.stack(dev_embeddings_stoker).to(device)

# get y
train_y_stoker = torch.LongTensor(list(df_train_stoker['chapter_idx'])).to(device)
dev_y_stoker = torch.LongTensor(list(df_dev_stoker['chapter_idx'])).to(device)


# create pytorch set
train_dataset_stoker = VanillaDataset(train_x_stoker, train_y_stoker)
dev_dataset_stoker = VanillaDataset(dev_x_stoker, dev_y_stoker)

### $\color{red}{Plato:}$

In [14]:
# get X
train_embeddings_plato = [torch.tensor(array) for array in df_train_plato['vanilla_embedding']]
train_x_plato = torch.stack(train_embeddings_plato).to(device)

dev_embeddings_plato = [torch.tensor(array) for array in df_dev_plato['vanilla_embedding']]
dev_x_plato = torch.stack(dev_embeddings_plato).to(device)

# get y
train_y_plato = torch.LongTensor(list(df_train_plato['chapter_idx'])).to(device)
dev_y_plato = torch.LongTensor(list(df_dev_plato['chapter_idx'])).to(device)


# create pytorch set
train_dataset_plato = VanillaDataset(train_x_plato, train_y_plato)
dev_dataset_plato = VanillaDataset(dev_x_plato, dev_y_plato)

## $\color{blue}{Expert-Model:}$

In [13]:
import torch.nn as nn
import torch.nn.functional as F

class DenseBlock(nn.Module):
    def __init__(self, input_size, output_size, dropout_rate):
        super(DenseBlock, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.batch_norm = nn.BatchNorm1d(output_size)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.linear(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        x = self.dropout(x)
        return x

class FeedForwardExpert(nn.Module):
    def __init__(self, output_size, dropout_rate):
        super(FeedForwardExpert, self).__init__()
        self.output_size = output_size

        # Define the dense blocks
        self.block1 = DenseBlock(768, 400, dropout_rate)
        self.block2 = DenseBlock(400, 200, dropout_rate)
        self.final_layer = nn.Linear(200, self.output_size)

        self.initialize_weights()

    def forward(self, x):
        x = self.block1(x)  # Bx768 -> Bx400
        x = self.block2(x)  # Bx400 -> Bx50
        x = self.final_layer(x)  # Bx50 -> Bx6
        return x

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


In [14]:
model = FeedForwardExpert(70,.1)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

403070

## $\color{blue}{Helper:}$

In [15]:
def accuracy(outputs, labels):
    # argmax to get predicted classes
    _, predicted = torch.max(outputs, 1)

    # count correct
    correct = (predicted == labels).sum().item()

    # get average
    acc = correct / labels.size(0)  # Total number of samples
    return acc

In [16]:
import numpy as np

def train(model, train_loader, criterion, optimizer):
    model.train()
    epoch_train_losses = []
    epoch_train_accuracy = []

    for batch_idx, (x, y) in enumerate(train_loader):

        optimizer.zero_grad()

        out = model(x)
        train_loss = criterion(out, y)
        train_accuracy = accuracy(out, y)

        epoch_train_losses.append(train_loss.item())
        epoch_train_accuracy.append(train_accuracy)

        # Backpropagation and optimization
        train_loss.backward()
        optimizer.step()

    return np.mean(epoch_train_losses), np.mean(epoch_train_accuracy)

In [17]:
def validate(model, dev_loader, criterion):
    model.eval()
    epoch_dev_losses = []
    epoch_dev_accuracy = []

    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(dev_loader):
            out = model(x)

            dev_loss = criterion(out, y)
            dev_accuracy = accuracy(out, y)

            epoch_dev_losses.append(dev_loss.item())
            epoch_dev_accuracy.append(dev_accuracy)

    return np.mean(epoch_dev_losses), np.mean(epoch_dev_accuracy)

In [18]:
from collections import namedtuple
Stats = namedtuple('Stats', [
    'train_loss',
    'train_accuracy',
    'dev_loss',
    'dev_accuracy',
    'epoch',
    'bs',
    'lr',
    'alpha',
    'max_accuracy'
])

In [19]:
def gen_config(lr_low, lr_high, alpha_low, alpha_high, b_size, b_step):
  bs_list = [b_size - b_step, b_size, b_size + b_step]
  bs = int(2**np.random.choice(bs_list))
  lr = round(10**float(np.random.uniform(lr_low,lr_high)),6)
  alpha = round(10**float(np.random.uniform(alpha_low,alpha_high)),6)
  return lr, alpha, bs

In [20]:
def gen_ranges( lr, lr_range, alpha, alpha_range, b_size, iteration):

  lr_center = lr
  lr_low = lr_center - lr_range/2
  lr_high = lr_center + lr_range/2
  lr_diff = lr_high - lr_low

  alpha_center = alpha
  alpha_low = alpha_center - alpha_range/2
  alpha_high = alpha_center + alpha_range/2
  alpha_diff = alpha_high - alpha_low

  b_step = 2 - iteration

  return (lr_low, lr_high, alpha_low, alpha_high, b_size, b_step)

In [21]:
def search_stats(results):
  best_stats = None
  max_dev_accuracy = 0
  for i in range(len(results)):
    acc = results[i].dev_accuracy
    if acc > max_dev_accuracy:
      best_stats = results[i]
      max_dev_accuracy = acc
  return best_stats

## $\color{blue}{Training:}$

In [22]:
def tv_run(epochs, model, train_data, dev_data, bs, lr, alpha, max_accuracy, path, verbose = 0):
  """
  Runs a training setup
  verbose == 1 - print model results
  verbose == 2 -> print epoch and model results
  """
  # Set up new model
  model = model.to(device)
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=alpha)

  # Prepare data loaders
  train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
  dev_loader = DataLoader(dev_data, batch_size=bs)

  # Hold epoch stats
  train_losses = []
  train_accuracy = []
  dev_losses = []
  dev_accuracy = []
  epoch_holder = []

  # Break if no improvement
  current_best = 0
  no_improvement = 0


  # Run epochs
  for epoch in range(epochs):

    # break out of epochs
    if no_improvement >= 5:
      break

    # call training and validation functions
    train_loss, train_acc = train(model, train_loader, criterion, optimizer)
    dev_loss, dev_acc = validate(model, dev_loader, criterion)

    # Store epoch stats
    train_losses.append(train_loss)
    train_accuracy.append(train_acc)
    dev_losses.append(dev_loss)
    dev_accuracy.append(dev_acc)
    epoch_holder.append(epoch + 1)

    # check for improvement
    if dev_acc > current_best:
      current_best = dev_acc
      no_improvement = 0
    else:
      no_improvement += 1

    # save best model
    if dev_acc > max_accuracy:
      torch.save(model.state_dict(), path)
      max_accuracy = dev_acc

    # optionally print epoch results
    if verbose == 2:
      print(f'\n --------- \nEpoch: {epoch + 1}\n')
      print(f'Epoch {epoch + 1} train loss: {train_loss:.4f}')
      print(f'Epoch {epoch + 1} train accuracy: {train_acc:.4f}')
      print(f'Epoch {epoch + 1} dev loss: {dev_loss:.4f}')
      print(f'Epoch {epoch + 1} dev accuracy: {dev_acc:.4f}')

  # save best results
  max_ind = np.argmax(dev_accuracy)

  stats = Stats(
      train_losses[max_ind],
      train_accuracy[max_ind],
      dev_losses[max_ind],
      dev_accuracy[max_ind],
      epoch_holder[max_ind],
      bs, lr, alpha,
      max_accuracy
  )

  # optionally print model results
  if verbose in [1,2]:
    print('\n ######## \n')
    print(f'bs:{stats.bs}, lr:{stats.lr}, alpha:{stats.alpha} @ epoch {stats.epoch}.')
    print(f'TL:{stats.train_loss}, TA:{stats.train_accuracy}.')
    print(f'DL:{stats.dev_loss}, DA:{stats.dev_accuracy}')

  return stats

### $\color{red}{Joyce:}$

In [None]:
"""
Main Admin
"""
epochs = 40
max_accuracy = 0
path = "class/models/vanilla_moe_joyce_expert.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -1
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -1
alpha_range = alpha_high - alpha_low

b_size = 5
b_step = 2

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print(f'round: {i}')
  print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print(f'b_size{b_size}')
  print(f'b_step{b_step}')
  print('max', max_accuracy)

  for j in range(27):
    count += 1
    print(count)

    # get config
    lr, alpha, bs = gen_config(lr_low, lr_high, alpha_low, alpha_high, b_size, b_step)

    # define model
    model = FeedForwardExpert(70,.1) # model with dropout
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, train_dataset_joyce, dev_dataset_joyce, bs, lr, alpha, max_accuracy, path, verbose = 0)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  bs = np.log2(stats.bs)

  config = gen_ranges(lr, lr_range, alpha, alpha_range, bs, i + 1)
  lr_low, lr_high, alpha_low, alpha_high, b_size, b_step = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low


round: 0
lr_low-5, lr_high-1, lr_range4
alpha_low-5, lr_high-1, lr_range4
b_size5
b_step2
max 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Stats(train_loss=0.09397064116701745, train_accuracy=0.9803444910352805, dev_loss=1.1934909105300904, dev_accuracy=0.6994318181818182, epoch=6, bs=128, lr=0.002514, alpha=0.004611, max_accuracy=0.6994318181818182)
round: 1
lr_low-3.2663013933167275, lr_high-1.9329680599833945, lr_range1.333333333333333
alpha_low-3.002871544447259, lr_high-1.669538211113926, lr_range1.333333333333333
b_size7.0
b_step1
max 0.6994318181818182
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Stats(train_loss=0.019017700971872137, train_accuracy=0.9994517543859649, dev_loss=1.2071991205215453, dev_accuracy=0.6995738636363636, epoch=16, bs=128, lr=0.000862, alpha=0.002853, max_accuracy=0.6995738636363636)
round: 2
lr_low-3.2867149563975095, lr_high-2.842270511953065, lr_range0.44444444444444464
alpha_low-2.76692

### $\color{red}{Stoker:}$

In [None]:
"""
Main Admin
"""
epochs = 40
max_accuracy = 0
path = "class/models/vanilla_moe_stoker_expert.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -1
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -1
alpha_range = alpha_high - alpha_low

b_size = 5
b_step = 2

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print(f'round: {i}')
  print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print(f'b_size{b_size}')
  print(f'b_step{b_step}')
  print('max', max_accuracy)

  for j in range(27):
    count += 1
    print(count)

    # get config
    lr, alpha, bs = gen_config(lr_low, lr_high, alpha_low, alpha_high, b_size, b_step)

    # define model
    model = FeedForwardExpert(70,.1) # model with dropout
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, train_dataset_stoker, dev_dataset_stoker, bs, lr, alpha, max_accuracy, path, verbose = 0)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  bs = np.log2(stats.bs)

  config = gen_ranges(lr, lr_range, alpha, alpha_range, bs, i + 1)
  lr_low, lr_high, alpha_low, alpha_high, b_size, b_step = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low


round: 0
lr_low-5, lr_high-1, lr_range4
alpha_low-5, lr_high-1, lr_range4
b_size5
b_step2
max 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Stats(train_loss=0.8481314434223093, train_accuracy=0.7277559867877788, dev_loss=2.4918209676231657, dev_accuracy=0.42410714285714285, epoch=8, bs=8, lr=0.002578, alpha=4.4e-05, max_accuracy=0.42410714285714285)
round: 1
lr_low-3.2553837536492822, lr_high-1.9220504203159492, lr_range1.333333333333333
alpha_low-5.02321399018048, lr_high-3.6898806568471465, lr_range1.3333333333333335
b_size3.0
b_step1
max 0.42410714285714285
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Stats(train_loss=0.8030121854231881, train_accuracy=0.7488645747316268, dev_loss=2.2395271488598416, dev_accuracy=0.4330357142857143, epoch=7, bs=8, lr=0.000988, alpha=0.000153, max_accuracy=0.4330357142857143)
round: 2
lr_low-3.2274652776345945, lr_high-2.78302083319015, lr_range0.44444444444444464
alpha_low-4.03753079140

### $\color{red}{Plato:}$

In [None]:
"""
Main Admin
"""
epochs = 40
max_accuracy = 0
path = "class/models/vanilla_moe_plato_expert.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -1
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -1
alpha_range = alpha_high - alpha_low

b_size = 5
b_step = 2

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print(f'round: {i}')
  print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print(f'b_size{b_size}')
  print(f'b_step{b_step}')
  print('max', max_accuracy)

  for j in range(27):
    count += 1
    print(count)

    # get config
    lr, alpha, bs = gen_config(lr_low, lr_high, alpha_low, alpha_high, b_size, b_step)

    # define model
    model = FeedForwardExpert(70,.1) # model with dropout
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, train_dataset_plato, dev_dataset_plato, bs, lr, alpha, max_accuracy, path, verbose = 0)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  bs = np.log2(stats.bs)

  config = gen_ranges(lr, lr_range, alpha, alpha_range, bs, i + 1)
  lr_low, lr_high, alpha_low, alpha_high, b_size, b_step = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low


round: 0
lr_low-5, lr_high-1, lr_range4
alpha_low-5, lr_high-1, lr_range4
b_size5
b_step2
max 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Stats(train_loss=0.024264571722596884, train_accuracy=0.99951171875, dev_loss=1.5249077677726746, dev_accuracy=0.6223958333333333, epoch=8, bs=128, lr=0.00371, alpha=0.000111, max_accuracy=0.6223958333333333)
round: 1
lr_low-3.0972927570516204, lr_high-1.7639594237182874, lr_range1.333333333333333
alpha_low-4.621343687880009, lr_high-3.288010354546676, lr_range1.3333333333333335
b_size7.0
b_step1
max 0.6223958333333333
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Stats(train_loss=0.024264571722596884, train_accuracy=0.99951171875, dev_loss=1.5249077677726746, dev_accuracy=0.6223958333333333, epoch=8, bs=128, lr=0.00371, alpha=0.000111, max_accuracy=0.6223958333333333)
round: 2
lr_low-2.6528483126071762, lr_high-2.2084038681627316, lr_range0.44444444444444464
alpha_low-4.176899243435565

## $\color{blue}{MoE-Model:}$

In [34]:
# load joyce expert
model_joyce = FeedForwardExpert(70, dropout_rate=0.1)
path_joyce = "class/models/vanilla_moe_joyce_expert.pt"
model_joyce.load_state_dict(torch.load(path_joyce))

# load stoker expert
model_stoker = FeedForwardExpert(70, dropout_rate=0.1)
path_stoker = "class/models/vanilla_moe_stoker_expert.pt"
model_stoker.load_state_dict(torch.load(path_stoker))

# load plato model
model_plato = FeedForwardExpert(70, dropout_rate=0.1)
path_plato = "class/models/vanilla_moe_plato_expert.pt"
model_plato.load_state_dict(torch.load(path_plato))

  model_joyce.load_state_dict(torch.load(path_joyce))
  model_stoker.load_state_dict(torch.load(path_stoker))
  model_plato.load_state_dict(torch.load(path_plato))


<All keys matched successfully>

In [None]:
# freeze joyce expert
for param in model_joyce.parameters():
  param.requires_grad = False

# freeze stoker expert
for param in model_stoker.parameters():
  param.requires_grad = False

# freeze plato
for param in model_stoker.parameters():
  param.requires_grad = False

NameError: name 'model_joyce' is not defined

In [31]:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Router(nn.Module):
    def __init__(self, num_experts, temperature=2):
        super().__init__()
        self.num_experts = num_experts
        self.fc1 = nn.Linear(768, 128)
        self.fc2 = nn.Linear(128, self.num_experts)
        self.temperature = temperature

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x) / self.temperature

        if self.temperature > 1:
          self.temperature *= 0.99
        else:
          self.temperature = 1

        return F.softmax(x, dim=-1)

class MoE(nn.Module):
    def __init__(self, expert_joyce, expert_stoker, expert_plato, temperature=1.2, num_experts=3, output_size=70, dropout_rate=0.11, top_k=1):
        super().__init__()
        self.num_experts = num_experts
        self.dropout_rate = dropout_rate
        self.k = top_k
        self.output_size = output_size
        self.temperature = temperature
        self.experts = nn.ModuleList([expert_joyce, expert_stoker, expert_plato])
        self.router = Router(self.num_experts)

    def forward(self, x):
        # Get routing weights
        routing_weights = self.router(x)  # Shape (bs, num_experts)

        # Sample k experts according to the routing weights
        # Ensure sum of weights is 1 (needed condition for probabilities)
        routing_weights = F.normalize(routing_weights, p=1, dim=-1)

        # Get the indices of experts based on probabilities
        topk_indices = torch.multinomial(routing_weights, num_samples=self.k, replacement=False)
        topk_vals = routing_weights.gather(1, topk_indices)  # Get the probability values for selected experts
        topk_vals_sum = topk_vals.sum(dim=1, keepdim=True)
        topk_vals_normalized = topk_vals / topk_vals_sum  # Normalize values for each sample

        # Initialize an output tensor with zeros
        outputs = torch.zeros(x.size(0), self.output_size, device=x.device)  # Shape (bs, c)

        # Iterate through the experts
        for i in range(self.k):
            expert_indices = topk_indices[:, i]

            for j in range(self.num_experts):
                # Check if the expert j is selected in current batch
                expert_mask = (expert_indices == j)
                if expert_mask.any():
                    expert_weight = topk_vals_normalized[:, i].view(-1, 1) * expert_mask.float().view(-1, 1)

                    # Get output from the expert
                    expert_output = self.experts[j](x)  # Shape (bs, c)

                    # Multiply the output by the corresponding weights and sum up
                    outputs += expert_output * expert_weight  # Shape (bs, c)

        return outputs

In [None]:
"""
Main Admin
"""
epochs = 40
max_accuracy = 0
path = "class/models/vanilla_moe_hard_pre_free.pt"
results = []

"""
init random search
lr [10^-5 - 10^-1]
alpha [10^-5 - 10^-1]
bs [8, 32, 128]
"""
lr_low = -5
lr_high = -1
lr_range = lr_high - lr_low

alpha_low = -5
alpha_high = -1
alpha_range = alpha_high - alpha_low

b_size = 5
b_step = 2

count = 0

"""
Hyperparameter Search
"""

for i in range(3):
  # debug
  print(f'round: {i}')
  print(f'lr_low{lr_low}, lr_high{lr_high}, lr_range{lr_range}')
  print(f'alpha_low{alpha_low}, lr_high{alpha_high}, lr_range{alpha_range}')
  print(f'b_size{b_size}')
  print(f'b_step{b_step}')
  print('max', max_accuracy)

  for j in range(27):
    count += 1
    print(count)

    # get config
    lr, alpha, bs = gen_config(lr_low, lr_high, alpha_low, alpha_high, b_size, b_step)

    # define model
    model = MoE(model_joyce, model_stoker, model_plato, temperature=1.2) # model with dropout
    model = model.to(device)

    # run training
    res = tv_run(epochs, model, train_dataset, dev_dataset, bs, lr, alpha, max_accuracy, path, verbose = 0)
    max_accuracy = res.max_accuracy
    results.append(res)

  # get best result of the round or even so far
  stats = search_stats(results)


  print(stats) # debug

  # reconfigure the new hypers
  lr = np.log10(stats.lr)
  lr_range = lr_range / 3

  alpha = np.log10(stats.alpha)
  alpha_range = alpha_range / 3

  bs = np.log2(stats.bs)

  config = gen_ranges(lr, lr_range, alpha, alpha_range, bs, i + 1)
  lr_low, lr_high, alpha_low, alpha_high, b_size, b_step = config
  lr_range = lr_high - lr_low
  alpha_range = alpha_high - alpha_low

round: 0
lr_low-5, lr_high-1, lr_range4
alpha_low-5, lr_high-1, lr_range4
b_size5
b_step2
max 0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
Stats(train_loss=0.348825791100661, train_accuracy=0.9011666666666667, dev_loss=1.74107817680605, dev_accuracy=0.5574596774193549, epoch=6, bs=32, lr=0.000723, alpha=0.080893, max_accuracy=0.5574596774193549)
round: 1
lr_low-3.8075283693721356, lr_high-2.4741950360388025, lr_range1.333333333333333
alpha_low-1.75875572469474, lr_high-0.4254223913614067, lr_range1.3333333333333335
b_size5.0
b_step1
max 0.5574596774193549
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Stats(train_loss=0.348825791100661, train_accuracy=0.9011666666666667, dev_loss=1.74107817680605, dev_accuracy=0.5574596774193549, epoch=6, bs=32, lr=0.000723, alpha=0.080893, max_accuracy=0.5574596774193549)
round: 2
lr_low-3.3630839249276914, lr_high-2.9186394804832467, lr_range0.44444444444444464
alpha_low-1.314311280250295

In [None]:
import dill
def save_results_to_file(namedtuples, filename):
    """Saves a list of namedtuples to a specified file using dill."""
    with open(filename, 'wb') as f:
        dill.dump(namedtuples, f)

def load_results_from_file(filename):
    """Loads a list of namedtuples from a specified file using dill."""
    with open(filename, 'rb') as f:
        return dill.load(f)

In [None]:
path = 'class/results/'
save_results_to_file(results, path + 'vanilla_moe_hard_pre_free.pk')

57.4 fixed
with free experts let's see
