In [1]:

import warnings 
from dial2vec_mlx.model import DialogueTransformer
from dial2vec.data import get_sessions 

from mlx_embeddings import load
import mlx.nn as nn 
import mlx.core as mx
import mlx.optimizers as optim 

model_name = "answerdotai/ModernBERT-base"
model, tokenizer = load(model_name)


# Some libraries use deprecated functions, ignore to prevent spam. 
warnings.filterwarnings("ignore", category=DeprecationWarning) 


# get from https://drive.google.com/file/d/1KpxQGXg9gvH-2u21bAMykL5N-tpYU2Dr/view?usp=sharing
training_file = "datasets/doc2dial/train.tsv"
testing_file = "datasets/doc2dial/clustering_test.tsv"

model_config = {
    # Max amount of tokens, anything above or below is pruned. 
    "max_tokens" : 386,     # config.max_position_embeddings
    # Temperature for cosine distancing 
    "temperature" : 1.0, 
    # How many samples to use for training 
    "batch_train_size": 5, 
    # How many samples to use for testing 
    "batch_test_size" : 5,
    # How many layers should be frozen in BERT variants 
    # Paper specifies that it freezes bottom 6 (out of 12) so we get something close to the ~50%
    "freeze_upto" : 14, 
    "learning_rate" : 1e-5,
    "gradient_accumulation_steps" : 1.0 
}

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [2]:
features = get_sessions(training_file, tokenizer, model.config, **model_config)
clustering_test = get_sessions(testing_file, tokenizer, model.config, **model_config)

In [3]:
model.unfreeze()

for i in range(model_config["freeze_upto"]):
    model.modules()[0]["model"]["layers"][i].freeze()


In [4]:
import random 


# usage of mlx.data overloads ram and crashes OS on my experience 
random.shuffle(features)

# Don't need turn ids or segment ids in modernBERT AFAIK 
all_input_ids = mx.array([f.input_ids for f in features], dtype = mx.int64)
all_input_mask = mx.array([f.input_mask for f in features], dtype = mx.int64)
all_position_ids = mx.array([f.position_ids for f in features], dtype = mx.int64)
all_role_ids = mx.array([f.role_ids for f in features], dtype = mx.int64)
all_labels = mx.array([f.label_id for f in features],dtype = mx.int64)

sample_num = len(features)

In [5]:
# See: https://huggingface.co/transformers/v1.0.0/migration.html

# https://github.com/huggingface/transformers/blob/df99f8c5a1c54d64fb013b43107011390c3be0d5/transformers/optimization.py#L45
# https://github.com/pytorch/pytorch/blob/main/torch/optim/lr_scheduler.py
def linear_warmup_with_decay(step, warmup_steps, total_steps, peak_lr, min_lr = 0.0):
    if step < warmup_steps:
            return peak_lr * (step / warmup_steps)
    else: 
          decay_steps = total_steps - warmup_steps
          decay_ratio = (total_steps - step) / decay_steps
          return min_lr + (peak_lr - min_lr) * decay_ratio
    
def create_scheduler(peak_lr, warmup_steps, total_steps):
      def scheduler(step):
            return linear_warmup_with_decay(step, warmup_steps, total_steps, peak_lr)
      return scheduler 

In [None]:
from livelossplot import PlotLosses
from numpy import average

plotlosses = PlotLosses()

dial2vec = DialogueTransformer(model, model_config)
mx.eval(dial2vec.parameters())

def loss_fn(model, inputs): 
    # Not sure if legit but gradients need to be computed inside of value_and_grad otherwise won't flow down layers fully 
    out = dial2vec(*inputs)
    return out['loss'] 


loss_and_grad_fn = nn.value_and_grad(dial2vec, loss_fn)
optimizer = optim.AdamW(learning_rate = model_config['learning_rate'], eps = 1e-6)

epochs = 5
batch = model_config["batch_train_size"]
full_run = sample_num * epochs / batch  

scheduler = create_scheduler(model_config['learning_rate'], warmup_steps = int(full_run * 0.1), total_steps = full_run)

step = 0
losses = [] 
for epoch in range(epochs):
    
    for i in range(0, sample_num, batch):
        if sample_num - i == 0: 
            continue 

        
        lr = scheduler(step)
        optimizer.learning_rate = lr  


        inputs = (
            all_input_ids[i : i + batch, :, :],
            all_input_mask[i : i + batch, :, :],
            all_position_ids[i : i + batch, :, :],
            all_role_ids[i : i + batch, :, :],
            all_labels[i : i + batch]
        )

        loss, grads = loss_and_grad_fn(dial2vec, inputs)
        optimizer.update(dial2vec, grads)
        mx.eval(dial2vec.parameters(), optimizer.state, loss)

        losses.append(loss.item())

        if step % 20 == 0:
            plotlosses.update({'loss' : average(losses)})
            plotlosses.send()     
            losses = [] 

        step += 1 



TypeError: DialogueTransformer.__init__() missing 1 required positional argument: 'config'