# Prerequisites

In [1]:
import pandas as pd
import re
import numpy as np
from nltk.tokenize import word_tokenize
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import torch
import torch.nn as nn
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from transformers import DistilBertTokenizer, DistilBertModel
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime
from sklearn.metrics import f1_score
import pickle
import matplotlib.pyplot as plt

class DistilBERT_Arch(nn.Module):
    def __init__(self, distilbert):
        super().__init__()
        self.distilbert = distilbert
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(768, 512)
        self.fc2 = nn.Linear(512, 4)
        self.softmax = nn.LogSoftmax(dim=1)
    def forward(self, input_ids, mask):        
        cls_hs = self.distilbert(input_ids, attention_mask=mask)[0]
        x = self.fc1(cls_hs)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

## Prep

In [None]:
np.random.seed(1)

df = pd.read_csv("/home/ubuntu/Development/punctuation/data/transcripts.csv")
# len(df): 2467

# train:val:test = 0.8:0.1:0.1
num_train = int(len(df)*0.8)+2 # 0.8 -> 0.01
num_val = int(len(df)*0.1) # 0.1 -> 0.01
num_test = int(len(df)*0.1) # 0.1 -> 0.01

# assign indices
id_all = np.random.choice(len(df), len(df), replace=False)
id_train = id_all[0:num_train]
id_val = id_all[num_train : num_val+num_train]
id_test = id_all[num_val+num_train : num_val+num_train+num_test]

# actual split
train_set = df.iloc[id_train]
val_set = df.iloc[id_val]
test_set = df.iloc[id_test]

# remove transcripts containing ♫
train_set = train_set[~train_set['transcript'].str.contains('♫')]
val_set = val_set[~val_set['transcript'].str.contains('♫')]
test_set = test_set[~test_set['transcript'].str.contains('♫')]

