In [None]:
#| include: false
import warnings
warnings.filterwarnings("ignore")

This notebook is focused to validate and evaluate trained model that imitates mental health counselor

In [None]:
# Necessary imports
from modules.model import *
from modules.dataset_utils import *

import torch
import torch.nn as nn
import torch.functional as F

from torchmetrics.text import Perplexity, BLEUScore
from jiwer import wer

import matplotlib.pyplot as plt
import seaborn as sns

import ipywidgets as widgets
from tqdm import tqdm

In [None]:
#| include: false
weights = torch.load('/home/piotr/Downloads/model_15-06.pt', map_location=torch.device('cpu'))

tokenizer = weights['tokenizer']
src_vocab_size = tokenizer.get_vocab_size()
tgt_vocab_size = tokenizer.get_vocab_size()

MAX_SRC_LEN = 90 
MAX_TGT_LEN = 300

model = build_transformer(N = 4, d_model = 512, h = 16, src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, src_seq_len=MAX_SRC_LEN, tgt_seq_len=MAX_TGT_LEN)

model.load_state_dict(weights['model_state_dict'])

In [None]:
#| include: false
dataset = DS.from_parquet('./MH_test.parquet')
test_ds = MHDataset(dataset, tokenizer, MAX_SRC_LEN, MAX_TGT_LEN)
test_loader = DataLoader(test_ds, batch_size = 16, shuffle = True)
test_dec_loader = DataLoader(test_ds, batch_size = 1, shuffle = True)

# Human rating

While evaluation metrics are crucial for assessing large language models (LLMs), the ultimate measure of their value lies in their practical usefulness for specific tasks. Metrics can provide a quantitative snapshot of performance, but they may not capture all nuances (especially for such task like this). Therefore, it's essential to also manually review and analyze the results to ensure they meet the real-world needs and expectations of users.

Let's generate some of responses using *greedy decoding* or *beam search* method.

In [None]:
def greedy_decode(model, source, source_mask, tokenizer, max_len, device):
    sos_idx = tokenizer.token_to_id('[SOS]')
    eos_idx = tokenizer.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        deco_mask = decoder_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, deco_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
    
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)



def beam_search_decode(model, beam_size, source, source_mask, tokenizer, max_len, device):
    sos_idx = tokenizer.token_to_id('[SOS]')
    eos_idx = tokenizer.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_initial_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)

    # Create a candidate list
    candidates = [(decoder_initial_input, 1)]

    while True:

        # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
        if any([cand.size(1) == max_len for cand, _ in candidates]):
            break

        # Create a new list of candidates
        new_candidates = []

        for candidate, score in candidates:

            # Do not expand candidates that have reached the eos token
            if candidate[0][-1].item() == eos_idx:
                continue

            # Build the candidate's mask
            candidate_mask = decoder_mask(candidate.size(1)).type_as(source_mask).to(device)
            # calculate output
            with torch.no_grad():
                out = model.decode(encoder_output, source_mask, candidate, candidate_mask)
                # get next token probabilities
                prob = model.project(out[:, -1])
                # get the top k candidates
                topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
                for i in range(beam_size):
                    # for each of the top k candidates, get the token and its probability
                    token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
                    token_prob = topk_prob[0][i].item()
                    # create a new candidate by appending the token to the current candidate
                    new_candidate = torch.cat([candidate, token], dim=1)
                    # We sum the log probabilities because the probabilities are in log space
                    new_candidates.append((new_candidate, score + token_prob))

        # Sort the new candidates by their score
        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
        # Keep only the top k candidates
        candidates = candidates[:beam_size]

        # If all the candidates have reached the eos token, stop
        if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
            break

    # Return the best candidate
    return candidates[0][0].squeeze()

In [None]:
def prepare_input(input_text, tokenizer):
    sos_token = torch.tensor([tokenizer.token_to_id('[SOS]')], dtype = torch.int64)
    eos_token = torch.tensor([tokenizer.token_to_id('[EOS]')], dtype = torch.int64)
    pad_token = torch.tensor([tokenizer.token_to_id('[PAD]')], dtype = torch.int64)
    
    
    # prepare inputs for encoder
    input_tokens = tokenizer.encode(input_text).ids
    
    padding_len = MAX_SRC_LEN - len(input_tokens) - 2 # [SOS] and [EOS]

    if padding_len <= 0:
        # truncate
        input_tokens = input_tokens[:MAX_SRC_LEN - 2]
        encoder_input = torch.cat(
            [
                sos_token,
                torch.tensor(input_tokens, dtype = torch.int64),
                eos_token
            ]
        )
    else:
        encoder_input = torch.cat(
            [
                sos_token,
                torch.tensor(input_tokens, dtype = torch.int64),
                eos_token,
                torch.tensor([pad_token] * padding_len, dtype = torch.int64)
            ]
        )
    encoder_input = encoder_input.unsqueeze(0) # add batch dim
    # create encoder mask
    encoder_mask = (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int()

    return {'encoder_input': encoder_input,
            'encoder_mask': encoder_mask}


def generate_text(input_text: str, model: Transformer, tokenizer: Tokenizer, output_len: int = 50, method: str = 'greedy', device: str = 'cpu'):

    encoder_ins = prepare_input(input_text, tokenizer)
    encoder_input, encoder_mask = encoder_ins['encoder_input'], encoder_ins['encoder_mask']
    if method == 'greedy':
        model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer, output_len, device)
    else:
        model_out = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer, output_len, device)

    model_out_text = tokenizer.decode(model_out.tolist())
    
    return model_out_text

