### Import all necessary packages 

In [None]:
import torch
import io
import torch.nn.functional as F
import random
import numpy as np
import time
import math
import datetime
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import os
from transformers import AutoModel,AutoTokenizer

### Helper Functions

In [None]:
def get_qc_examples(input_file, num_sentences=None):
    examples = []

    with open(input_file, 'r') as f:
        lines = f.readlines()
        if(num_sentences==0):
            return examples
        if num_sentences:
            lines = lines[:num_sentences]

        for line in lines:
            split = line.split("\t")
            text = split[0]
            labels = split[1][:-1]
            examples.append((text, labels))

    return examples

In [None]:
def generate_data_loader(input_examples, label_masks, label_map, do_shuffle = False, balance_label_examples = False):
    '''
    Generate a Dataloader given the input examples, eventually masked if they are to be considered NOT labeled.
    '''
    examples = []

    # Count the percentage of labeled examples
    num_labeled_examples = 0
    for label_mask in label_masks:
        if label_mask:
            num_labeled_examples += 1

    label_mask_rate = num_labeled_examples/len(input_examples)

    # if required it applies the balance
    for index, ex in enumerate(input_examples):
        if label_mask_rate == 1 or not balance_label_examples:
            examples.append((ex, label_masks[index]))
        else:
            # IT SIMULATE A LABELED EXAMPLE
            if label_masks[index]:
                balance = int(1/label_mask_rate)
                balance = int(math.log(balance,2))
                if balance < 1:
                    balance = 1
                for b in range(0, int(balance)):
                    examples.append((ex, label_masks[index]))
            else:
                examples.append((ex, label_masks[index]))

    input_ids = []
    input_mask_array = []
    label_mask_array = []
    label_id_array = []
    label_id01_array = []

    # Tokenization
    for (text, label_mask) in examples:
        # text[0] is the sentence
        # text[1] is the list of 1s and 0s separated by space mentioning which word is a disfluency
        # label_mask is True (if sentence is labelled) or False
        
        encoded_sent = tokenizer.encode(text[0], add_special_tokens=True, max_length=max_seq_length, padding="max_length", truncation=True)
        input_ids.append(encoded_sent)
        label_mask_array.append(label_mask)

        tokenized_inputs = tokenizer([text[0].split()], max_length=max_seq_length, padding="max_length", truncation=True, is_split_into_words=True)
        word_ids = tokenized_inputs.word_ids(batch_index=0)

        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(PAD_VALUE)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                
                if label_mask == 1:
                    label_ids.append(int(text[1].split()[word_idx]))
                else:
                    label_ids.append(int(0)) # anything
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                if label_mask == 1:
                    label_ids.append(int(text[1].split()[word_idx]))
                else:
                    label_ids.append(int(0)) # anything

            previous_word_idx = word_idx

        label_ids01 = [0 if id == PAD_VALUE else id for id in label_ids]

        label_id_array.append(label_ids)
        label_id01_array.append(label_ids01)


    # Attention to token (to ignore padded input wordpieces)
    for sent in input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]
        input_mask_array.append(att_mask)
    

    input_ids = torch.tensor(input_ids) # List of list. Inner list contains ids for every word of the sentence and 0 for padding
    input_mask_array = torch.tensor(input_mask_array) # List of list. Inner list contains 1 for every non zero input id word in example sentence
    label_id_array = torch.tensor(label_id_array, dtype=torch.long) # 2 for padding, 1 for disfluent, 0 for non disfluent
    label_id01_array = torch.tensor(label_id01_array, dtype=torch.long) # 1 for disfluent, 0 for padding/non disfluent
    label_mask_array = torch.tensor(label_mask_array) # list of bool. True for sent with labeling, false for sent without labeling

    dataset = TensorDataset(input_ids, input_mask_array, label_id_array, label_mask_array, label_id01_array)

    if do_shuffle:
        sampler = RandomSampler
    else:
        sampler = SequentialSampler

    return DataLoader(
              dataset,
              sampler = sampler(dataset),
              batch_size = batch_size)

In [None]:
def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded)) # Format as hh:mm:ss

### Define parameters for data usage and training

In [None]:
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed_val)
if torch.cuda.is_available():
    device = torch.device("cuda:3")