def data_prep(data_set):

    # Dataset Cleanup
    data_set = data_set.drop('url',axis=1)
    data_set = data_set['transcript']
    data_set = data_set.str.replace("\(.*?\)", " ")\
    .str.replace("\[.*?\]", " ")\
    .str.replace(";", ". ")\
    .str.replace(":", ". ")\
    .str.replace('"', ' ')\
    .str.replace('!', '. ')\
    .str.replace(" — (?=[a-z])", ", ")\
    .str.replace(" — (?=[A-Z])", ". ")\
    .str.replace("(?<=[a-z])\.(?=[A-Z])", ". ")\
    .str.replace("(?<=[a-z])\?(?=[A-Z])", ". ")\
    .str.replace("(?<= )'(?=[a-zA-Z])", " ")\
    .str.replace("(?<=[a-z])\'(?= )", " ")\
    .str.replace("\'(?= )", " ")\
    .str.replace(" — ", " ")\
    .str.replace('\.+', '.')\
    .str.replace(' +', ' ')\
    .str.lower()
    # hyphens are hard to handle. for now sentences like below still have an issue:
    # one - on - one tutoring works best so that's what we tried to emulate like with me and my mom even though we knew it would be one - on - thousands 

    temp_list_1 = []
    for sentences in data_set:
        temp_list_1 += re.split('(?<=\.)|(?<=\?)',sentences)

    temp_list_2 = []
    for item in temp_list_1:
        temp_list_2.append(re.sub('^ ','',item))

    temp_list_3 = []
    for s in temp_list_2:
        try:
            if s[-1] == ".":
                temp_list_3.append(s)
            elif s[-1] == "?":
                temp_list_3.append(s)
            else:
                pass
        except:
            pass

    del data_set
    del temp_list_1
    del temp_list_2

    total_words = 0
    combined_text = ""
    outer_list = []

    # create outer_list, a list of sentences that don't go beyond 400 words
    for s in temp_list_3:
        if total_words + len(word_tokenize(s)) < 400:
            combined_text += (s + " ")
            total_words += len(word_tokenize(s))
        else:
            outer_list.append(combined_text)
            combined_text = ""
            total_words = 0        

    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    encoded_data_data = tokenizer.batch_encode_plus(outer_list, max_length=450, padding='max_length', truncation=True, return_tensors='pt')

    punc_mask_outer = []
    ids_no_punc_outer = []
    attention_mask_outer = []

    for j in range(len(encoded_data_data['input_ids'])):

        # punctuation mask for sentences
        punc_mask = []
        for i in encoded_data_data['input_ids'][j]:
            if i == 1012:
                punc_mask.pop()
                punc_mask.append(1) # period
            elif i == 1029:
                punc_mask.pop()
                punc_mask.append(2) # question mark
            elif i == 1010:
                punc_mask.pop()
                punc_mask.append(3) # comma
            else:
                punc_mask.append(0)
        punc_mask_outer.append(torch.tensor(punc_mask))

        # sentences converted to word ids excluding punctuations
        # len(punc_mask) should be the same as len(ids_no_punc)
        ids_no_punc = []
        for i in encoded_data_data['input_ids'][j]:
            if i == 1012:
                pass
            elif i == 1029:
                pass
            elif i == 1010:
                pass
            else:
                ids_no_punc.append(i)
        ids_no_punc_outer.append(torch.tensor(ids_no_punc))

        # attention_mask with subwords set to 0 except for the last one
        attention_mask = []
        first_hash = True
        for i in encoded_data_data['input_ids'][j]:
            if (i == 101 or i == 102 or i == 0): # CLS, SEP, PAD
                attention_mask.append(0)
            elif (i == 1029 or i == 1010 or i == 1012):
                pass
            else:
                if re.match(r'^##', tokenizer.decode([i])):         
                    if first_hash == True:
                        attention_mask.pop()
                        attention_mask.append(0)
                        first_hash == False
                    attention_mask.append(1)
                else:
                    if first_hash == False:
                        attention_mask.pop()
                    attention_mask.append(1)                
        attention_mask_outer.append(torch.tensor(attention_mask))

    # figure out max length so that PADs can be added till it reaches max
    token_lengths = []
    for i in range(len(punc_mask_outer)):
        token_lengths.append(len(punc_mask_outer[i]))
    token_length_max = np.max(token_lengths)

    for i in range(len(punc_mask_outer)):
        # add PAD again because length is not equal after removing punctuations
        zeros = [0] * (token_length_max - len(punc_mask_outer[i]))

        punc_mask = torch.cat((punc_mask_outer[i], torch.tensor(zeros)), 0)
        ids_no_punc = torch.cat((ids_no_punc_outer[i], torch.tensor(zeros)), 0)
        attention_mask = torch.cat((attention_mask_outer[i], torch.tensor(zeros)), 0)

        if i != 0:
            pass
            punc_mask_outer_adjusted = torch.cat((punc_mask_outer_adjusted, punc_mask.view(1,-1)),0)
            ids_no_punc_outer_adjusted = torch.cat((ids_no_punc_outer_adjusted, ids_no_punc.view(1,-1)),0)
            attention_mask_outer_adjusted = torch.cat((attention_mask_outer_adjusted, attention_mask.view(1,-1)),0)
        else:
            punc_mask_outer_adjusted = punc_mask.view(1,-1)
            ids_no_punc_outer_adjusted = ids_no_punc.view(1,-1)
            attention_mask_outer_adjusted = attention_mask.view(1,-1)
            
    return ids_no_punc_outer_adjusted, attention_mask_outer_adjusted, punc_mask_outer_adjusted, punc_mask_outer, tokenizer
    
tarin_set = data_prep(train_set)
val_set = data_prep(val_set)
test_set = data_prep(test_set)
    
with open('train_set.pickle', 'wb') as f:
    pickle.dump(tarin_set, f) 

with open('val_set.pickle', 'wb') as f:
    pickle.dump(val_set, f) 

with open('test_set.pickle', 'wb') as f:
    pickle.dump(test_set, f) 

## Train and Val with Early Stopping

In [None]:
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(f'Prep started: {current_time}', flush=True)
print()

torch.cuda.empty_cache()

with open('train_set.pickle', 'rb') as f:
    data = pickle.load(f)

ids_no_punc_outer_adjusted, attention_mask_outer_adjusted, punc_mask_outer_adjusted, punc_mask_outer, tokenizer_train = data

with open('val_set.pickle', 'rb') as f:
    data = pickle.load(f)

ids_no_punc_outer_adjusted_val, attention_mask_outer_adjusted_val, punc_mask_outer_adjusted_val, punc_mask_outer_val, tokenizer_val = data
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

batch_size = 32
epochs = 200