In [None]:
greetings = "Hello there! How are you?"
greetings_response = generate_text(greetings, model = model, method = 'greedy', tokenizer=tokenizer, output_len=30, device = 'cpu')

intro = "What is your name? Tell me something about yourself."
intro_response = generate_text(intro, model = model, tokenizer=tokenizer, output_len=30, device = 'cpu')


sadge = "I have problems with my boss in my job. He gives me too many task and I'm overwhelmed by work."

sadge_response = generate_text(sadge, model = model, method = 'greedy', tokenizer=tokenizer, output_len=30, device = 'cpu')


sadge2 = "I don't know what I want to do in my life"
sadge2_response = generate_text(sadge2, model = model, method = 'greedy', tokenizer=tokenizer, output_len=30, device = 'cpu')


print('*'* 40, 'GREETINGS', '*'*40)
print('INPUT:', greetings)
print('OUTPUT:', greetings_response)

print('*'* 40, 'INTRO', '*'* 40)
print('INPUT:', intro)
print('OUTPUT:', intro_response)

print('*'* 40, 'SADGE', '*'* 40)
print('INPUT:', sadge)
print('OUTPUT:', sadge_response)

print('*'* 40, 'SADGE2', '*'* 40)
print('INPUT:', sadge2)
print('OUTPUT:', sadge2_response)

## Decode test dataset 