else:
    device = torch.device("cpu")

    
EXPERIMENT_ID= 'presto-english-small-230523-2'
CHECKPOINT_DIR = f'checkpoints/{EXPERIMENT_ID}'
transformer_model_name = "google/muril-base-cased" # Check Hugging Face for experimenting with other multilingual transformers


label_list = ["1", "0"]
PAD_VALUE = 2

labeled_file = "./data/sample_data/sample-train-labeled.tsv"
unlabeled_file = "./data/sample_data/sample-train-unlabeled.tsv"
valid_filename = "./data/sample_data/sample-valid-labeled.tsv"
test_filename = "./data/sample_data/sample-test-labeled.tsv"

NUM_TRAINING_SENTENCES = len(open(labeled_file,'r').readlines()) # You can also specify a fixed number like 5000 or 10000
NUM_UNLABELED_SENTENCES = len(open(unlabeled_file,'r').readlines()) # You can also specify a fixed number like 5000 or 10000

print("Labeled file:", labeled_file)
print("Unlabeled file:", unlabeled_file)
print("Valid file:", valid_filename)
print("Test file:", test_filename)


### Define transformer parameters

In [None]:
#--------------------------------
#  Transformer parameters
#--------------------------------
max_seq_length = 64
batch_size = 16

# number of hidden layers in the discriminator
num_hidden_layers_d = 1; 

# size of the generator's input noisy vectors
noise_size = 100

# dropout to be applied to discriminator's input vectors
out_dropout_rate = 0.4

# Replicate labeled data to balance poorly represented datasets
apply_balance = True

#--------------------------------
#  Optimization parameters
#--------------------------------
learning_rate_discriminator = 2e-5
learning_rate_generator = 2e-5
epsilon = 1e-8
num_train_epochs = 40

# Scheduler
apply_scheduler = False
warmup_proportion = 0.1

# Print
print_each_n_step = 10


### Load transformer, tokenizer and data for training/testing

In [None]:
transformer = AutoModel.from_pretrained(transformer_model_name)
tokenizer = AutoTokenizer.from_pretrained(transformer_model_name)

In [None]:
labeled_examples = get_qc_examples(labeled_file, num_sentences=NUM_TRAINING_SENTENCES)
unlabeled_examples = get_qc_examples(unlabeled_file, num_sentences=NUM_UNLABELED_SENTENCES)
valid_examples = get_qc_examples(valid_filename)
test_examples = get_qc_examples(test_filename)

print(f"Labeled Training Examples: {len(labeled_examples)}, Unlabeled Training Examples: {len(unlabeled_examples)}, Validation Examples: {len(valid_examples)}, Test Examples: {len(test_examples)}")

In [None]:
label_map = {}
for (i, label) in enumerate(label_list):
    label_map[label] = i

    
train_examples = labeled_examples

#The labeled (train) dataset is assigned with a mask set to True
train_label_masks = np.ones(len(labeled_examples), dtype=bool)

#If unlabel examples are available
if unlabeled_examples:
    train_examples = train_examples + unlabeled_examples
    #The unlabeled (train) dataset is assigned with a mask set to False
    tmp_masks = np.zeros(len(unlabeled_examples), dtype=bool)
    train_label_masks = np.concatenate([train_label_masks,tmp_masks])

train_dataloader = generate_data_loader(train_examples, train_label_masks, label_map, do_shuffle = True, balance_label_examples = apply_balance)

valid_label_masks = np.ones(len(valid_examples), dtype=bool)
valid_dataloader = generate_data_loader(valid_examples, valid_label_masks, label_map, do_shuffle = False, balance_label_examples = False)

test_label_masks = np.ones(len(test_examples), dtype=bool)
test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)

### Define model based on hyperparameters

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_size=100, hidden_size=512, dropout_rate=0.1):
        super(Generator, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=8)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

    def forward(self, noise, memory):
        return self.transformer_decoder(noise, memory)


class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers)
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs

In [None]:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name)

hidden_size = int(config.hidden_size)
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]

generator = Generator(noise_size=noise_size, hidden_size=hidden_size, dropout_rate=out_dropout_rate)
discriminator = Discriminator(input_size=hidden_size, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate)