dataset_train = TensorDataset(ids_no_punc_outer_adjusted, attention_mask_outer_adjusted, punc_mask_outer_adjusted)
dataloader_train = DataLoader(dataset_train, sampler=RandomSampler(dataset_train), batch_size=batch_size)

distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased",
                                                    num_labels=4,  
                                                    output_attentions=False,
                                                    output_hidden_states=False)
model = DistilBERT_Arch(distilbert)
optimizer = AdamW(model.parameters(),
                 lr=1e-5,
                 eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer,
                                           num_warmup_steps=0,
                                           num_training_steps=len(dataloader_train)*epochs)

# prep for class_weights
for i, p in enumerate(punc_mask_outer):    
    if i != 0:
        punc_cat = torch.cat((punc_cat, p), dim=0)
    else:
        punc_cat = p
        
class_weights = compute_class_weight('balanced', np.unique(punc_cat), punc_cat.numpy())
#TODO: haven't considered attention_mask yet!!
weights = torch.tensor(class_weights, dtype=torch.float)
weights = weights.to(device)
cross_entropy = nn.NLLLoss(weight=weights)

patience_cnt = 0
prev_f1 = 0

param_list = []

now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(f'Training started: {current_time}', flush=True)
print()

val_f1_list = []
val_f1_micro_list = []
val_f1_macro_list = []
val_f1_weighted_list = []

now = datetime.now()
current_time = now.strftime("%m%d%H%M")

folder = './val_results_' + current_time

os.makedirs(folder, exist_ok=True)

for epoch in range(epochs):

    if patience_cnt <= 2:

        loss_total = 0

        preds_masked_all = torch.tensor([0]).to(device)
        labels_masked_all = torch.tensor([0]).to(device)
        
        for batch in dataloader_train:

            model.train()
            model.zero_grad()

            batch = [b.to(device) for b in batch]

            model.to(device)

            outputs = model(batch[0].to(torch.long), batch[1].to(torch.long))
            
            loss = cross_entropy(outputs.to(torch.float32).view(-1, 4), batch[2].to(torch.long).view(-1))

            loss.backward()
            optimizer.step()
            scheduler.step()

            model.eval()

            loss_total += loss         
                
            preds = torch.argmax(outputs, axis=2)
            attention_masks = batch[1].to(torch.bool)
            labels = batch[2]

            preds_masked = torch.masked_select(preds, attention_masks)
            labels_masked = torch.masked_select(labels, attention_masks)
            
            preds_masked_all = torch.cat([preds_masked_all, preds_masked])
            labels_masked_all = torch.cat([labels_masked_all, labels_masked])
                           
        loss = loss_total/len(dataloader_train)
        acc = (preds_masked_all == labels_masked_all).sum() / len(preds_masked_all)        

        print(f'epoch: {epoch+1}, tr_loss: {loss.item():.3f}, tr_acc: {acc:.3f}', flush=True)

        with torch.no_grad():

            model.eval()

            dataset_val = TensorDataset(ids_no_punc_outer_adjusted_val, attention_mask_outer_adjusted_val, punc_mask_outer_adjusted_val)
            dataloader_val = DataLoader(dataset_val, sampler=RandomSampler(dataset_val), batch_size=batch_size)
            
            preds_masked_all = torch.tensor([0]).to(device)
            labels_masked_all = torch.tensor([0]).to(device)
                        
            for val_batch in dataloader_val:

                val_batch = [b.to(device) for b in val_batch]
                val_outputs = model(val_batch[0].to(torch.long), val_batch[1].to(torch.long))
            
                preds = torch.argmax(val_outputs, axis=2)
                attention_masks = val_batch[1].to(torch.bool)
                labels = val_batch[2]

                preds_masked = torch.masked_select(preds, attention_masks)
                labels_masked = torch.masked_select(labels, attention_masks)

                preds_masked_all = torch.cat([preds_masked_all, preds_masked])  
                labels_masked_all = torch.cat([labels_masked_all, labels_masked])
            
            val_acc = (preds_masked_all == labels_masked_all).sum() / len(preds_masked_all)        
    
            preds_masked_all = preds_masked_all.to('cpu').numpy()
            labels_masked_all = labels_masked_all.to('cpu').numpy()
    
            val_f1 = f1_score(labels_masked_all, preds_masked_all, average=None)
            val_f1_micro = f1_score(labels_masked_all, preds_masked_all, average='micro')
            val_f1_macro = f1_score(labels_masked_all, preds_masked_all, average='macro')
            val_f1_weighted = f1_score(labels_masked_all, preds_masked_all, average='weighted')
            
            val_f1_list.append(val_f1)
            val_f1_micro_list.append(val_f1_micro)
            val_f1_macro_list.append(val_f1_macro)
            val_f1_weighted_list.append(val_f1_weighted)
                        
            now = datetime.now()
            current_time = now.strftime("%H:%M:%S")
            print(f'Epoch complete: {current_time}', flush=True)
            
            print(f'val_acc: {val_acc:.4f}', flush=True)
            print(f'val_f1: {val_f1}', flush=True)
            print(f'val_f1_micro: {val_f1_micro:.4f}', flush=True)
            print(f'val_f1_macro: {val_f1_macro:.4f}', flush=True)
            print(f'val_f1_weighted: {val_f1_weighted:.4f}', flush=True)
            print()
                               
            if prev_f1 >= val_f1_macro:
                patience_cnt += 1
            else:
                patience_cnt = 0
            
            prev_f1 = val_f1_macro
            param_list.append(model)
                
            torch.save(param_list, folder + '/distilbert_result.pt')
            
            f = open(folder + '/val_f1_list.txt', 'wb')
            pickle.dump(val_f1_list, f)
            
            f = open(folder + '/val_f1_micro_list.txt', 'wb')
            pickle.dump(val_f1_micro_list, f)
            
            f = open(folder + '/val_f1_macro_list.txt', 'wb')
            pickle.dump(val_f1_macro_list, f)

            f = open(folder + '/val_f1_weighted_list.txt', 'wb')
            pickle.dump(val_f1_weighted_list, f)

    else:
        print(f'3rd consecutive degrades observed at epoch {epoch}. So the best is epoch {epoch-3}', flush=True)
        break

