# Import Libraries

In [2]:
import numpy as np
import json
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
# from transformers import AutoConfig, AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer
from rouge import Rouge

# import datasets
import pandas as pd
from tqdm.notebook import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


# Load JSON Data

In [3]:
with open('wikihow_trimmed.json','r',encoding='utf-8') as file:
    data = json.load(file)

# Model

In [4]:
class SumExtractorModel(nn.Module):
    def __init__(self, dim_in, dim_hid, dim_out, num_layers,num_head, dropout=0.5):
        super(SumExtractorModel,self).__init__()
                
        assert dim_in % num_head == 0 #Check if input dimension is divisible by the number of attention heads
        
        encoder_layer = nn.TransformerEncoderLayer(dim_in, nhead=num_head, dim_feedforward=dim_hid, device=device,dropout=dropout)
        
        self.encoder = nn.TransformerEncoder(encoder_layer,num_layers)
        
        self.classifier = nn.Sequential(
            nn.Linear(dim_in*dim_out, dim_hid),
            nn.ReLU(),
            nn.Linear(dim_hid, dim_out),
            nn.Sigmoid()
        )
        
    def forward(self, input_sent, attention_mask=None):
        
        # Get embeddings from pretrained model
        embeddings = input_sent
        bs, seq_len, embed_dim = embeddings.shape
        embeddings = embeddings.permute(1,0,2) #Switch from (batch, seq_len, embed_dim) -> (seq_len, batch, embed_dim)
        # embeddings = self.pretrained_model.encode(input_sent)
        # print(embeddings.shape)
        
        # Normalize embeddings        
        output = self.encoder(embeddings,attention_mask)
        # print("After Encoder: ",output.shape)
        output = output.view(bs,-1)
        # print("After Reshape: ",output.shape)
        output = self.classifier(output)
        # print("After Classifier: ",output.shape)

        # raise "stop here"
        return output #Output after a Sigmoid activation, to use with BCELoss

# Train Parameters

In [5]:
max_length = data['max_length']
dim_in = 384 #For MiniLM-L6-v2 .Size of embeddings given by pre-trained model, depends on output of pretrained model dimensions
# dim_in = 768 #For distilroberta
dim_hid = 256
encoder_layers = 3
num_heads = 4 #Attention heads

# num_iters = 15
num_iters = 20000 #Iterations per epoch
num_epochs = 10

lr = 1e-4

In [6]:
pretrained_name = "sentence-transformers/all-MiniLM-L6-v2"
# pretrained_name = "sentence-transformers/all-distilroberta-v1"

pretrained_model = SentenceTransformer(pretrained_name).to(device)

In [7]:
model = SumExtractorModel(dim_in, dim_hid, max_length, encoder_layers, num_heads).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr = lr)
loss_function = nn.BCELoss().to(device)

# Train Function

In [8]:
def train(input_tensor, target_tensor, model, loss_fct, optimizer,mask=None):
    
    model.train()
    
    inputs = input_tensor
    targets = target_tensor

    optimizer.zero_grad()

    output = model(inputs,mask)

    loss = loss_fct(output,targets)

    loss.backward()
    optimizer.step()
    cur_los = loss.item()
    
    return cur_los

In [9]:
def validation(input_tensor, target_tensor, model, loss_fct,mask=None):
    
    model.eval()
    
    inputs = input_tensor
    targets = target_tensor

    output = model(inputs,mask)

    loss = loss_fct(output,targets)
    val_los = loss.item()
    
    return val_los

