In [None]:
import regex as re 
from glob import glob
from mlx_embeddings import load, generate 
from functools import reduce
import mlx.core as mx
import mlx.nn as nn 
import mlx.optimizers as optim 
import numpy as np 
from sklearn.metrics import f1_score

data_dir = "MRDA-Corpus/mrda_data"
model_name = "answerdotai/ModernBERT-base"

model, tokenizer = load(model_name)

def get_dialog(file_paths: list[str]) -> dict[str, list]:
    cache = {} 

    for file_path in file_paths: 
        file = open(file_path, 'r')

        lines = file.readlines()

        tuples = [] 
        for l in lines: 
            ls = l.strip().split('|')

            author = ls[0]
            text = ls[1]
            da = ls[-1]

            tokens = tokenizer.tokenize(text)
            tokens = [tokenizer.cls_token_id] + tokenizer.convert_tokens_to_ids(tokens)

            tuples.append({
                'speaker' : author, 
                'dialog_act' : da,
                'input_ids' : tokens, 
                'input_mask' : [1] * len(tokens), 
                'position_ids' : list(range(len(tokens)))
            })

        cache[file_path] = tuples 

    return cache 

def sweep(set: dict[str, list]) -> int:
    max_tokens = 0
    for k, v in set.items(): 
        for tuple in v: 
            max_tokens = max(max_tokens, len(tuple['input_ids']))
    return max_tokens 

def peak(set: dict[str, list]) -> set[str]: 
    classes = {*()} # set() doesn't work  O_O
    for k,v in set.items(): 
        for tuple in v: 
            classes.add(tuple['dialog_act'])
    return classes  


def pad(set: dict[str, list], max_tokens: int, pad_id: int): 
    for k, utterances in set.items(): 
        for row in utterances:
            holdover = max_tokens - len(row['input_ids'])
            row['input_ids'] = row['input_ids'] + ([pad_id] * holdover)
            row['input_mask'] = row['input_mask'] + ([0] * holdover)
            row['position_ids'] = row['position_ids'] + ([0] * holdover)

            row['input_ids'] = mx.array(row['input_ids'])
            row['input_mask'] = mx.array(row['input_mask'])
            row['position_ids'] = mx.array(row['position_ids'])

def convert_dialog(set: dict[str, list], dialogs: list): 
    for k, utterance in set.items(): 
        for row in utterance: 
            row['dialog_act'] = dialogs.index(row['dialog_act'])

train_fp, test_fp, val_fp = map(glob, (f"{data_dir}/train/*.txt", f"{data_dir}/test/*.txt", f"{data_dir}/val/*.txt"))
train, test, val = map(get_dialog, (train_fp, test_fp, val_fp))
x, y, z = map(sweep, (train, test, val))

# Useful for padding with maximum amount so that training doesn't require too much space. 
max_tokens = reduce(max, (x, y, z))
list(map(lambda x: pad(x, max_tokens, tokenizer.pad_token_id), (train, test, val)))

# Get dialog acts 
classes = list(reduce(lambda x, y: x | y, map(peak, (train, test, val))))
num_classes = len(classes)

list(map(lambda x: convert_dialog(x, classes), (train, test, val)))


# Defined in paper as best chunk size 
chunk_size = 96 

In [None]:
def loss_fn(model, X: tuple[list[str], mx.array, mx.array, mx.array], y : mx.array): 
    logits = model(X[1], X[2], X[3], X[0])

    return mx.mean(nn.losses.cross_entropy(logits, y))

def eval_fn(model, X: dict[str, list]) -> float: 
    pred_list = [] 
    y_list = [] 

    for k, v in X.items(): 

        for i in range(0, len(v), chunk_size):
            slice = v[i : i + chunk_size]
            
            speakers = [row['speaker'] for row in slice]
            acts = [row['dialog_act'] for row in slice]
            input_ids = [row['input_ids'] for row in slice]
            input_masks = [row['input_mask'] for row in slice]
            position_ids = [row['position_ids'] for row in slice]

            ids, masks, pos = map(
                lambda x: mx.stack(x, axis = 0),
                (input_ids, input_masks, position_ids) 
            )
            acts = mx.array(acts)

            logits = model(ids, masks, pos, speakers)
            pred = mx.argmax(logits, axis = 1)

            pred_list.append(pred)
            y_list.append(acts)

    pred_list = mx.concat(pred_list, axis = 0)
    y_list = mx.concat(y_list, axis = 0)

    acc = mx.sum((pred_list == y_list), axis = 0) / pred_list.shape[0]

    return f1_score(np.array(y_list), np.array(pred_list), average = 'macro'), acc.item()

from dagc.model import DAGC 
dagc = DAGC(model, num_classes)
mx.eval(dagc.parameters())

optimizer = optim.AdamW(learning_rate = 1e-3, weight_decay = 5e-4)

loss_and_grad_fn = nn.value_and_grad(dagc, loss_fn)

def step(speakers: list[str], 
         acts: list[int], 
         input_ids: list[mx.array], 
         input_masks: list[mx.array], 
         position_ids: list[mx.array]):
    
    ids, masks, pos = map(
        lambda x: mx.stack(x, axis = 0),
        (input_ids, input_masks, position_ids) 
    )
    acts = mx.array(acts)

    loss, grads = loss_and_grad_fn(dagc, (speakers, ids, masks, pos), acts)
    optimizer.update(dagc, grads)

    mx.eval(dagc.parameters(), optimizer.state)


def early_stop(scores: list[float]) -> bool:
    for i in range(2, 11):
        if scores[-i] > scores[-(i - 1)]:
            return False  
    return True 

scores = [] 
lr_reduce_scores = [] 
for e in range(100):

    for k, v in train.items(): 

        for i in range(0, len(v), chunk_size):
            slice = v[i : i + chunk_size]
            step(
                [row['speaker'] for row in slice],
                [row['dialog_act'] for row in slice],
                [row['input_ids'] for row in slice],
                [row['input_mask'] for row in slice],
                [row['position_ids'] for row in slice]
            )


    score, acc = eval_fn(dagc, test)
    scores.append(score)
    lr_reduce_scores.append(score)
    print(f"Epoch {e + 1}, F1-Score: ", score, "Accuracy: ", acc)

    if len(scores) > 14:
        if early_stop(scores):
            print("Stopping early due to no f1 improvement")
            break  

    if len(lr_reduce_scores) > 4: 
        if lr_reduce_scores[-1] <= lr_reduce_scores[-2] and lr_reduce_scores[-2] <= lr_reduce_scores[-3] and lr_reduce_scores[-3] <= lr_reduce_scores[-4]:
            optimizer.learning_rate = optimizer.learning_rate * 0.9 
            print("Reducing learning rate")
            lr_reduce_scores = [] 