torch.save(param_list[-4], folder + '/distilbert_result.pt')
        
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(f'Completed: {current_time}', flush=True)

### Val Results Visualization

In [None]:
folder = "/home/ubuntu/Development/punctuation/val_results_05060418"

file = open(folder + '/val_f1_list.txt', 'rb')
val_f1 = pickle.load(file)
file.close()

file = open(folder + '/val_f1_micro_list.txt', 'rb')
val_f1_micro = pickle.load(file)
file.close()

file = open(folder + '/val_f1_macro_list.txt', 'rb')
val_f1_macro = pickle.load(file)
file.close()

file = open(folder + '/val_f1_weighted_list.txt', 'rb')
val_f1_weighted = pickle.load(file)
file.close()

In [None]:
fig = plt.figure(figsize=(12, 10))
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
ax3 = fig.add_subplot(313)

ax1.plot(range(1, len(val_f1_micro)+1), val_f1_micro)
ax2.plot(range(1, len(val_f1_macro)+1), val_f1_macro)
ax3.plot(range(1, len(val_f1_weighted)+1), val_f1_weighted)

ax1.axvline(14, color="yellow")
ax2.axvline(14, color="yellow")
ax3.axvline(14, color="yellow")

ax1.set_title('val_f1_micro')
ax2.set_title('val_f1_macro')
ax3.set_title('val_f1_weighted');

In [None]:
label_0 = [x[0] for x in val_f1]
label_1 = [x[1] for x in val_f1]
label_2 = [x[2] for x in val_f1]
label_3 = [x[3] for x in val_f1]

fig = plt.figure(figsize=(15, 12))
ax1 = fig.add_subplot(411)
ax2 = fig.add_subplot(412)
ax3 = fig.add_subplot(413)
ax4 = fig.add_subplot(414)

ax1.plot(range(1, len(label_0)+1), label_0)
ax2.plot(range(1, len(label_1)+1), label_1)
ax3.plot(range(1, len(label_2)+1), label_2)
ax4.plot(range(1, len(label_3)+1), label_3)

ax1.axvline(14, color="yellow")
ax2.axvline(14, color="yellow")
ax3.axvline(14, color="yellow")
ax4.axvline(14, color="yellow")

ax1.set_title('f1: non_punc')
ax2.set_title('f1: period')
ax3.set_title('f1: question')
ax4.set_title('f1: comma');

## Test

In [6]:
folder = "/home/ubuntu/Development/punctuation/val_results_05060418"

now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(f'Started: {current_time}')

model = torch.load(folder + '/distilbert_result.pt')
model.eval()