In [14]:
def test(testset, model, rouge):
    
    rouge1 = []
    rouge2 = []
    rougel = []
    sum_len = 4 #How many sentences to use as summary from model output
    model.eval()
    i=1
    target = random.randint(2000,4000)
    
    for testpair in tqdm(testset,position =0, desc='Rouge Testing'):
        sentence, _, = getEmbed(testpair[0],testpair[1], data['max_length'], pretrained_model)
        sentence = sentence.to(device)
        
        orig_len = len(testpair[0])
        tgt_sum = [testpair[0][i] for i in range(len(testpair[1])) if testpair[1][i] == 1]
        tgt_sum = [' '.join(tgt_sum)]
        
        hyp = model(sentence).squeeze()[:orig_len] #Trim away excess output from padding
        
        if len(hyp) < sum_len:
            #If unable to get 4 sentences just get the highest number available.
            hyp = torch.topk(hyp,len(hyp))[1].detach().cpu().numpy()
        else:
            #Get top 4 sentences to use for summary.
            hyp = torch.topk(hyp,sum_len)[1].detach().cpu().numpy()
        
        hyp = sorted(hyp)
        hyp = [testpair[0][i] for i in hyp]
        hyp = [' '.join(hyp)]
        
        if i == target:
            print(f"Model Output: {hyp}")
            print()
            print(f"Actual Summary: {tgt_sum}")
            print()
            print(i)
        
        score = rouge.get_scores(hyp, tgt_sum,avg=True)
        rouge1.append(score['rouge-1']['f'])
        rouge2.append(score['rouge-2']['f'])
        rougel.append(score['rouge-l']['f'])
        
        i += 1

    print("====Rouge Scores Below====")
    print(f"Rouge-1 Score: {np.mean(rouge1)}")
    print(f"Rouge-2 Score: {np.mean(rouge2)}")
    print(f"Rouge-l Score: {np.mean(rougel)}")

In [11]:
def getEmbed(source, target, maxlen, model):
    pad_token = '[PAD]'
    sentences = source
    labels = target
    orig_len = len(source)
    
    if len(sentences) < maxlen:
        pad_num = maxlen - len(sentences)
        sentences = sentences + [pad_token for i in range(pad_num)]
        labels = labels + [0 for i in range(pad_num)]
             
    output = model.encode(sentences)
    output = torch.tensor(output,dtype=torch.float32).unsqueeze(0)
    labels = torch.tensor(labels,dtype=torch.float32).unsqueeze(0)
#     mask = torch.zeros((maxlen,maxlen), dtype=torch.float32)
    
#     for i in range(maxlen):
#         for j in range(maxlen):
#             if (j < orig_len) and (i < orig_len):
#                 mask[i][j] = 1.0
    
    return output, labels# , mask

In [12]:
best_val_loss = None

for e in tqdm(range(1,num_epochs+1), position=0, desc='Epoch'):
    train_loss = []
    val_loss = []
    
    train_set = [random.choice(data['train']) for _ in range(num_iters)]
    
    #Training
    for i in tqdm(range(1,num_iters+1), desc=f'Epoch {e}, Train', position = 1):
        pair = train_set[i-1]

        # input_tensor, target_tensor, mask = getEmbed(pair[0],pair[1], data['max_length'], pretrained_model)
        # input_tensor, target_tensor, mask = input_tensor.to(device), target_tensor.to(device), mask.to(device)
        
        input_tensor, target_tensor = getEmbed(pair[0],pair[1], data['max_length'], pretrained_model)
        input_tensor, target_tensor = input_tensor.to(device), target_tensor.to(device)
        
        train_iter_loss = train(input_tensor, target_tensor, model, loss_function, optimizer,mask)
        
        train_loss.append(train_iter_loss)
        
    print(f"Train loss: {np.mean(train_loss)}")
        
    #Validation
    for val_pair in tqdm(data['val'],position=2,desc=f'Epoch {e}, Validation'):
    
        # input_tensor, target_tensor, mask = getEmbed(val_pair[0],val_pair[1], data['max_length'], pretrained_model)
        # input_tensor, target_tensor, mask = input_tensor.to(device), target_tensor.to(device), mask.to(device)
    
        input_tensor, target_tensor = getEmbed(val_pair[0],val_pair[1], data['max_length'], pretrained_model)
        input_tensor, target_tensor = input_tensor.to(device), target_tensor.to(device)
    
        val_iter_loss = validation(input_tensor, target_tensor, model, loss_function,mask)
        val_loss.append(val_iter_loss)
    
    if (not best_val_loss) or (np.mean(val_loss) < best_val_loss):
        best_val_loss = np.mean(val_loss)
        torch.save(model.state_dict(), f"./model_weights/bestmodel{e}.pth")
        print(f"**Model saved** New best validation loss: {best_val_loss}")
        
    print(f"Validation loss: {np.mean(val_loss)}")

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.1091239801120013