if torch.cuda.is_available():    
    generator.to(device)
    discriminator.to(device)
    transformer.to(device)

### Start Training

In [None]:
training_stats = []

total_t0 = time.time()

transformer_vars = [i for i in transformer.parameters()]
d_vars = transformer_vars + [v for v in discriminator.parameters()]
g_vars = [v for v in generator.parameters()]

dis_optimizer = torch.optim.AdamW(d_vars, lr=learning_rate_discriminator)
gen_optimizer = torch.optim.AdamW(g_vars, lr=learning_rate_generator) 

if apply_scheduler:
    num_train_examples = len(train_examples)
    num_train_steps = int(num_train_examples / batch_size * num_train_epochs)
    num_warmup_steps = int(num_train_steps * warmup_proportion)

    scheduler_d = get_constant_schedule_with_warmup(dis_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
    scheduler_g = get_constant_schedule_with_warmup(gen_optimizer, 
                                           num_warmup_steps = num_warmup_steps)
    
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=f"{CHECKPOINT_DIR}/runs/")

step_cnt = 0 # to decide generator's/discriminator's backprop?
best_f1 = 0
best_checkpoint_so_far = None

for epoch_i in range(0, num_train_epochs):
    # ========================================
    #               Training
    # ========================================
    # Perform one full pass over the training set.
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, num_train_epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()

    # Reset the total loss for this epoch.
    tr_g_loss = 0
    tr_d_loss = 0

    # Put the model into training mode.
    transformer.train() 

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):

        step_cnt += 1
        if step_cnt % 3 == 0:
            generator.train()
            discriminator.eval()
        else:
            generator.eval()
            discriminator.train()
        
        # Progress update every print_each_n_step batches.
        if step % print_each_n_step == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_label_mask = batch[3].to(device)
        b_labels01 = batch[4].to(device)
        real_batch_size = b_input_ids.shape[0]
     
        model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
        hidden_states = model_outputs[0]
        

        # Generator using Transformer Decoder 
        noise = torch.randn(max_seq_length, real_batch_size, hidden_size).to(device)
        memory = torch.randn(max_seq_length, real_batch_size, hidden_size).to(device)
        gen_rep = generator(noise, memory).permute(1, 0, 2)

        # Generate the output of the Discriminator for real and fake data.
        disciminator_input = torch.cat([hidden_states, gen_rep], dim=0)
        
        
        # Then, we select the output of the disciminator
        features, logits, probs = discriminator(disciminator_input)

        # Finally, we separate the discriminator's output for the real and fake data
        features_list = torch.split(features, real_batch_size)
        D_real_features = features_list[0]
        D_fake_features = features_list[1]
      
        logits_list = torch.split(logits, real_batch_size)
        D_real_logits = logits_list[0]
        D_fake_logits = logits_list[1]
        
        probs_list = torch.split(probs, real_batch_size)
        D_real_probs = probs_list[0] # (5,64,3)
        D_fake_probs = probs_list[1] # (5,64,3)
        
        
        # Generator's LOSS estimation
        g_loss_d = -1 * torch.mean(torch.log(1 - D_fake_probs[:, :, -1] + epsilon))
        g_feat_reg = torch.mean(torch.pow(torch.mean(D_real_features, dim=0) - torch.mean(D_fake_features, dim=0), 2))
        g_loss = g_loss_d + g_feat_reg

        # Disciminator's LOSS estimation
        logits = D_real_logits[:, :, 0:-1]
        
        
        log_probs = F.log_softmax(logits, dim=-1)
        label2one_hot = torch.nn.functional.one_hot(b_labels01, len(label_list))
        per_example_loss = -torch.sum(label2one_hot * log_probs, dim=(-2, -1))
        per_example_loss = torch.masked_select(per_example_loss, b_label_mask.to(device))
        labeled_example_count = per_example_loss.type(torch.float32).numel()

        # It may be the case that a batch does not contain labeled examples, so the "supervised loss" in this case is not evaluated
        if labeled_example_count == 0:
            D_L_Supervised = 0
        else:
            D_L_Supervised = torch.div(torch.sum(per_example_loss.to(device)), labeled_example_count)
                 
        D_L_unsupervised1U = -1 * torch.mean(torch.log(1 - D_real_probs[:, :, -1] + epsilon))
        D_L_unsupervised2U = -1 * torch.mean(torch.log(D_fake_probs[:, :, -1] + epsilon))
        d_loss = D_L_Supervised + D_L_unsupervised1U + D_L_unsupervised2U

        #---------------------------------
        #  OPTIMIZATION
        #---------------------------------
        # Avoid gradient accumulation
        gen_optimizer.zero_grad()
        dis_optimizer.zero_grad()
        

        if step_cnt % 3 == 0:
            g_loss.backward()
            gen_optimizer.step()
            if apply_scheduler:
                scheduler_g.step()
        else:
            d_loss.backward() 
            dis_optimizer.step()
            if apply_scheduler:
                scheduler_d.step()

        
        # Save the losses to print them later
        tr_g_loss += g_loss.item()
        tr_d_loss += d_loss.item()


    # Calculate the average loss over all of the batches.
    avg_train_loss_g = tr_g_loss / len(train_dataloader)
    avg_train_loss_d = tr_d_loss / len(train_dataloader)             
    
    # Measure how long this epoch took.
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss generator: {0:.3f}".format(avg_train_loss_g))
    print("  Average training loss discriminator: {0:.3f}".format(avg_train_loss_d))
    print("  Training epoch took: {:}".format(training_time))
        
        
    # ========================================
    #     TEST ON THE EVALUATION DATASET
    # ========================================
    print("")
    print("Running Validation...")

    t0 = time.time()

    # Put the model in evaluation mode--the dropout layers behave differently during evaluation.
    transformer.eval()
    discriminator.eval()
    generator.eval()

    nll_loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
    total_valid_loss = 0
    
    all_preds = []
    all_labels_ids = []

    for batch in valid_dataloader:
        
        # Unpack this training batch from our dataloader. 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        b_labels01 = batch[4].to(device)

        with torch.no_grad():        
            model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
            hidden_states = model_outputs[0]
            _, logits, probs = discriminator(hidden_states)
            filtered_logits = logits[:, :, 0:-1]
            
            total_valid_loss += nll_loss(filtered_logits.reshape((-1, len(label_list))), b_labels01.reshape((-1)))
        
        _, preds = torch.max(filtered_logits, 2)
        
        all_preds += preds.detach().cpu()
        all_labels_ids += b_labels.detach().cpu()

    # Report the final accuracy for this validation run.
    all_preds = torch.stack(all_preds).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    
    pad_ids = np.full((all_preds.shape), PAD_VALUE)
    valid_accuracy = np.sum(np.logical_and(all_preds == all_labels_ids, all_labels_ids != pad_ids)) / np.sum(all_labels_ids != pad_ids)
    print("  Validation Accuracy: {0:.3f}".format(valid_accuracy))
    
    # NOTE: We consider prediction = 2 (PAD token) as FLUENT
    
    tp = np.sum(np.logical_and(all_preds == all_labels_ids, all_labels_ids == np.full((all_preds.shape), 1)))
    tn = np.sum(np.logical_and(all_labels_ids == np.full((all_preds.shape), 0), 
                               np.logical_or(all_preds == np.full((all_preds.shape), 0), 
                                             all_preds == np.full((all_preds.shape), 2))))
    fn = np.sum(np.logical_and(all_labels_ids == np.full((all_preds.shape), 1), 
                               all_labels_ids != all_preds))
    fp = np.sum(np.logical_and(all_preds == np.full((all_preds.shape), 1), 
                               all_labels_ids == np.full((all_preds.shape), 0)))
    
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)

    # Calculate the average loss over all of the batches.
    avg_valid_loss = total_valid_loss / len(valid_dataloader)
    avg_valid_loss = avg_valid_loss.item()
    
    # Measure how long the validation run took.
    valid_time = format_time(time.time() - t0)

    print("  Validation Precision: {0:.3f}".format(precision))
    print("  Validation Recall: {0:.3f}".format(recall))
    print("  Validation F1 Score: {0:.3f}".format(f1))
    print("  Validation Loss: {0:.3f}".format(avg_valid_loss))
    print("  Validation took: {:}".format(valid_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss generator': avg_train_loss_g,
            'Training Loss discriminator': avg_train_loss_d,
            'Training Time': training_time,
            
            'Validation Loss': avg_valid_loss,
            'Validation Time': valid_time,
            'Validation tp': tp.item(),
            'Validation tn': tn.item(),
            'Validation fp': fp.item(),
            'Validation fn': fn.item(),
            'Validation Accuracy': valid_accuracy.item(),
            'Validation Precision': precision.item(),
            'Validation Recall': recall.item(),
            'Validation F1 Score': f1.item()
        }
    )
    
    writer.add_scalar('Train Loss Generator', avg_train_loss_g, epoch_i)
    writer.add_scalar('Train Loss Discriminator', avg_train_loss_d, epoch_i)
    writer.add_scalar('Validation Loss', avg_valid_loss, epoch_i)
    writer.add_scalar('Validation Accuracy', valid_accuracy, epoch_i)
    writer.add_scalar('Validation Precision', precision, epoch_i)
    writer.add_scalar('Validation Recall', recall, epoch_i)
    writer.add_scalar('Validation F1 Score', f1, epoch_i)
    writer.flush()
    
    if math.isnan(f1):
        f1 = 0

    if f1 >= best_f1 and f1 > 0:
        os.system(f'mkdir -p {CHECKPOINT_DIR}/checkpoint-{epoch_i + 1}/')
        torch.save(discriminator, f"{CHECKPOINT_DIR}/checkpoint-{epoch_i + 1}/discriminator.pt")
        torch.save(generator, f"{CHECKPOINT_DIR}/checkpoint-{epoch_i + 1}/generator.pt")
        transformer.save_pretrained(f"{CHECKPOINT_DIR}/checkpoint-{epoch_i + 1}")
        tokenizer.save_pretrained(f"{CHECKPOINT_DIR}/checkpoint-{epoch_i + 1}/")

        # Remove the prev BEST checkpoint
        if f1 > best_f1 and best_checkpoint_so_far:
            os.system(f"rm -r {CHECKPOINT_DIR}/checkpoint-{best_checkpoint_so_far}/")

        best_f1 = f1
        best_checkpoint_so_far = epoch_i + 1

    print("Best checkpoint so far:", best_checkpoint_so_far)