In [None]:
def decode_batches(model, validation_ds, tokenizer, max_len, device, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"


            # model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer, max_len, device)
            model_out = beam_search_decode(model, 2, encoder_input, encoder_mask, tokenizer, max_len, device)
            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # Print the source, target and model output
            print('-'*80)
            print(f"{f'SOURCE: ':>12}{source_text}")
            print(f"{f'TARGET: ':>12}{target_text}")
            print(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print('-'*80)
                break


decode_batches(model, test_dec_loader, tokenizer, 20, 'cpu', 10)

The responses generated by the model are quite satisfying. They maintain a strong alignment with the subject matter of the inputs, ensuring relevance. Additionally, the sentences are well-constructed, and the lengths of the outputs are appropriate, providing a balanced and coherent response.

# Evaluation on test dataset

It is finally time to evaluate metrics on test dataset.

Finding a valid metrics for this task is not so easy. We will observe:

- `perplexity`- measures how well a language model predicts a text sample. It’s calculated as the average number of bits per word a model needs to represent the sample. Lower average bits per word implies better performance in the context of predicting text,

- `word error rate` - common metric of the performance of an automatic speech recognition system. This value indicates the percentage of words that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a WER of 0 being a perfect score,

- `token accuracy` - token accuracy measures the percentage of tokens (words or subwords) that the model correctly predicts in a given text sample. Higher token accuracy indicates higher performance

In [None]:
ppx = Perplexity(ignore_index=tokenizer.token_to_id('[PAD]'))

def evaluate_metrics(model:Transformer, loss_fn, test_loader:DataLoader, tokenizer:Tokenizer, device:torch.device) -> None:
    model.eval()

    wers = []
    token_accuracies = []
    pers = []
    losses = []

    with torch.no_grad():
        batch_iterator = tqdm(test_loader, desc=f"Validation step")

        for id, batch in enumerate(batch_iterator):
            encoder_input = batch['encoder_input'].to(device) 
            decoder_input = batch['decoder_input'].to(device) 

            encoder_mask = batch['encoder_mask'].to(device) 
            decoder_mask = batch['decoder_mask'].to(device) 

            label = batch['label'].to(device) # ground_truth

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output) # logits
            
            loss = loss_fn(proj_output.view(-1, tokenizer.get_vocab_size()), label.view(-1))
            
            losses.append(loss.item())
            
            
            per = ppx(proj_output, label)
            pers.append(per)

            pred_tokens = torch.argmax(proj_output, -1)
            for preds, target in zip(pred_tokens, label):
                pred_text = tokenizer.decode(preds.tolist())
                target_text = tokenizer.decode(target.tolist())
                
                # 1) Token accuracy:
                correct_predictions = (preds == target).sum().item()
                total_tokens = len(target)
                sentence_accuracy = correct_predictions / total_tokens
                token_accuracies.append(sentence_accuracy)
                
                # 2) WER
                try:
                    wers.append(wer(target_text, pred_text))
                except ValueError:
                    pass
    return {'Token_accuracy': torch.tensor(token_accuracies, dtype = float).mean(),
            'Perplexity': torch.tensor(pers, dtype = float).mean(),
            'WER': torch.tensor(wers, dtype = float).mean(),
            'Loss': torch.tensor(losses, dtype = float).mean()}

In [None]:
device = 'cpu'

loss_fn = nn.CrossEntropyLoss().to(device)

evaluate_metrics(model, loss_fn = loss_fn, test_loader=test_loader, tokenizer=tokenizer, device = device)

Token accuracy of 0.9032 means that the model correctly predicts about 90.32% of the tokens (words or subwords) it generates. This is a relatively high accuracy, indicating that the model is quite effective at predicting the correct tokens.
But there is also a dark side of that. While token accuracy is high, it doesn’t account for the overall context or semantic meaning of the responses. In mental health counseling, understanding and generating contextually and emotionally appropriate responses are crucial, which may not always be captured by token accuracy alone.

A perplexity of 85.1507 indicates that, on average, the model considers 85.1507 different possibilities for the next token. Lower perplexity values are preferred as they indicate more confidence and precision in predictions. However, perplexity values around 85 for a relatively small dataset suggest that the model is still learning and has room for improvement. High perplexity might mean the model struggles with generating contextually coherent responses. In the context of mental health, this could lead to responses that are less tailored or specific to the nuances of a patient's input.

A WER of 0.6263 means that 62.63% of the words in the generated response are different from the expected response. In our case it is quite high, suggesting that more than half of the model's output is different from the reference.





Considering the relatively small training dataset, the model's performance is promising, but for practical deployment in mental health counseling, additional training with more data and fine-tuning might be necessary to enhance its contextual and semantic capabilities.

# Vizualization of attention layers

Visualizing attention layers of an encoder helps understand model behavior, debug issues, and verify learning. It provides insights into how the model processes information, identifies where it focuses during predictions, and aids in diagnosing errors or biases in its decision-making. This transparency enhances interpretability and trustworthiness of the model's outputs.

In [None]:
def get_attention_scores(input_text, model: Transformer, tokenizer, layer: int, head: int):
    generate_text(input_text, model, tokenizer, output_len=2, device= 'cpu')

    decoder_layer = model.decoder.layers[layer] # type() = EncoderBlock
    encoder_layer = model.encoder.layers[layer] # type() = EncoderBlock
    encoder_attention = encoder_layer.self_attention_block.attention_scores

    cross_attention = decoder_layer.cross_attention_block.attention_scores

    return {'encoder_attention_scores': encoder_attention[:, head, :, :],
            'cross_attention_scores': cross_attention[:, head, :, :]}


def plot_attention(input_text, model, tokenizer):
    def plot_layer_head(layer, head):
        attention = get_attention_scores(input_text, model, tokenizer, layer, head)
        input_len = len(input_text.split()) + 1
        attention_encoder = attention['encoder_attention_scores'].squeeze().detach().cpu().numpy()
        input_tokens = prepare_input(input_text, tokenizer)['encoder_input']
        input_tokens = input_tokens[0, :input_len]
        decoded_tokens = [tokenizer.id_to_token(tok) for tok in input_tokens]

        attention_encoder = attention_encoder[:input_len, :input_len]
        plt.figure(figsize=(10, 8))
        sns.heatmap(attention_encoder, xticklabels=decoded_tokens, yticklabels=decoded_tokens,
                    cmap='coolwarm', fmt=".2f", cbar_kws={"shrink": .8})
        
        plt.title(f'Attention Scores for Layer {layer}, Head {head}')
        plt.xlabel('Tokens')
        plt.ylabel('Tokens')
        plt.xticks(rotation=90)
        plt.yticks(rotation=0)
        plt.show()

    num_layers = 4
    num_heads = 16

    layer_widget = widgets.IntSlider(min=0, max=num_layers-1, value=0, description='Layer:')
    head_widget = widgets.IntSlider(min=0, max=num_heads-1, value=0, description='Head:')
    
    widgets.interact(plot_layer_head, layer=layer_widget, head=head_widget)

In [None]:
plot_attention("I'm having trouble with my romantic relationship, what can I do?", model, tokenizer)

# Conclusions

Goals of project have been achieved. Although built model is not perfect, it gives quite satisfying responses related to topic. 