Epoch 1, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.16882556703277019
Validation loss: 0.16882556703277019


Epoch 2, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.10247664355374873


Epoch 2, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.1282325627142135
Validation loss: 0.1282325627142135


Epoch 3, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.10023083524005487


Epoch 3, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.12094052935910637
Validation loss: 0.12094052935910637


Epoch 4, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.099416930958163


Epoch 4, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.1144787656022671
Validation loss: 0.1144787656022671


Epoch 5, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.1045500708853826


Epoch 5, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

Validation loss: 0.11489556273322835


Epoch 6, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.09984934172555804


Epoch 6, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.11360214592395987
Validation loss: 0.11360214592395987


Epoch 7, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.09830715311709791


Epoch 7, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

Validation loss: 0.11442446409069887


Epoch 8, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.0991067492838949


Epoch 8, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

Validation loss: 0.11863741799033056


Epoch 9, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.10385441885134206


Epoch 9, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.11151248142768411
Validation loss: 0.11151248142768411


Epoch 10, Train:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loss: 0.10434386999038979


Epoch 10, Validation:   0%|          | 0/5543 [00:00<?, ?it/s]

**Model saved** New best validation loss: 0.10886168674248801
Validation loss: 0.10886168674248801


In [None]:
# torch.save(model.state_dict(), "simpletransformer.pth")

# Evaluation on test set

In [12]:
rouge = Rouge()
model.load_state_dict(torch.load("./model_weights/3layer_4head/bestmodel8.pth")) #Used to load model state dict of saved model.
# model.load_state_dict(torch.load("./model_weights/distilroberta_3layer_4head/bestmodel10.pth"))

<All keys matched successfully>

In [15]:
test(data['test'], model, rouge)

Rouge Testing:   0%|          | 0/5502 [00:00<?, ?it/s]

====Rouge Scores Below====
Rouge-1 Score: 0.37598216423969455
Rouge-2 Score: 0.21010031050480532
Rouge-l Score: 0.3517201756941423


In [15]:
test(data['test'], model, rouge)

Rouge Testing:   0%|          | 0/5502 [00:00<?, ?it/s]

Model Output: ['In order to make dopamine , your body needs tyrosine -- after a bunch of synthesizing and technical terms , it gets turned into your happy fuel . It can be found in soy products ( like tofu , etc . ) , fish , dairy , and meats . However , many dairy and meat products are high in calories and fat , so exercise caution and monitor your caloric intake with this high - dopamine diet . Dopamine is easy to oxidize , and antioxidants may reduce free radical damage to the brain cells that produce dopamine .']

Actual Summary: ['In order to make dopamine , your body needs tyrosine -- after a bunch of synthesizing and technical terms , it gets turned into your happy fuel . Many fruits and vegetables are rich in antioxidants , including : Beta - carotene and carotenoids : Greens , orange vegetables and fruits , asparagus , broccoli , beets   Vitamin C : Peppers , oranges , strawberries , cauliflower ,']
====Rouge Scores Below====
Rouge-1 Score: 0.3860141396670621
Rouge-2 Score: 0.

In [8]:
trainable_param = 0

for param in model.parameters():
    if param.requires_grad:
        trainable_param += param.numel()
        
print(f"Model contains {trainable_param} parameters")

Model contains 24028496 parameters