writer.close()

In [None]:
for i in range(len(training_stats)):
    if math.isnan(training_stats[i]['Validation F1 Score']):
        # print(training_stats[i]['Validation F1 Score'])
        training_stats[i]['Validation F1 Score'] = 0

In [None]:
print("Best Validation performance")
print("Accuracy: {:0.2f}".format(training_stats[0]['Validation Accuracy'] * 100))
print("Precision: {:0.2f}".format(training_stats[0]['Validation Precision'] * 100))
print("Recall: {:0.2f}".format(training_stats[0]['Validation Recall'] * 100))
print("F1 Score: {:0.2f}".format(training_stats[0]['Validation F1 Score'] * 100))

### Load best model from training & infer on test set

In [None]:
training_stats.sort(key = lambda x: (x['Validation Loss'], -x['Validation F1 Score'], -x['Validation Accuracy']))
training_stats.sort(key = lambda x: (-x['Validation F1 Score'], -x['Validation Accuracy']))

best_checkpoint = training_stats[0]['epoch']
print("Best checkpoint is:", best_checkpoint)

model_dir = f"{CHECKPOINT_DIR}/checkpoint-{best_checkpoint}"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
transformer = AutoModel.from_pretrained(model_dir).to(device)
discriminator = torch.load(f"{model_dir}/discriminator.pt")

