# Step 3. Training
This notebook trains the punctuation model. The training dataset after steps 1 and 2 can also be found at the following URL. https://www.kaggle.com/datasets/takuji/punctuation-model-dataset

# Libraries

In [None]:
import os
import json
import datetime
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report

import torch
import torch.nn as nn
from tensorflow.keras.preprocessing.sequence import pad_sequences
from transformers import AutoTokenizer, AutoModelForTokenClassification, AdamW, get_linear_schedule_with_warmup
warnings.filterwarnings('ignore')

# Train

In [None]:
class CFG:
    batch_size=32
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer=AutoTokenizer.from_pretrained('xlm-roberta-large')
    full_finetuning=True
    epochs=2
    max_grad_norm=1.0
    checkpoint_dir='checkpoints'
    log_dir='runs'
    load_checkpoint=False
    checkpoint_path='checkpoint_last.pt'
    learning_rate=8e-6
    apex=True

def process_data(is_train):
    if is_train:
        df = pd.read_csv('given_train.csv')
        df = pd.concat([df, pd.read_csv('indiccorpv2_0_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_0_valid.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('given_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_1_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_1_valid.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('given_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_2_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_2_valid.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('given_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_3_train.csv')]).reset_index(drop=True)
        df = pd.concat([df, pd.read_csv('indiccorpv2_3_valid.csv')]).reset_index(drop=True)
    else:
        df = pd.read_csv('given_valid.csv')
    df.dropna(inplace = True)
    tag_values = ['blank', 'end', 'comma', 'qm', 'hyp']
    tag_values.append("PAD")
    encoder = {t: i for i, t in enumerate(tag_values)}
    print(f"Encoder: {encoder}")
    sentences = df['sentence'].values
    labels = df['label'].values
    return sentences, labels, encoder, tag_values

def folder_with_time_stamps(log_folder, checkpoint_folder):
    folder_hook = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    log_saving = log_folder + '/' + folder_hook
    checkpoint_saving = checkpoint_folder + '/' + folder_hook
    train_encoder_file_path = 'label_encoder_' + folder_hook + '.json'
    return log_saving, checkpoint_saving, train_encoder_file_path, folder_hook

log_folder, checkpoint_folder, train_encoder_file_path, _ = folder_with_time_stamps(CFG.log_dir,
                                                                                    CFG.checkpoint_dir)

print(f"train encoder path -> {train_encoder_file_path}")

os.makedirs(log_folder, exist_ok=True)
os.makedirs(checkpoint_folder, exist_ok=True)

train_sentences, train_labels, train_encoder, tag_values = process_data(is_train=True)
valid_sentences, valid_labels, _, _ = process_data(is_train=False)

with open(train_encoder_file_path, "w") as outfile:
    json.dump(train_encoder, outfile)

print("--------------------------------Tag Values----------------------------------")
print(tag_values)

class PunctuationDataset(torch.utils.data.Dataset):
    def __init__(self, texts, labels, tag2idx):
        self.texts = texts
        self.labels = labels
        self.tag2idx = tag2idx

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        sentence = self.texts[item].split()
        text_label = self.labels[item].split()

        tokenized_sentence = []
        labels = []

        for word, label in zip(sentence, text_label):
            # Tokenize the word and count number of subwords
            tokenized_word = CFG.tokenizer.tokenize(word)
            n_subwords = len(tokenized_word)

            # Add the tokenized word to the final tokenized word list
            tokenized_sentence.extend(tokenized_word)

            # Add the same label to the new list of labels `n_subwords` times
            labels.extend([label] * n_subwords)

        input_ids = pad_sequences([CFG.tokenizer.convert_tokens_to_ids(tokenized_sentence)],
                                  maxlen=256, dtype="long", value=0.0,
                                  truncating="post", padding="post")

        tags = pad_sequences([[self.tag2idx.get(l) for l in labels]],
                             maxlen=256, value=self.tag2idx["PAD"], padding="post",
                             dtype="long", truncating="post")

        attention_masks = [float(i != 0.0) for i in input_ids[0]]

        return {
            "ids": torch.tensor(input_ids[0], dtype=torch.long),
            "mask": torch.tensor(attention_masks, dtype=torch.long),
            "target_tag": torch.tensor(tags[0], dtype=torch.long),
        }

train_dataset = PunctuationDataset(texts=train_sentences, labels=train_labels,
                                   tag2idx=train_encoder)
valid_dataset = PunctuationDataset(texts=valid_sentences, labels=valid_labels,
                                   tag2idx=train_encoder)

train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=True)
valid_data_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.batch_size, num_workers=4, shuffle=False)

model = AutoModelForTokenClassification.from_pretrained('xlm-roberta-large',
                                                        num_labels=len(train_encoder),
                                                        output_attentions=False,
                                                        output_hidden_states=False)

weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 0.0]).cuda()
criterion = nn.CrossEntropyLoss(weight=weights)