with torch.no_grad(): # is this necessary?

    with open('/home/ubuntu/Development/punctuation/data/test_set.pickle', 'rb') as f:
        data = pickle.load(f)

    ids_no_punc_outer_adjusted_test, attention_mask_outer_adjusted_test, punc_mask_outer_adjusted_test, punc_mask_outer_test, tokenizer_test = data
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    batch_size = 32
#     epochs = 19

    dataset_test = TensorDataset(ids_no_punc_outer_adjusted_test, attention_mask_outer_adjusted_test, punc_mask_outer_adjusted_test)
    dataloader_test = DataLoader(dataset_test, sampler=RandomSampler(dataset_test), batch_size=batch_size)

#     test_nume = 0
#     test_deno = 0

#     test_preds_cat = []
#     test_labels_cat = []
    
    preds_masked_all = torch.tensor([0]).to(device)
    labels_masked_all = torch.tensor([0]).to(device)
    
    for test_batch in dataloader_test:

#         test_batch = [b.to(device) for b in test_batch]
#         test_outputs = model(test_batch[0].to(torch.long), test_batch[1].to(torch.long))

#         for j in range(test_outputs.shape[0]):
#         # for jth sample in a batch

#             test_preds = np.argmax(test_outputs[j].to('cpu').detach().numpy(), axis=1)
#             test_labels = test_batch[2].to(torch.long)[j].to('cpu').detach().numpy()

#             # for ith token in a jth sample
#             # if attention mask is not 0, check if predictinon matches label
#             for i in range(len(test_batch[1][j])):
#                 if test_batch[1][j][i] != 0:
#                     if test_preds[i] == test_labels[i]:    
#                         test_nume += 1
                        
#                     test_preds_cat.append(test_preds[i])
#                     test_labels_cat.append(test_labels[i])
                        
#                     test_deno += 1

#     test_preds_cat = np.array(test_preds_cat)
#     test_labels_cat = np.array(test_labels_cat)

        test_batch = [b.to(device) for b in test_batch]
        test_outputs = model(test_batch[0].to(torch.long), test_batch[1].to(torch.long))

        preds = torch.argmax(test_outputs, axis=2)
        attention_masks = test_batch[1].to(torch.bool)
        labels = test_batch[2]

        preds_masked = torch.masked_select(preds, attention_masks)
        labels_masked = torch.masked_select(labels, attention_masks)

        preds_masked_all = torch.cat([preds_masked_all, preds_masked])  
        labels_masked_all = torch.cat([labels_masked_all, labels_masked])

    test_acc = (preds_masked_all == labels_masked_all).sum() / len(preds_masked_all)        

    preds_masked_all = preds_masked_all.to('cpu').numpy()
    labels_masked_all = labels_masked_all.to('cpu').numpy()
    
#     test_acc = test_nume/test_deno    
#     test_f1 = f1_score(test_labels_cat, test_preds_cat, average=None)
#     test_f1_micro = f1_score(test_labels_cat, test_preds_cat, average='micro')
#     test_f1_macro = f1_score(test_labels_cat, test_preds_cat, average='macro')
#     test_f1_weighted = f1_score(test_labels_cat, test_preds_cat, average='weighted')
    
    test_f1 = f1_score(labels_masked_all, preds_masked_all, average=None)
    test_f1_micro = f1_score(labels_masked_all, preds_masked_all, average='micro')
    test_f1_macro = f1_score(labels_masked_all, preds_masked_all, average='macro')
    test_f1_weighted = f1_score(labels_masked_all, preds_masked_all, average='weighted')
    
    print(f'test_acc: {test_acc:.3f}', flush=True)
    print(f'test_f1: {test_f1}', flush=True)
    print(f'test_f1_micro: {test_f1_micro:.3f}', flush=True)
    print(f'test_f1_macro: {test_f1_macro:.3f}', flush=True)
    print(f'test_f1_weighted: {test_f1_weighted:.3f}', flush=True)
    
now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print(f'Completed: {current_time}', flush=True)

Started: 12:47:02
test_acc: 0.901
test_f1: [0.95720808 0.72939145 0.46546088 0.5636657 ]
test_f1_micro: 0.901
test_f1_macro: 0.679
test_f1_weighted: 0.913
Completed: 12:47:24


In [9]:
pwd

'/home/ubuntu/Development/punctuation'