# Uncomment next two cells if you want to load a model directly for inference on a defined test set

In [None]:
# model_dir = "/home/development/vineet/DDP_1/seq-gan-bert/checkpoints/presto-english-small-210523-7/checkpoint-27/"
# tokenizer = AutoTokenizer.from_pretrained(model_dir)
# transformer = AutoModel.from_pretrained(model_dir).to(device)

In [None]:
# discriminator = torch.load(f"{model_dir}/discriminator.pt")
# generator = torch.load(f"{model_dir}/generator.pt")

In [None]:
from datasets import Dataset

def test_tokenizer(examples):
    tokenized_inputs = tokenizer(examples["disfluent"], truncation=True, max_length=512, is_split_into_words=True)
    return tokenized_inputs


# Evaluate on blind sentences
test_dict = {
                'disfluent': [sentence.split("\t")[0].split() for sentence in open(test_filename, 'r').readlines()],
                'labels': [sentence.split("\t")[1].split() for sentence in open(test_filename, 'r').readlines()],
            }

test_dataset = Dataset.from_dict(test_dict)
test_dataset = test_dataset.map(test_tokenizer, batched=True)
tokenized_for_word_ids = tokenizer(test_dataset["disfluent"], truncation=True, is_split_into_words=True)

all_preds = []
all_labels_ids = []