if CFG.full_finetuning:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]

else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=CFG.learning_rate,
    eps=1e-8
)

total_steps = len(train_data_loader) * CFG.epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

starting_epoch = 0

if CFG.load_checkpoint:
    checkpoint = torch.load(CFG.checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    starting_epoch = checkpoint['epoch'] + 1

if torch.cuda.device_count() > 1:
    print("Using ", torch.cuda.device_count(), "GPUs")
    model = nn.DataParallel(model)

loss_values, validation_loss_values = [], []
model.cuda()

for epoch in range(starting_epoch, 8):

    model.train()
    total_loss = 0

    # Training loop
    tk0 = tqdm(train_data_loader, total=int(len(train_data_loader)), unit='batch')
    tk0.set_description(f'Epoch {epoch + 1}')

    for step, batch in enumerate(tk0):
        # add batch to gpu
        for k, v in batch.items():
            batch[k] = v.to(CFG.device)

        b_input_ids, b_input_mask, b_labels = batch['ids'], batch['mask'], batch['target_tag']

        model.zero_grad()

        with torch.cuda.amp.autocast(enabled=CFG.apex, dtype=torch.bfloat16):
            outputs = model(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
            loss = criterion(outputs[1].view(-1, 6), b_labels.view(-1))
        loss.backward()
        total_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=CFG.max_grad_norm)

        optimizer.step()

        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_data_loader)
    print("Average train loss: {}".format(avg_train_loss))

    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(state, checkpoint_folder + '/checkpoint_last.pt')
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    model.eval()
    # Reset the validation loss for this epoch.
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    predictions, true_labels = [], []

    best_val_loss = np.inf

    for batch in tqdm(valid_data_loader, total=int(len(valid_data_loader)), unit='batch', leave=True):
        for k, v in batch.items():
            batch[k] = v.to(CFG.device)
        b_input_ids, b_input_mask, b_labels = batch['ids'], batch['mask'], batch['target_tag']

        with torch.no_grad():
            outputs = model(b_input_ids, token_type_ids=None,
                            attention_mask=b_input_mask, labels=b_labels)
        logits = outputs[1].detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        eval_loss += criterion(outputs[1].view(-1, 6), b_labels.view(-1)).item()
        predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
        true_labels.extend(label_ids)

    eval_loss = eval_loss / len(valid_data_loader)

    if eval_loss < best_val_loss:
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(state, checkpoint_folder + '/checkpoint_best.pt')
        best_val_loss = eval_loss

    validation_loss_values.append(eval_loss)
    print("Validation loss: {}".format(eval_loss))

    pred_tags = [tag_values[p_i] for p, l in zip(predictions, true_labels) for p_i, l_i in zip(p, l) if
                 tag_values[l_i] != "PAD"]
    valid_tags = [tag_values[l_i] for l in true_labels for l_i in l if tag_values[l_i] != "PAD"]

    val_accuracy = accuracy_score(valid_tags, pred_tags)
    val_f1_score = f1_score(valid_tags, pred_tags, average='macro')
    report = classification_report(valid_tags, pred_tags, output_dict=True, labels=np.unique(pred_tags))

    df_report = pd.DataFrame(report).transpose()
    df_report['categories'] = list(df_report.index)
    df_report = df_report[ ['categories'] + [ col for col in df_report.columns if col != 'categories' ] ]

    print("Validation Accuracy: {}".format(val_accuracy))
    print("Validation F1-Score: {}".format(val_f1_score))
    print("Classification Report: {}".format(report))

    df_report.to_csv(f'report_epoch{epoch + 1}.csv', index=False)