transformer.eval()
discriminator.eval()
generator.eval()

for batch in test_dataloader:
        # Unpack this training batch from our dataloader.
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        with torch.no_grad():
            model_outputs = transformer(b_input_ids, attention_mask=b_input_mask)
            hidden_states = model_outputs[0]
            _, logits, probs = discriminator(hidden_states)
            filtered_logits = logits[:, :, 0:-1]

        _, preds = torch.max(filtered_logits, 2)
        all_preds += preds.detach().cpu()
        all_labels_ids += b_labels.detach().cpu()


tp, fp, tn, fn = 0, 0, 0, 0

for i in range(len(all_preds)):
    actual_input = test_dataset["disfluent"][i]
    word_ids = tokenized_for_word_ids.word_ids(i)[1:-1] # Remove [CLS] & [SEP] token's word_id

    print("Input           :\t", ' '.join(actual_input))
    print("Ref Labels      :\t", ' '.join(map(str, all_labels_ids[i][1: 1 + len(word_ids)].tolist())))
    print("Predicted Labels:\t", ' '.join(map(str, all_preds[i][1: 1 + len(word_ids)].tolist())))

    for ref_label, pred_label in zip(all_labels_ids[i][1: 1 + len(word_ids)].tolist(), all_preds[i][1: 1 + len(word_ids)].tolist()):
        if ref_label == pred_label:
            if ref_label == 1:
                tp += 1
            else:
                tn += 1
        elif ref_label == 1:
            fn += 1 # pred_label can be 0 or 2
        elif ref_label == 0:
            if pred_label == 1:
                fp += 1
            elif pred_label == 2:
                tp += 1

    # print(tokenized_input)
    # print(word_ids)
    # print(predictions[i][1:1 + len(tokenized_input)])

    previous_word_idx = None
    disfluent = 0 # Count of (predicted) disfluent subwords of a word
    fluent = 0 # count of (predicted) fluent subwords of a words
    fluent_sentence = []

    for idx, prediction in enumerate(all_preds[i][1:1 + len(word_ids)]): # Remove [CLS] & [SEP] & PAD TOKEN predictions

        # We add/ignore the previous word (based on how many subwords of the word were predicted disfluent).
        # Added if count(fluent subwords) >= count(disfluent subwords)
        if word_ids[idx] != previous_word_idx:
            if previous_word_idx is not None and fluent >= disfluent:
                fluent_sentence.append(actual_input[previous_word_idx])

            fluent, disfluent = 0, 0

        if prediction == 0 or prediction == 2: # consider prediction = 2 (PAD token) is also FLUENT
            fluent += 1
        else:
            disfluent += 1

        previous_word_idx = word_ids[idx]

    # Don't forget to add the last word
    if previous_word_idx is not None and fluent >= disfluent:
        fluent_sentence.append(actual_input[previous_word_idx])

    print("Prediction:", ' '.join(fluent_sentence))

    if i % 500 == 0 and i > 0:
        print(f"Testing Example-{i}")


print('tp, fp, fn, tn:', tp, fp, fn, tn)
precision = 100 * tp / (tp + fp)
recall = 100 * tp / (tp + fn)
f1 = 2 * precision * recall / (precision + recall)

print("Accuracy:", round(100 * (tp + tn) / (tp + fp + fn + tn), 2))
print("Precision:", round(precision, 2))
print("Recall:", round(recall, 2))
print("F1 Score:", round(f1, 2))

### Save data from training as log file

In [None]:
print(training_stats[0])
import json

with open(f"{CHECKPOINT_DIR}/log", 'w') as log:
    log.write(json.dumps(training_stats))