# Memsum-SBERT

## Model

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import math
import copy
import json
import pickle

from hk_utils import print
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from tqdm.notebook import tqdm_notebook as tqdm
from rouge_score import rouge_scorer

from src.MemSum_Full.model import AddMask, PositionalEncoding, MultiHeadAttention, FeedForward, TransformerEncoderLayer, TransformerDecoderLayer, MultiHeadPoolingLayer

import os
import gc
import time
from datetime import datetime

device = "cuda:0"

In [3]:
# gpu analyzer
import gc
import torch
import math

def convert_size(size_bytes):
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s%s" % (s, size_name[i])

def analyzeGPU():
    tensor_num = 0
    tensor_size = 0
    param_num = 0
    param_size = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.is_cuda:
                if type(obj) is torch.nn.parameter.Parameter:
                    param_num += 1
                    param_size += obj.element_size() * obj.nelement()
                else:
                    #tensors.append(obj)
                    tensor_num += 1
                    tensor_size += obj.element_size() * obj.nelement()
        except:
            pass
        
    return (tensor_num, convert_size(tensor_size)), (param_num, convert_size(param_size))

In [4]:
# utils
def print_time(time_in_sec):
    if time_in_sec <= 60:
        seconds = time_in_sec
        return f"{seconds:.3f}s"
    elif 60 < time_in_sec <= 60 * 60:
        minutes = int(time_in_sec // 60)
        seconds = time_in_sec - (minutes * 60)
        return f"{minutes:02}m {seconds:.3f}s"
    elif 60 * 60 < time_in_sec <= 60 * 60 * 24:
        hours = int(time_in_sec // (60 * 60))
        minutes = int(( time_in_sec - hours * (60 * 60) ) // 60)
        seconds = time_in_sec - hours * (60 * 60) - minutes * 60
        return f"{hours:02}h {minutes:02}m {seconds:.3f}s"
    else:
        days = int(time_in_sec % (60 * 60 * 24))
        hours = int(( time_in_sec - days * (60 * 60 * 24) ) // (60 * 60))
        minutes = int(( time_in_sec - days * (60 * 60 * 24) - hours * (60 * 60) ) // 60)
        seconds = time_in_sec - days * (60 * 60 * 24) - hours * (60 * 60) - minutes * 60
        return f"{days}d {hours:02}h {minutes:02}m {seconds:.3f}s"
    
def get_ngram(sentence, n, delim="_"):
    words = sentence.lower().split()
    
    ngram_set = set()
    for i in range(len(words) - n + 1):
        ngram_set.add(delim.join(words[i:i + n]))
        
    return ngram_set

def update_moving_average( m_ema, m, decay=0.999 ):
    with torch.no_grad():
        param_dict_m_ema = m_ema.parameters() 
        param_dict_m = m.parameters() 
        for param_m_ema, param_m in zip( param_dict_m_ema, param_dict_m ):
            param_m_ema.copy_( decay * param_m_ema + (1 - decay) * param_m )
            
def save_model(batch, scores, model_dir):
    os.makedirs(model_dir, exist_ok=True)
    
    data = {
        "sbert": sbert_model_name,
        "current_batch": batch,
        "global_context_encoder": global_context_encoder_ema.state_dict(),
        "extraction_context_decoder": extraction_context_decoder_ema.state_dict(),
        "extractor": extractor_ema.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scores": scores
    }
    
    filename = f"{batch:06}.{scores['rouge1']:.3f}-{scores['rouge2']:.3f}-{scores['rougeL']:.3f}.pt"
    model_path = os.path.join(model_dir, filename)
    
    torch.save(data, model_path)
    
    return model_path

def load_model(model_path, device):
    ckpt = torch.load(model_path, map_location=device)
    
    global_context_encoder.load_state_dict(ckpt["global_context_encoder"])
    extraction_context_decoder.load_state_dict(ckpt["extraction_context_decoder"])
    extractor.load_state_dict(ckpt["extractor"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])

In [5]:
# models
class LocalSentenceEncoder(nn.Module):
    def __init__(
        self,
        model_name="sentence-transformers/all-mpnet-base-v2",
        input_type="packed",
        max_sentence_num=500
    ):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.input_type = input_type
        self.max_sentence_num = max_sentence_num
        
    @property
    def embed_dim(self):
        return self.model.config.hidden_size
    
    def _forward_packed_input(self, input_ids, attention_mask, sentence_nums):
        # input_ids : tensor[sum(sentence_nums), max_seq_len]
        # attention_mask : tensor[sum(sentence_nums), max_seq_len]
        # sentence_nums : list[batch_size]
        
        with torch.no_grad():
            token_embeddings = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            ).last_hidden_state
            
            # mean pooling
            attention_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            output = torch.sum(token_embeddings * attention_mask_expanded, 1) / torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
            
            # normalization
            sentence_embeddings = F.normalize(output, p=2, dim=1)
            
            # reshape : tensor[sum(sentence_nums), emb_dim] -> tensor[batch_size, max_sent_num, emb_dim]
            sentence_embeddings_by_doc = []
            idx_cumsum = 0
            for sentence_num in sentence_nums:
                padded_sentence_embeddings = torch.zeros([self.max_sentence_num, sentence_embeddings.shape[-1]]).to(input_ids.device)
                padded_sentence_embeddings[:sentence_num, :] = sentence_embeddings[idx_cumsum:idx_cumsum + sentence_num]
                sentence_embeddings_by_doc.append(padded_sentence_embeddings)
                
                idx_cumsum += sentence_num
                
            sentence_embeddings = torch.stack(sentence_embeddings_by_doc, dim=0)
            
            return sentence_embeddings  # tensor[batch_size, max_sent_num, emb_dim]
        
    def _forward_padded_input(self, input_ids, attention_mask):
        # input_ids : tensor[batch_size * max_sent_num, max_seq_len]
        # attention_mask : tensor[batch_size * max_sent_num, max_seq_len]
        
        with torch.no_grad():
            batch_size = input_ids.shape[0] // self.max_sentence_num
            
            token_embeddings = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            ).last_hidden_state
            
            # mean pooling
            attention_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            output = torch.sum(token_embeddings * attention_mask_expanded, 1) / torch.clamp(attention_mask_expanded.sum(1), min=1e-9)
            
            # normalization
            sentence_embeddings = F.normalize(output, p=2, dim=1)
            
            # reshape : tensor[batch_size * max_sent_num, emb_dim] -> tensor[batch_size, max_sent_num, emb_dim]
            sentence_embeddings_by_doc = []
            for i in range(batch_size):
                sentence_embeddings_by_doc.append(sentence_embeddings[i * self.max_sentence_num:(i + 1) * self.max_sentence_num, :])
            sentence_embeddings = torch.stack(sentence_embeddings_by_doc, dim=0)
            
            return sentence_embeddings  # tensor[batch_size, max_sent_num, emb_dim]
        
    def forward(self, input_ids, attention_mask, sentence_nums=None):
        if self.input_type == "packed":
            assert type(sentence_nums) is list, "Invalid input:sentence_nums"
            return self._forward_packed_input(input_ids, attention_mask, sentence_nums)
        elif self.input_type == "padded":
            return self._forward_padded_input(input_ids, attention_mask)

class GlobalContextEncoder(nn.Module):
    def __init__(self, embed_dim,  num_heads, hidden_dim, num_dec_layers ):
        super().__init__()
        # self.pos_encode = PositionalEncoding( embed_dim)
        # self.layer_list = nn.ModuleList( [  TransformerEncoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
        self.rnn = nn.LSTM(  embed_dim, embed_dim, 2, batch_first = True, bidirectional = True)
        self.norm_out = nn.LayerNorm( 2*embed_dim )
        self.ln_out = nn.Linear( 2*embed_dim, embed_dim )

    def forward(self, sen_embed, doc_mask, dropout_rate = 0.):
        net, _ = self.rnn( sen_embed )
        net = self.ln_out(F.relu( self.norm_out(net) ) )
        return net

class ExtractionContextDecoder( nn.Module ):
    def __init__( self, embed_dim,  num_heads, hidden_dim, num_dec_layers ):
        super().__init__()
        self.layer_list = nn.ModuleList( [  TransformerDecoderLayer( embed_dim, num_heads, hidden_dim ) for _ in range(num_dec_layers) ] )
    ## remaining_mask: set all unextracted sen indices as True
    ## extraction_mask: set all extracted sen indices as True
    def forward( self, sen_embed, remaining_mask, extraction_mask, dropout_rate = 0. ):
        net = sen_embed
        for layer in self.layer_list:
            #  encoder_output, x,  src_mask, trg_mask , dropout_rate = 0.
            net = layer( sen_embed, net, remaining_mask, extraction_mask, dropout_rate )
        return net

class Extractor( nn.Module ):
    def __init__( self, embed_dim, num_heads ):
        super().__init__()
        self.norm_input = nn.LayerNorm( 3*embed_dim  )
        
        self.ln_hidden1 = nn.Linear(  3*embed_dim, 2*embed_dim  )
        self.norm_hidden1 = nn.LayerNorm( 2*embed_dim  )
        
        self.ln_hidden2 = nn.Linear(  2*embed_dim, embed_dim  )
        self.norm_hidden2 = nn.LayerNorm( embed_dim  )

        self.ln_out = nn.Linear(  embed_dim, 1 )

        self.mh_pool = MultiHeadPoolingLayer( embed_dim, num_heads )
        self.norm_pool = nn.LayerNorm( embed_dim  )
        self.ln_stop = nn.Linear(  embed_dim, 1 )

        self.mh_pool_2 = MultiHeadPoolingLayer( embed_dim, num_heads )
        self.norm_pool_2 = nn.LayerNorm( embed_dim  )
        self.ln_baseline = nn.Linear(  embed_dim, 1 )

    def forward( self, sen_embed, relevance_embed, redundancy_embed , extraction_mask, dropout_rate = 0. ):
        if redundancy_embed is None:
            redundancy_embed = torch.zeros_like( sen_embed )
        net = self.norm_input( F.dropout( torch.cat( [ sen_embed, relevance_embed, redundancy_embed ], dim = 2 ) , p = dropout_rate  )  ) 
        net = F.relu( self.norm_hidden1( F.dropout( self.ln_hidden1( net ) , p = dropout_rate  )   ))
        hidden_net = F.relu( self.norm_hidden2( F.dropout( self.ln_hidden2( net)  , p = dropout_rate  )  ))
        
        p = self.ln_out( hidden_net ).sigmoid().squeeze(2)

        net = F.relu( self.norm_pool(  F.dropout( self.mh_pool( hidden_net, extraction_mask) , p = dropout_rate  )  ))
        p_stop = self.ln_stop( net ).sigmoid().squeeze(1)

        net = F.relu( self.norm_pool_2(  F.dropout( self.mh_pool_2( hidden_net, extraction_mask ) , p = dropout_rate  )  ))
        baseline = self.ln_baseline(net)

        return p, p_stop, baseline

## GovReport

In [6]:
# model instances
max_sentence_num = 500
max_sequence_len = 100

sbert_model_name = "sentence-transformers/all-mpnet-base-v2"

local_sentence_encoder = LocalSentenceEncoder(
    model_name=sbert_model_name,
    input_type="packed",
    max_sentence_num=max_sentence_num
).to(device)
tokenizer = AutoTokenizer.from_pretrained(sbert_model_name)
embed_dim = local_sentence_encoder.embed_dim

global_context_encoder = GlobalContextEncoder(
    embed_dim=embed_dim,
    num_heads=8,
    hidden_dim=1024,
    num_dec_layers=3
).to(device)

extraction_context_decoder = ExtractionContextDecoder(
    embed_dim=embed_dim,
    num_heads=8,
    hidden_dim=1024,
    num_dec_layers=3
).to(device)

extractor = Extractor(
    embed_dim=embed_dim,
    num_heads=8
).to(device)

global_context_encoder_ema = copy.deepcopy(global_context_encoder).to(device)
extraction_context_decoder_ema = copy.deepcopy(extraction_context_decoder).to(device)
extractor_ema = copy.deepcopy(extractor).to(device)

In [7]:
model_parameters = \
[par for par in global_context_encoder.parameters() if par.requires_grad] +\
[par for par in extraction_context_decoder.parameters() if par.requires_grad] +\
[par for par in extractor.parameters() if par.requires_grad]

# Memo. make lse not trainable - too small VRAM
#[par for par in local_sentence_encoder.parameters() if par.requires_grad] +\

optimizer = optim.Adam(
    model_parameters,
    lr=1e-4,
    weight_decay=1e-6
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.1,
    patience=5,
    verbose=True
)

rouge_cal = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeLsum'], use_stemmer=True)

In [8]:
# dataset
class GovReportDataset(Dataset):
    def __init__(
        self, 
        split,
        tokenizer,
        output_type="packed",
        max_sequence_len=100,
        max_sentence_num=500
    ):
        assert output_type in ["packed", "padded"], "Invalid param:output_type"
        
        self.tokenizer = tokenizer
        self.split = split
        self.output_type = output_type
        self.max_sequence_len = max_sequence_len
        self.max_sentence_num = max_sentence_num
        
        if self.split == "train":
            data_path = "data/gov-report/train_GOVREPORT.jsonl"
        elif self.split == "val":
            data_path = "data/gov-report/val_GOVREPORT.jsonl"
        elif self.split == "test":
            data_path = "data/gov-report/test_GOVREPORT.jsonl"
        else:
            raise ValueError("Invalid param:split")
            
        with open(data_path, "r") as f:
            self.data = [json.loads(line) for line in f.readlines()]
            
        self.data = [
            x for x in self.data 
            if not (
                "text" in x and len(x["text"]) == 0
                and "indices" in x and len(x["indices"]) == 0
                and "summary" in x and len(x["summary"]) == 0
                and "score" in x and len(x["score"]) == 0
            )
        ]

    def __len__(self):
        return len(self.data)
    
    def is_good_sentence(self, text):
        # bad sentence :
        #   1) empty sentence (after strip())
        #   2) only a single character(such as ".") (after strip())
        if len(text.strip()) <= 1:
            return False
        
        return True  
    
    def __getitem__(self, idx):
        doc = self.data[idx]
        sentences = doc["text"]
        gold_summary = doc["summary"]
        
        # select only good sentences
        sentences = [(idx, text) for idx, text in enumerate(sentences) if self.is_good_sentence(text)]
        sent_idx_mapper = {old_sent_idx: new_sent_idx for new_sent_idx, (old_sent_idx, _) in enumerate(sentences)}
        sentences = [text for _, text in sentences]
        
        # trim sentences to max_sentence_num
        sentences = sentences[:self.max_sentence_num]
        sentence_num = len(sentences)
        
        # tokenize sentences
        tokenized = self.tokenizer(
            sentences,
            truncation=True,
            padding="max_length",
            max_length=self.max_sequence_len,
            return_tensors="pt"
        )
        
        input_ids = tokenized.input_ids  # [sent_num, max_seq_len]
        attention_mask = tokenized.attention_mask  # [sent_num, max_seq_len]
        
        if self.output_type == "padded":
            # pad input_ids
            padded_input_ids = torch.zeros(
                [self.max_sentence_num, self.max_sequence_len],
                dtype=input_ids.dtype,
                device=input_ids.device
            )
            padded_input_ids[:sentence_num, :] = input_ids
            input_ids = padded_input_ids  # [max_sent_num, max_seq_len]
            
            # pad attention_mask
            padded_attention_mask = torch.zeros(
                [self.max_sentence_num, self.max_sequence_len],
                dtype=attention_mask.dtype,
                device=attention_mask.device
            )
            padded_attention_mask[:sentence_num, :] = attention_mask
            attention_mask = padded_attention_mask  # [max_sent_num, max_seq_len]
        
        # make (padded) doc_mask
        doc_mask = torch.BoolTensor([False] * sentence_num + [True] * (self.max_sentence_num - sentence_num))
        # Note. doc_mask is not required for packed output(since no sentence needs to be masked),
        # but just calculate it for these two reasons
        #   - eventually packed output needs to be converted to padded output after forwarding LSE,
        #     and doc_mask needs to be built at that moment
        #   - to make the shape of output be same with padded output
        
        if self.split == "train":
            oracle_summaries = doc["indices"]
            oracle_summary_scores = doc["score"]

            # randomly select oracle summary
            random_idx = np.random.choice(len(oracle_summaries))
            oracle_summary = oracle_summaries[random_idx]
            oracle_summary_score = oracle_summary_scores[random_idx]

            # reset idxes in oracle_summary
            oracle_summary = [sent_idx_mapper[old_sent_idx] for old_sent_idx in oracle_summary if old_sent_idx in sent_idx_mapper]
            
            # trim oracle_summary to max_sentence_num
            oracle_summary = np.array(oracle_summary)
            oracle_summary = oracle_summary[oracle_summary < sentence_num]

            # shuffle oracle_summary
            np.random.shuffle(oracle_summary)
            oracle_summary = oracle_summary.tolist()

            # pad oracle_summary (pad value = -1)
            oracle_summary = torch.LongTensor(oracle_summary + [-1] * (self.max_sentence_num - len(oracle_summary)))
            # Note. pad oracle_summary even for packed output,
            # because eventually label need to be converted to padded version after forwarding LSE
        
            return input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score
            
        else: # val, test
            return input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary

class collate_fn:
    def __init__(self, split, output_type="packed"):
        assert split in ["train", "val", "test"], "Invalid input:split"
        assert output_type in ["packed", "padded"], "Invalid input:output_type"
        
        self.split = split
        self.output_type = output_type
        
    def __call__(self, batch):
        if self.output_type == "packed":
            input_ids = torch.cat([x[0] for x in batch], dim=0)  # tensor[sum(sent_nums), max_seq_len]
            attention_mask = torch.cat([x[1] for x in batch], dim=0)  # tensor[sum(sent_nums), max_seq_len]
            doc_mask = torch.stack([x[2] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            #label = torch.stack([x[3] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            sentence_num = [x[3] for x in batch]  # list[batch_size]  # same
            
            if self.split == "train":
                oracle_summary = torch.stack([x[4] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
                oracle_summary_score = torch.tensor([x[5] for x in batch])  # tensor[batch_size]  # same
            else: # val, test
                sentences = [x[4] for x in batch]  # list[batch_size, sent_num]  # same
                gold_summary = [x[5] for x in batch]  # list[batch_size, gold_summary_num]  # same
                
        elif self.output_type == "padded":
            input_ids = torch.cat([x[0] for x in batch], dim=0)  # tensor[batch_size * max_sent_num, max_seq_len]
            attention_mask = torch.cat([x[1] for x in batch], dim=0)  # tensor[batch_size * max_sent_num, max_seq_len]
            doc_mask = torch.stack([x[2] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            #label = torch.stack([x[3] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            sentence_num = [x[3] for x in batch]  # list[batch_size]  # same
            
            if self.split == "train":
                oracle_summary = torch.stack([x[4] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
                oracle_summary_score = torch.tensor([x[5] for x in batch])  # tensor[batch_size]  # same
            else: # val, test
                sentences = [x[4] for x in batch]  # list[batch_size, sent_num]  # same
                gold_summary = [x[5] for x in batch]  # list[batch_size, gold_summary_num]  # same
        
        if self.split == "train":
            return input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score
        
        else: # val, test
            return input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary

In [9]:
# dataset instances
dataset_output_type = "packed"
batch_size = 4  # use smaller batch_size for padded_output
num_workers = 10

train_dataset = GovReportDataset(
    split="train",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

val_dataset = GovReportDataset(
    split="val",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

test_dataset = GovReportDataset(
    split="test",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn(
        split="train",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn(
        split="val",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn(
        split="val",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

In [10]:
def train_iteration(batch):
    input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score = batch
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]
    oracle_summary = oracle_summary.to(device)
    oracle_summary_score = oracle_summary_score.to(device)
    
    batch_size = len(sentence_num)
    
    local_sen_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
    global_context_embed = global_context_encoder(local_sen_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]
    
    doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
    remaining_mask_np = np.ones_like( doc_mask_np ).astype( bool ) | doc_mask_np
    extraction_mask_np = np.zeros_like( doc_mask_np ).astype( bool ) | doc_mask_np
    
    log_action_prob_list = []
    log_stop_prob_list = []

    done_list = []
    extraction_context_embed = None
    
    for step in range(max_sentence_num):
        remaining_mask = torch.from_numpy(remaining_mask_np).to(device)  # tensor[batch_size, max_sent_num]
        extraction_mask = torch.from_numpy(extraction_mask_np).to(device)  # tensor[batch_size, max_sent_num]

        if step > 0: # if at least one sentence is selected
            extraction_context_embed = extraction_context_decoder(local_sen_embed, remaining_mask, extraction_mask, dropout_rate=0.1)

        sentence_scores, p_stop, _ = extractor(local_sen_embed, global_context_embed, extraction_context_embed, extraction_mask, dropout_rate=0.1)
        # sentence_scores : tensor[batch_size, max_sent_num]
        # p_stop : tensor[batch_size]
        
        p_stop = p_stop.unsqueeze(1)  # p_stop : tensor[batch_size, 1]
        m_stop = Categorical(torch.cat([1 - p_stop, p_stop], dim=1))

        # grep step-th summary sentence idx
        summary_sent_idxs = oracle_summary[:, step]  # tensor[batch_size]

        # find documents that all summary sentences are extracted(= padding is selected)
        is_summarization_over = (summary_sent_idxs == -1)  # tensor[batch_size]

        if len(done_list) > 0:
            # is_just_stop : 이번 step에 막 summarization over 한거면 true, 아니면 false
            is_just_stop = torch.logical_and(~done_list[-1], is_summarization_over)
        else:
            is_just_stop = is_summarization_over

        if torch.all(is_summarization_over) and not torch.any(is_just_stop): # 모든 doc들의 summarization이 끝나고 1 step 진행한 경우 break
            break

        sentence_scores = sentence_scores.masked_fill(extraction_mask, 1e-12)
        normalized_sentence_scores = sentence_scores / sentence_scores.sum(dim=1, keepdims=True)  # tensor[batch_size, max_sent_num]

        extracted_sentence_scores = normalized_sentence_scores[range(batch_size), summary_sent_idxs]  # tensor[batch_size]
        log_action_prob = extracted_sentence_scores.masked_fill(is_summarization_over, 1.0).log()

        log_stop_prob = m_stop.log_prob(is_summarization_over.long())
        log_stop_prob = log_stop_prob.masked_fill(torch.logical_xor(is_summarization_over, is_just_stop), 0)

        log_action_prob_list.append(log_action_prob.unsqueeze(1))
        log_stop_prob_list.append(log_stop_prob.unsqueeze(1))
        done_list.append(is_summarization_over)

        for doc_idx, summary_sent_idx in enumerate(summary_sent_idxs.tolist()):
            if summary_sent_idx != -1:
                remaining_mask_np[doc_idx, summary_sent_idx] = False
                extraction_mask_np[doc_idx, summary_sent_idx] = True

    log_action_prob_list = torch.cat(log_action_prob_list, dim=1)
    log_stop_prob_list = torch.cat(log_stop_prob_list, dim=1)
    log_prob_list = log_action_prob_list + log_stop_prob_list
    
    log_prob_list = log_prob_list.sum(dim=1) / ((log_prob_list != 0).float().sum(dim=1))
    
    loss = (-log_prob_list * oracle_summary_score).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [11]:
p_stop_threshold = 0.6
max_extracted_sentences_per_document = 22

def validation_iteration(batch):
    with torch.no_grad():
        input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary = batch
        
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]

        batch_size = len(sentence_num)

        local_sent_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
        global_context_embed = global_context_encoder_ema(local_sent_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]
        
        doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
        
        extracted_sents = []
        extracted_sent_idxs = []
        
        sent_score_history = []
        p_stop_history = []
        
        for doc_i in range(batch_size):
            current_doc_mask_np = doc_mask_np[doc_i:doc_i + 1]
            current_remaining_mask_np = np.ones_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            current_extraction_mask_np = np.zeros_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            
            current_local_sent_embed = local_sent_embed[doc_i:doc_i + 1]
            current_global_context_embed = global_context_embed[doc_i:doc_i + 1]
            current_extraction_context_embed = None
            
            current_extracted_sent_idxs = []
            extracted_sent_ngrams = set()
            sent_score_history_for_doc_i = []
            p_stop_history_for_doc_i = []
            
            for step in range(max_extracted_sentences_per_document + 1):
                current_extraction_mask = torch.from_numpy(current_extraction_mask_np).to(device)
                current_remaining_mask = torch.from_numpy(current_remaining_mask_np).to(device)
                
                if step > 0:
                    current_extraction_context_embed = extraction_context_decoder_ema(
                        current_local_sent_embed,
                        current_remaining_mask,
                        current_extraction_mask
                    )
                    
                sent_scores, p_stop, _ = extractor_ema(
                    current_local_sent_embed,
                    current_global_context_embed,
                    current_extraction_context_embed,
                    current_extraction_mask
                )
                
                p_stop = p_stop.item()
                
                p_stop_history_for_doc_i.append(p_stop)
                sent_score_history_for_doc_i.append(sent_scores.detach().cpu().numpy())
                
                sent_scores = sent_scores.masked_fill(current_extraction_mask, 1e-12)
                normalized_sent_scores = sent_scores / sent_scores.sum(dim=1, keepdims=True)
                
                is_summarization_over = (p_stop > p_stop_threshold)
                
                _, sorted_sent_idxs = normalized_sent_scores.sort(dim=1, descending=True)
                sorted_sent_idxs = sorted_sent_idxs[0]
                
                # skip if selected sentence's ngrams are already in extracted sentences' ngram set
                # (avoid redundancy)
                extracted = False
                for sent_idx in [x.item() for x in sorted_sent_idxs]:
                    if sent_idx >= sentence_num[doc_i]:
                        break
                    
                    selected_sent = sentences[doc_i][sent_idx]
                    selected_sent_ngrams = get_ngram(selected_sent, n=3)
                        
                    if len(selected_sent_ngrams & extracted_sent_ngrams) < 1:
                        extracted_sent_ngrams.update(selected_sent_ngrams)
                        extracted = True
                        break
                
                if not is_summarization_over and step < max_extracted_sentences_per_document and extracted:
                    current_extracted_sent_idxs.append(sent_idx)
                    current_extraction_mask_np[0, sent_idx] = True
                    current_remaining_mask_np[0, sent_idx] = False
                else:
                    extracted_sents.append([sentences[doc_i][sent_idx] for sent_idx in current_extracted_sent_idxs])
                    extracted_sent_idxs.append(current_extracted_sent_idxs)
                    break
                    
                    
            sent_score_history.append(sent_score_history_for_doc_i)
            p_stop_history.append(p_stop_history_for_doc_i)
        
        scores = []
        for doc_i in range(batch_size):
            hyp = "\n".join(extracted_sents[doc_i]).strip()
            ref = "\n".join(gold_summary[doc_i]).strip()

            score = rouge_cal.score(hyp, ref)
            scores.append((
                score["rouge1"].fmeasure,
                score["rouge2"].fmeasure,
                score["rougeLsum"].fmeasure
            ))

        return scores

In [12]:
p_stop_threshold = 0.6
max_extracted_sentences_per_document = 22

def test_iteration(batch):
    with torch.no_grad():
        input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary = batch
        
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]

        batch_size = len(sentence_num)

        local_sent_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
        global_context_embed = global_context_encoder(local_sent_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]
        
        doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
        
        extracted_sents = []
        extracted_sent_idxs = []
        
        sent_score_history = []
        p_stop_history = []
        
        for doc_i in range(batch_size):
            current_doc_mask_np = doc_mask_np[doc_i:doc_i + 1]
            current_remaining_mask_np = np.ones_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            current_extraction_mask_np = np.zeros_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            
            current_local_sent_embed = local_sent_embed[doc_i:doc_i + 1]
            current_global_context_embed = global_context_embed[doc_i:doc_i + 1]
            current_extraction_context_embed = None
            
            current_extracted_sent_idxs = []
            extracted_sent_ngrams = set()
            sent_score_history_for_doc_i = []
            p_stop_history_for_doc_i = []
            
            for step in range(max_extracted_sentences_per_document + 1):
                current_extraction_mask = torch.from_numpy(current_extraction_mask_np).to(device)
                current_remaining_mask = torch.from_numpy(current_remaining_mask_np).to(device)
                
                if step > 0:
                    current_extraction_context_embed = extraction_context_decoder(
                        current_local_sent_embed,
                        current_remaining_mask,
                        current_extraction_mask
                    )
                    
                sent_scores, p_stop, _ = extractor(
                    current_local_sent_embed,
                    current_global_context_embed,
                    current_extraction_context_embed,
                    current_extraction_mask
                )
                
                p_stop = p_stop.item()
                
                p_stop_history_for_doc_i.append(p_stop)
                sent_score_history_for_doc_i.append(sent_scores.detach().cpu().numpy())
                
                sent_scores = sent_scores.masked_fill(current_extraction_mask, 1e-12)
                normalized_sent_scores = sent_scores / sent_scores.sum(dim=1, keepdims=True)
                
                is_summarization_over = (p_stop > p_stop_threshold)
                
                _, sorted_sent_idxs = normalized_sent_scores.sort(dim=1, descending=True)
                sorted_sent_idxs = sorted_sent_idxs[0]
                
                # skip if selected sentence's ngrams are already in extracted sentences' ngram set
                # (avoid redundancy)
                extracted = False
                for sent_idx in [x.item() for x in sorted_sent_idxs]:
                    if sent_idx >= sentence_num[doc_i]:
                        break
                    
                    selected_sent = sentences[doc_i][sent_idx]
                    selected_sent_ngrams = get_ngram(selected_sent, n=3)
                        
                    if len(selected_sent_ngrams & extracted_sent_ngrams) < 1:
                        extracted_sent_ngrams.update(selected_sent_ngrams)
                        extracted = True
                        break
                
                if not is_summarization_over and step < max_extracted_sentences_per_document and extracted:
                    current_extracted_sent_idxs.append(sent_idx)
                    current_extraction_mask_np[0, sent_idx] = True
                    current_remaining_mask_np[0, sent_idx] = False
                else:
                    extracted_sents.append([sentences[doc_i][sent_idx] for sent_idx in current_extracted_sent_idxs])
                    extracted_sent_idxs.append(current_extracted_sent_idxs)
                    break
                    
                    
            sent_score_history.append(sent_score_history_for_doc_i)
            p_stop_history.append(p_stop_history_for_doc_i)
        
        scores = []
        for doc_i in range(batch_size):
            hyp = "\n".join(extracted_sents[doc_i]).strip()
            ref = "\n".join(gold_summary[doc_i]).strip()

            score = rouge_cal.score(hyp, ref)
            scores.append((
                score["rouge1"].fmeasure,
                score["rouge2"].fmeasure,
                score["rougeLsum"].fmeasure
            ))

        return scores

In [10]:
# train
cur_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
model_dir = os.path.join("model/memsum_sbert/gov-report", cur_time)

epochs = 5  # 50
print_every = 100  # 100
validate_every = 500  # 100
early_stop = False  # False

print(f"model dir: {model_dir}")
print()

total_train_batch = 0
for epoch in range(epochs):
    print.info(f"[Epoch {epoch + 1}]")
    running_loss = 0
    total_train_elapsed_time = 0
    
    for i, train_batch in enumerate(train_dataloader):
        train_start_time = time.time()
        total_train_batch += 1
        
        loss = train_iteration(train_batch)
        running_loss += loss
        
        #scheduler.step()
        
        update_moving_average(global_context_encoder_ema, global_context_encoder)
        update_moving_average(extraction_context_decoder_ema, extraction_context_decoder)
        update_moving_average(extractor_ema, extractor)
        
        train_elapsed_time = time.time() - train_start_time
        total_train_elapsed_time += train_elapsed_time
        
        if total_train_batch % print_every == 0:
            print.info(f"  [Training Batch {i + 1} / {len(train_dataloader)} (total_batch = {total_train_batch})]")
            #current_lr = optimizer.param_groups[0]['lr']
            print.info(f"    training : loss = {running_loss / print_every : .3f}, elapsed time = (current) {print_time(train_elapsed_time)} / (mean) {print_time(total_train_elapsed_time / print_every)}")
            running_loss = 0
            total_train_elapsed_time = 0
                
        if (total_train_batch % validate_every == 0) or (i == len(train_dataloader) - 1):
            val_start_time = time.time()
            val_scores = []
            for j, val_batch in enumerate(val_dataloader):
                print.info(f"    [Validation Batch {j + 1} / {len(val_dataloader)}]", end="\r")
                val_scores += validation_iteration(val_batch)
                    
            val_rouge1, val_rouge2, val_rougeL = list(zip(*val_scores))
            
            avg_val_rouge1 = np.mean(val_rouge1)
            avg_val_rouge2 = np.mean(val_rouge2)
            avg_val_rougeL = np.mean(val_rougeL)
            
            model_path = save_model(
                batch=total_train_batch, 
                scores={
                    "rouge1": avg_val_rouge1,
                    "rouge2": avg_val_rouge2,
                    "rougeL": avg_val_rougeL
                },
                model_dir=model_dir
            )
            
            val_elapsed_time = time.time() - val_start_time
            
            print.info(f"    validation : rouge1 = {avg_val_rouge1 : .4f}, rouge2 = {avg_val_rouge2 : .4f}, rougeL = {avg_val_rougeL : .4f}, time = {print_time(val_elapsed_time)} => {model_path}")
               
    if early_stop:
        break

model dir: model/memsum_sbert/gov-report/2022-12-12-11-36-03

[36m[Epoch 1][0m
[36m  [Training Batch 100 / 4380 (total_batch = 100)][0m
[36m    training : loss =  3.247, elapsed time = (current) 2.058s / (mean) 2.144s[0m
[36m  [Training Batch 200 / 4380 (total_batch = 200)][0m
[36m    training : loss =  3.138, elapsed time = (current) 2.081s / (mean) 2.078s[0m
[36m  [Training Batch 300 / 4380 (total_batch = 300)][0m
[36m    training : loss =  3.158, elapsed time = (current) 1.669s / (mean) 2.056s[0m
[36m  [Training Batch 400 / 4380 (total_batch = 400)][0m
[36m    training : loss =  3.156, elapsed time = (current) 1.789s / (mean) 2.086s[0m
[36m  [Training Batch 500 / 4380 (total_batch = 500)][0m
[36m    training : loss =  3.098, elapsed time = (current) 1.524s / (mean) 2.016s[0m
[36m    validation : rouge1 =  0.5424, rouge2 =  0.2033, rougeL =  0.5106, time = 11m 34.150s => model/memsum_sbert/gov-report/2022-12-12-11-36-03/000501.0.542-0.203-0.511.pt[0m
[36m  [T

In [13]:
# test
ckpt_path = "model/memsum_sbert/gov-report/2022-12-12-11-36-03/017501.0.581-0.242-0.549.pt"  # edit here
load_model(ckpt_path, device=device)

test_scores = []
for i, batch in enumerate(test_dataloader):
    print.info(f"[Test Batch {i + 1} / {len(test_dataloader)}]", end="\r")
    current_test_scores = test_iteration(batch)

    test_scores += current_test_scores

test_rouge1, test_rouge2, test_rougeL = list(zip(*test_scores))
            
avg_test_rouge1 = np.mean(test_rouge1)
avg_test_rouge2 = np.mean(test_rouge2)
avg_test_rougeL = np.mean(test_rougeL)
print.info(f"test : rouge1 = {avg_test_rouge1 : .4f}, rouge2 = {avg_test_rouge2 : .4f}, rougeL = {avg_test_rougeL : .4f}")

[36mtest : rouge1 =  0.5827, rouge2 =  0.2465, rougeL =  0.5507[0m


## TOS;DR

In [6]:
# model instances
max_sentence_num = 300
max_sequence_len = 50

sbert_model_name = "sentence-transformers/all-mpnet-base-v2"

local_sentence_encoder = LocalSentenceEncoder(
    model_name=sbert_model_name,
    input_type="packed",
    max_sentence_num=max_sentence_num
).to(device)
tokenizer = AutoTokenizer.from_pretrained(sbert_model_name)
embed_dim = local_sentence_encoder.embed_dim

global_context_encoder = GlobalContextEncoder(
    embed_dim=embed_dim,
    num_heads=8,
    hidden_dim=1024,
    num_dec_layers=3
).to(device)

extraction_context_decoder = ExtractionContextDecoder(
    embed_dim=embed_dim,
    num_heads=8,
    hidden_dim=1024,
    num_dec_layers=3
).to(device)

extractor = Extractor(
    embed_dim=embed_dim,
    num_heads=8
).to(device)

global_context_encoder_ema = copy.deepcopy(global_context_encoder).to(device)
extraction_context_decoder_ema = copy.deepcopy(extraction_context_decoder).to(device)
extractor_ema = copy.deepcopy(extractor).to(device)

In [7]:
model_parameters = \
[par for par in global_context_encoder.parameters() if par.requires_grad] +\
[par for par in extraction_context_decoder.parameters() if par.requires_grad] +\
[par for par in extractor.parameters() if par.requires_grad]

# Memo. make lse not trainable - too small VRAM
#[par for par in local_sentence_encoder.parameters() if par.requires_grad] +\

optimizer = optim.Adam(
    model_parameters,
    lr=1e-4,
    weight_decay=1e-6
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.1,
    patience=5,
    verbose=True
)

rouge_cal = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeLsum'], use_stemmer=True)

In [8]:
# from chkpoint
ckpt = "model/memsum_sbert/gov-report/2022-12-12-11-36-03/017501.0.581-0.242-0.549.pt"
load_model(ckpt, device=device)

In [9]:
# dataset
class TOSDRDataset(Dataset):
    def __init__(
        self, 
        split,
        tokenizer,
        output_type="packed",
        mode="only-gold-summary",
        max_sequence_len=100,
        max_sentence_num=500
    ):
        assert mode in ["only-gold-summary", "all"], "Invalid param:mode"
        assert output_type in ["packed", "padded"], "Invalid param:output_type"
        
        self.tokenizer = tokenizer
        self.split = split
        self.output_type = output_type
        self.mode = mode
        self.max_sequence_len = max_sequence_len
        self.max_sentence_num = max_sentence_num
        
        if self.split == "train":
            data_path = "data/tos-dr/train.jsonl"
        elif self.split == "val":
            data_path = "data/tos-dr/val.jsonl"
        elif self.split == "test":
            data_path = "data/tos-dr/test.jsonl"
        else:
            raise ValueError("Invalid param:split")
            
        with open(data_path, "r") as f:
            self.data = [json.loads(line) for line in f.readlines()]
            
        self.data = [
            x for x in self.data 
            if not (
                "text" in x and len(x["text"]) == 0
                and "indices" in x and len(x["indices"]) == 0
                and "summary" in x and len(x["summary"]) == 0
                and "score" in x and len(x["score"]) == 0
            )
        ]

    def __len__(self):
        return len(self.data)
    
    def is_good_sentence(self, text):
        # bad sentence :
        #   1) empty sentence (after strip())
        #   2) only a single character(such as ".") (after strip())
        if len(text.strip()) <= 1:
            return False
        
        return True  
    
    def __getitem__(self, idx):
        doc = self.data[idx]
        sentences = doc["text"]
        gold_summary = doc["summary"]
        
        # select only good sentences
        sentences = [(idx, text) for idx, text in enumerate(sentences) if self.is_good_sentence(text)]
        sent_idx_mapper = {old_sent_idx: new_sent_idx for new_sent_idx, (old_sent_idx, _) in enumerate(sentences)}
        sentences = [text for _, text in sentences]
        
        # trim sentences to max_sentence_num
        sentences = sentences[:self.max_sentence_num]
        sentence_num = len(sentences)
        
        # tokenize sentences
        tokenized = self.tokenizer(
            sentences,
            truncation=True,
            padding="max_length",
            max_length=self.max_sequence_len,
            return_tensors="pt"
        )
        
        input_ids = tokenized.input_ids  # [sent_num, max_seq_len]
        attention_mask = tokenized.attention_mask  # [sent_num, max_seq_len]
        
        if self.output_type == "padded":
            # pad input_ids
            padded_input_ids = torch.zeros(
                [self.max_sentence_num, self.max_sequence_len],
                dtype=input_ids.dtype,
                device=input_ids.device
            )
            padded_input_ids[:sentence_num, :] = input_ids
            input_ids = padded_input_ids  # [max_sent_num, max_seq_len]
            
            # pad attention_mask
            padded_attention_mask = torch.zeros(
                [self.max_sentence_num, self.max_sequence_len],
                dtype=attention_mask.dtype,
                device=attention_mask.device
            )
            padded_attention_mask[:sentence_num, :] = attention_mask
            attention_mask = padded_attention_mask  # [max_sent_num, max_seq_len]
        
        # make (padded) doc_mask
        doc_mask = torch.BoolTensor([False] * sentence_num + [True] * (self.max_sentence_num - sentence_num))
        # Note. doc_mask is not required for packed output(since no sentence needs to be masked),
        # but just calculate it for these two reasons
        #   - eventually packed output needs to be converted to padded output after forwarding LSE,
        #     and doc_mask needs to be built at that moment
        #   - to make the shape of output be same with padded output
        
        if self.split == "train":
            oracle_summaries = doc["indices"]
            oracle_summary_scores = doc["score"]
            
            if self.mode == "only-gold-summary":
                oracle_summary = oracle_summaries[0]
                oracle_summary_score = oracle_summary_scores[0]
            else: # self.mode == "all"
                # randomly select oracle summary
                random_idx = np.random.choice(len(oracle_summaries))
                oracle_summary = oracle_summaries[random_idx]
                oracle_summary_score = oracle_summary_scores[random_idx]

            # reset idxes in oracle_summary
            oracle_summary = [sent_idx_mapper[old_sent_idx] for old_sent_idx in oracle_summary if old_sent_idx in sent_idx_mapper]
            
            # trim oracle_summary to max_sentence_num
            oracle_summary = np.array(oracle_summary)
            oracle_summary = oracle_summary[oracle_summary < sentence_num]

            # shuffle oracle_summary
            np.random.shuffle(oracle_summary)
            oracle_summary = oracle_summary.tolist()

            # pad oracle_summary (pad value = -1)
            oracle_summary = torch.LongTensor(oracle_summary + [-1] * (self.max_sentence_num - len(oracle_summary)))
            # Note. pad oracle_summary even for packed output,
            # because eventually label need to be converted to padded version after forwarding LSE
        
            return input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score
            
        else: # val, test
            return input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary

class collate_fn:
    def __init__(self, split, output_type="packed"):
        assert split in ["train", "val", "test"], "Invalid input:split"
        assert output_type in ["packed", "padded"], "Invalid input:output_type"
        
        self.split = split
        self.output_type = output_type
        
    def __call__(self, batch):
        if self.output_type == "packed":
            input_ids = torch.cat([x[0] for x in batch], dim=0)  # tensor[sum(sent_nums), max_seq_len]
            attention_mask = torch.cat([x[1] for x in batch], dim=0)  # tensor[sum(sent_nums), max_seq_len]
            doc_mask = torch.stack([x[2] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            #label = torch.stack([x[3] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            sentence_num = [x[3] for x in batch]  # list[batch_size]  # same
            
            if self.split == "train":
                oracle_summary = torch.stack([x[4] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
                oracle_summary_score = torch.tensor([x[5] for x in batch])  # tensor[batch_size]  # same
            else: # val, test
                sentences = [x[4] for x in batch]  # list[batch_size, sent_num]  # same
                gold_summary = [x[5] for x in batch]  # list[batch_size, gold_summary_num]  # same
                
        elif self.output_type == "padded":
            input_ids = torch.cat([x[0] for x in batch], dim=0)  # tensor[batch_size * max_sent_num, max_seq_len]
            attention_mask = torch.cat([x[1] for x in batch], dim=0)  # tensor[batch_size * max_sent_num, max_seq_len]
            doc_mask = torch.stack([x[2] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            #label = torch.stack([x[3] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
            sentence_num = [x[3] for x in batch]  # list[batch_size]  # same
            
            if self.split == "train":
                oracle_summary = torch.stack([x[4] for x in batch], dim=0)  # tensor[batch_size, max_sent_num]  # same
                oracle_summary_score = torch.tensor([x[5] for x in batch])  # tensor[batch_size]  # same
            else: # val, test
                sentences = [x[4] for x in batch]  # list[batch_size, sent_num]  # same
                gold_summary = [x[5] for x in batch]  # list[batch_size, gold_summary_num]  # same
        
        if self.split == "train":
            return input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score
        
        else: # val, test
            return input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary

In [10]:
# dataset instances
dataset_output_type = "packed"
dataset_mode = "only-gold-summary" # ["only-gold-summary", "all"]
batch_size = 4  # use smaller batch_size for padded_output
num_workers = 10

train_dataset = TOSDRDataset(
    split="train",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    mode=dataset_mode,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

val_dataset = TOSDRDataset(
    split="val",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    mode=dataset_mode,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

test_dataset = TOSDRDataset(
    split="test",
    tokenizer=tokenizer,
    output_type=dataset_output_type,
    mode=dataset_mode,
    max_sequence_len=max_sequence_len,
    max_sentence_num=max_sentence_num
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn(
        split="train",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn(
        split="val",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn(
        split="val",
        output_type=dataset_output_type
    ),
    num_workers=num_workers
)

In [11]:
max_probing_sentence_num = max_sentence_num
max_probing_sentence_num = 50

def train_iteration(batch):
    input_ids, attention_mask, doc_mask, sentence_num, oracle_summary, oracle_summary_score = batch
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]
    oracle_summary = oracle_summary.to(device)
    oracle_summary_score = oracle_summary_score.to(device)

    batch_size = len(sentence_num)

    local_sen_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
    global_context_embed = global_context_encoder(local_sen_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]

    doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
    remaining_mask_np = np.ones_like( doc_mask_np ).astype( bool ) | doc_mask_np
    extraction_mask_np = np.zeros_like( doc_mask_np ).astype( bool ) | doc_mask_np

    log_action_prob_list = []
    log_stop_prob_list = []

    done_list = []

    extraction_context_embed = None

    #for step in range(max_sentence_num):
    for step in range(max_probing_sentence_num):
        remaining_mask = torch.from_numpy(remaining_mask_np).to(device)  # tensor[batch_size, max_sent_num]
        extraction_mask = torch.from_numpy(extraction_mask_np).to(device)  # tensor[batch_size, max_sent_num]

        if step > 0: # if at least one sentence is selected
            extraction_context_embed = extraction_context_decoder(local_sen_embed, remaining_mask, extraction_mask, dropout_rate=0.1)

        sentence_scores, p_stop, _ = extractor(local_sen_embed, global_context_embed, extraction_context_embed, extraction_mask, dropout_rate=0.1)
        # sentence_scores : tensor[batch_size, max_sent_num]
        # p_stop : tensor[batch_size]

        p_stop = p_stop.unsqueeze(1)  # p_stop : tensor[batch_size, 1]
        m_stop = Categorical(torch.cat([1 - p_stop, p_stop], dim=1))

        # grep step-th summary sentence idx
        summary_sent_idxs = oracle_summary[:, step]  # tensor[batch_size]

        # find documents that all summary sentences are extracted(= padding is selected)
        is_summarization_over = (summary_sent_idxs == -1)  # tensor[batch_size]

        if len(done_list) > 0:
            # is_just_stop : 이번 step에 막 summarization over 한거면 true, 아니면 false
            is_just_stop = torch.logical_and(~done_list[-1], is_summarization_over)
        else:
            is_just_stop = is_summarization_over

        if torch.all(is_summarization_over) and not torch.any(is_just_stop): # 모든 doc들의 summarization이 끝나고 1 step 진행한 경우 break
            break

        sentence_scores = sentence_scores.masked_fill(extraction_mask, 1e-12)
        normalized_sentence_scores = sentence_scores / sentence_scores.sum(dim=1, keepdims=True)  # tensor[batch_size, max_sent_num]

        extracted_sentence_scores = normalized_sentence_scores[range(batch_size), summary_sent_idxs]  # tensor[batch_size]
        log_action_prob = extracted_sentence_scores.masked_fill(is_summarization_over, 1.0).log()

        log_stop_prob = m_stop.log_prob(is_summarization_over.long())
        log_stop_prob = log_stop_prob.masked_fill(torch.logical_xor(is_summarization_over, is_just_stop), 0)

        log_action_prob_list.append(log_action_prob.unsqueeze(1))
        log_stop_prob_list.append(log_stop_prob.unsqueeze(1))
        done_list.append(is_summarization_over)

        for doc_idx, summary_sent_idx in enumerate(summary_sent_idxs.tolist()):
            if summary_sent_idx != -1:
                remaining_mask_np[doc_idx, summary_sent_idx] = False
                extraction_mask_np[doc_idx, summary_sent_idx] = True

    log_action_prob_list = torch.cat(log_action_prob_list, dim=1)
    log_stop_prob_list = torch.cat(log_stop_prob_list, dim=1)
    log_prob_list = log_action_prob_list + log_stop_prob_list

    log_prob_list = log_prob_list.sum(dim=1) / ((log_prob_list != 0).float().sum(dim=1))

    loss = (-log_prob_list * oracle_summary_score).mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

In [12]:
p_stop_threshold = 0.6
max_extracted_sentences_per_document = 13

def validation_iteration(batch):
    with torch.no_grad():
        input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary = batch
        
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]

        batch_size = len(sentence_num)

        local_sent_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
        global_context_embed = global_context_encoder_ema(local_sent_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]
        
        doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
        
        extracted_sents = []
        extracted_sent_idxs = []
        
        sent_score_history = []
        p_stop_history = []
        
        for doc_i in range(batch_size):
            current_doc_mask_np = doc_mask_np[doc_i:doc_i + 1]
            current_remaining_mask_np = np.ones_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            current_extraction_mask_np = np.zeros_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            
            current_local_sent_embed = local_sent_embed[doc_i:doc_i + 1]
            current_global_context_embed = global_context_embed[doc_i:doc_i + 1]
            current_extraction_context_embed = None
            
            current_extracted_sent_idxs = []
            extracted_sent_ngrams = set()
            sent_score_history_for_doc_i = []
            p_stop_history_for_doc_i = []
            
            for step in range(max_extracted_sentences_per_document + 1):
                #gc.collect()
                #torch.cuda.empty_cache()
                
                current_extraction_mask = torch.from_numpy(current_extraction_mask_np).to(device)
                current_remaining_mask = torch.from_numpy(current_remaining_mask_np).to(device)
                
                if step > 0:
                    current_extraction_context_embed = extraction_context_decoder_ema(
                        current_local_sent_embed,
                        current_remaining_mask,
                        current_extraction_mask
                    )
                    
                sent_scores, p_stop, _ = extractor_ema(
                    current_local_sent_embed,
                    current_global_context_embed,
                    current_extraction_context_embed,
                    current_extraction_mask
                )
                
                p_stop = p_stop.item()
                
                p_stop_history_for_doc_i.append(p_stop)
                sent_score_history_for_doc_i.append(sent_scores.detach().cpu().numpy())
                
                sent_scores = sent_scores.masked_fill(current_extraction_mask, 1e-12)
                normalized_sent_scores = sent_scores / sent_scores.sum(dim=1, keepdims=True)
                
                is_summarization_over = (p_stop > p_stop_threshold)
                
                _, sorted_sent_idxs = normalized_sent_scores.sort(dim=1, descending=True)
                sorted_sent_idxs = sorted_sent_idxs[0]
                
                # skip if selected sentence's ngrams are already in extracted sentences' ngram set
                # (avoid redundancy)
                extracted = False
                for sent_idx in [x.item() for x in sorted_sent_idxs]:
                    if sent_idx >= sentence_num[doc_i]:
                        break
                    
                    selected_sent = sentences[doc_i][sent_idx]
                    selected_sent_ngrams = get_ngram(selected_sent, n=3)
                        
                    if len(selected_sent_ngrams & extracted_sent_ngrams) < 1:
                        extracted_sent_ngrams.update(selected_sent_ngrams)
                        extracted = True
                        break
                
                if not is_summarization_over and step < max_extracted_sentences_per_document and extracted:
                    current_extracted_sent_idxs.append(sent_idx)
                    current_extraction_mask_np[0, sent_idx] = True
                    current_remaining_mask_np[0, sent_idx] = False
                else:
                    extracted_sents.append([sentences[doc_i][sent_idx] for sent_idx in current_extracted_sent_idxs])
                    extracted_sent_idxs.append(current_extracted_sent_idxs)
                    break
                    
                    
            sent_score_history.append(sent_score_history_for_doc_i)
            p_stop_history.append(p_stop_history_for_doc_i)
        
        scores = []
        for doc_i in range(batch_size):
            hyp = "\n".join(extracted_sents[doc_i]).strip()
            ref = "\n".join(gold_summary[doc_i]).strip()

            score = rouge_cal.score(hyp, ref)
            scores.append((
                score["rouge1"].fmeasure,
                score["rouge2"].fmeasure,
                score["rougeLsum"].fmeasure
            ))

        return scores

In [13]:
p_stop_threshold = 0.6
max_extracted_sentences_per_document = 13

def test_iteration(batch):
    with torch.no_grad():
        input_ids, attention_mask, doc_mask, sentence_num, sentences, gold_summary = batch
        
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        doc_mask = doc_mask.to(device)  # tensor[batch_size, max_sent_num]

        batch_size = len(sentence_num)

        local_sent_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
        global_context_embed = global_context_encoder(local_sent_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]
        
        doc_mask_np = doc_mask.detach().cpu().numpy()  # np[batch_size, max_sent_num]
        
        extracted_sents = []
        extracted_sent_idxs = []
        
        sent_score_history = []
        p_stop_history = []
        
        for doc_i in range(batch_size):
            current_doc_mask_np = doc_mask_np[doc_i:doc_i + 1]
            current_remaining_mask_np = np.ones_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            current_extraction_mask_np = np.zeros_like(current_doc_mask_np).astype(bool) | current_doc_mask_np
            
            current_local_sent_embed = local_sent_embed[doc_i:doc_i + 1]
            current_global_context_embed = global_context_embed[doc_i:doc_i + 1]
            current_extraction_context_embed = None
            
            current_extracted_sent_idxs = []
            extracted_sent_ngrams = set()
            sent_score_history_for_doc_i = []
            p_stop_history_for_doc_i = []
            
            for step in range(max_extracted_sentences_per_document + 1):
                #gc.collect()
                #torch.cuda.empty_cache()
                
                current_extraction_mask = torch.from_numpy(current_extraction_mask_np).to(device)
                current_remaining_mask = torch.from_numpy(current_remaining_mask_np).to(device)
                
                if step > 0:
                    current_extraction_context_embed = extraction_context_decoder(
                        current_local_sent_embed,
                        current_remaining_mask,
                        current_extraction_mask
                    )
                    
                sent_scores, p_stop, _ = extractor(
                    current_local_sent_embed,
                    current_global_context_embed,
                    current_extraction_context_embed,
                    current_extraction_mask
                )
                
                p_stop = p_stop.item()
                
                p_stop_history_for_doc_i.append(p_stop)
                sent_score_history_for_doc_i.append(sent_scores.detach().cpu().numpy())
                
                sent_scores = sent_scores.masked_fill(current_extraction_mask, 1e-12)
                normalized_sent_scores = sent_scores / sent_scores.sum(dim=1, keepdims=True)
                
                is_summarization_over = (p_stop > p_stop_threshold)
                
                _, sorted_sent_idxs = normalized_sent_scores.sort(dim=1, descending=True)
                sorted_sent_idxs = sorted_sent_idxs[0]
                
                # skip if selected sentence's ngrams are already in extracted sentences' ngram set
                # (avoid redundancy)
                extracted = False
                for sent_idx in [x.item() for x in sorted_sent_idxs]:
                    if sent_idx >= sentence_num[doc_i]:
                        break
                    
                    selected_sent = sentences[doc_i][sent_idx]
                    selected_sent_ngrams = get_ngram(selected_sent, n=3)
                        
                    if len(selected_sent_ngrams & extracted_sent_ngrams) < 1:
                        extracted_sent_ngrams.update(selected_sent_ngrams)
                        extracted = True
                        break
                
                if not is_summarization_over and step < max_extracted_sentences_per_document and extracted:
                    current_extracted_sent_idxs.append(sent_idx)
                    current_extraction_mask_np[0, sent_idx] = True
                    current_remaining_mask_np[0, sent_idx] = False
                else:
                    extracted_sents.append([sentences[doc_i][sent_idx] for sent_idx in current_extracted_sent_idxs])
                    extracted_sent_idxs.append(current_extracted_sent_idxs)
                    break
                    
                    
            sent_score_history.append(sent_score_history_for_doc_i)
            p_stop_history.append(p_stop_history_for_doc_i)
        
        scores = []
        for doc_i in range(batch_size):
            hyp = "\n".join(extracted_sents[doc_i]).strip()
            ref = "\n".join(gold_summary[doc_i]).strip()

            score = rouge_cal.score(hyp, ref)
            scores.append((
                score["rouge1"].fmeasure,
                score["rouge2"].fmeasure,
                score["rougeLsum"].fmeasure
            ))

        return scores

In [15]:
# train
cur_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
#model_dir = os.path.join("model/memsum_sbert/tos-dr", dataset_mode, cur_time)
model_dir = os.path.join("model/memsum_sbert/tos-dr-from-pretrained", dataset_mode, cur_time)

epochs = 10  # 50
print_every = 100  # 100
validate_every = 200  # 100
early_stop = False  # False

print(f"model dir: {model_dir}")
print()

total_train_batch = 0
for epoch in range(epochs):
    print.info(f"[Epoch {epoch + 1}]")
    running_loss = 0
    total_train_elapsed_time = 0
    
    for i, train_batch in enumerate(train_dataloader):
        train_start_time = time.time()
        total_train_batch += 1
        
        loss = train_iteration(train_batch)
        running_loss += loss
        
        #scheduler.step()
        
        update_moving_average(global_context_encoder_ema, global_context_encoder)
        update_moving_average(extraction_context_decoder_ema, extraction_context_decoder)
        update_moving_average(extractor_ema, extractor)
        
        train_elapsed_time = time.time() - train_start_time
        total_train_elapsed_time += train_elapsed_time
        
        if total_train_batch % print_every == 0:
            print.info(f"  [Training Batch {i + 1} / {len(train_dataloader)} (total_batch = {total_train_batch})]")
            #current_lr = optimizer.param_groups[0]['lr']
            
            #(tensor_num, tensor_size), (param_num, param_size) = analyzeGPU()
            
            print.info(f"    training : loss = {running_loss / print_every : .3f}, elapsed time = (current) {print_time(train_elapsed_time)} / (mean) {print_time(total_train_elapsed_time / print_every)}")
            #print.info(f"    actual step = {actual_step}, # of cuda tensors = {tensor_num}, mem = {tensor_size}")
            running_loss = 0
            total_train_elapsed_time = 0
                
        if (total_train_batch % validate_every == 0) or (i == len(train_dataloader) - 1):
            val_start_time = time.time()
            val_scores = []
            for j, val_batch in enumerate(val_dataloader):
                print.info(f"    [Validation Batch {j + 1} / {len(val_dataloader)}]", end="\r")
                val_scores += validation_iteration(val_batch)
                    
            val_rouge1, val_rouge2, val_rougeL = list(zip(*val_scores))
            
            avg_val_rouge1 = np.mean(val_rouge1)
            avg_val_rouge2 = np.mean(val_rouge2)
            avg_val_rougeL = np.mean(val_rougeL)
            
            model_path = save_model(
                batch=total_train_batch, 
                scores={
                    "rouge1": avg_val_rouge1,
                    "rouge2": avg_val_rouge2,
                    "rougeL": avg_val_rougeL
                },
                model_dir=model_dir
            )
            
            val_elapsed_time = time.time() - val_start_time
            
            print.info(f"    validation : rouge1 = {avg_val_rouge1 : .4f}, rouge2 = {avg_val_rouge2 : .4f}, rougeL = {avg_val_rougeL : .4f}, time = {print_time(val_elapsed_time)} => {model_path}")
               
    if early_stop:
        break

model dir: model/memsum_sbert/tos-dr-from-pretrained/only-gold-summary/2022-12-15-08-33-53

[36m[Epoch 1][0m
[36m  [Training Batch 100 / 403 (total_batch = 100)][0m
[36m    training : loss =  4.384, elapsed time = (current) 0.973s / (mean) 1.025s[0m
[36m  [Training Batch 200 / 403 (total_batch = 200)][0m
[36m    training : loss =  4.231, elapsed time = (current) 1.418s / (mean) 0.982s[0m
[36m    validation : rouge1 =  0.3411, rouge2 =  0.1460, rougeL =  0.3218, time = 51.429s => model/memsum_sbert/tos-dr-from-pretrained/only-gold-summary/2022-12-15-08-33-53/000200.0.341-0.146-0.322.pt[0m
[36m  [Training Batch 300 / 403 (total_batch = 300)][0m
[36m    training : loss =  4.311, elapsed time = (current) 1.451s / (mean) 1.008s[0m
[36m  [Training Batch 400 / 403 (total_batch = 400)][0m
[36m    training : loss =  4.329, elapsed time = (current) 1.148s / (mean) 1.072s[0m
[36m    validation : rouge1 =  0.3493, rouge2 =  0.1687, rougeL =  0.3312, time = 51.872s => model/mem

In [16]:
# test
#ckpt_path = "model/memsum_sbert/tos-dr/only-gold-summary/2022-12-15-03-43-15/001612.0.419-0.267-0.406.pt"
#ckpt_path = "model/memsum_sbert/tos-dr/all/2022-12-15-06-21-02/002000.0.421-0.272-0.407.pt"
ckpt_path = "model/memsum_sbert/tos-dr-from-pretrained/only-gold-summary/2022-12-15-08-33-53/001600.0.401-0.242-0.386.pt"
load_model(ckpt_path, device=device)

test_scores = []
for i, batch in enumerate(test_dataloader):
    print.info(f"[Test Batch {i + 1} / {len(test_dataloader)}]", end="\r")
    current_test_scores = test_iteration(batch)

    test_scores += current_test_scores

test_rouge1, test_rouge2, test_rougeL = list(zip(*test_scores))

avg_test_rouge1 = np.mean(test_rouge1)
avg_test_rouge2 = np.mean(test_rouge2)
avg_test_rougeL = np.mean(test_rougeL)
print.info(f"test : rouge1 = {avg_test_rouge1 : .4f}, rouge2 = {avg_test_rouge2 : .4f}, rougeL = {avg_test_rougeL : .4f}")

[36mtest : rouge1 =  0.4178, rouge2 =  0.2631, rougeL =  0.4034[0m


# Old

In [7]:
batch = iter(train_dataloader).__next__()
input_ids, attention_mask, doc_mask, label, sentence_num, oracle_summary, oracle_summary_score = batch

input_ids, attention_mask, doc_mask = input_ids.to(device), attention_mask.to(device), doc_mask.to(device)

In [8]:
local_sen_embed = local_sentence_encoder(input_ids, attention_mask, sentence_num)  # [batch_size, max_sent_num, emb_dim]
global_context_embed = global_context_encoder(local_sen_embed, doc_mask, dropout_rate=0.1)  # [batch_size, max_sent_num, emb_dim]

In [9]:
doc_mask_np = doc_mask.detach().cpu().numpy()
remaining_mask_np = np.ones_like( doc_mask_np ).astype( bool ) | doc_mask_np
extraction_mask_np = np.zeros_like( doc_mask_np ).astype( bool ) | doc_mask_np

# True : masked(not visible), False : not masked(visible)

In [10]:
log_action_prob_list = []
log_stop_prob_list = []

done_list = []
extraction_context_embed = None

In [27]:
# valid_sen_idxs.shape[1] == max_sentence_num
for step in range(max_sentence_num):
    remaining_mask = torch.from_numpy(remaining_mask_np).to(device)
    extraction_mask = torch.from_numpy(extraction_mask_np).to(device)
    
    if step > 0: # if at least one sentence is selected
        extraction_context_embed = extraction_context_decoder(local_sen_embed, remaining_mask, extraction_mask, dropout_rate=0.1)
    
    sentence_scores, p_stop, baseline = extractor(local_sen_embed, global_context_embed, extraction_context_embed, extraction_mask, dropout_rate=0.1)
    # p : tensor[batch_size, max_sent_num]
    # Note. baseline is not used -> deletable?
    
    p_stop = p_stop.unsqueeze(1)
    m_stop = Categorical(torch.cat([1 - p_stop, p_stop], dim=1))
    
    # grep step-th summary sentence idx
    summary_sent_idxs = oracle_summary[:, step]  # tensor[batch_size]
    
    # find documents that all summary sentences are extracted(= padding is selected)
    is_summarization_over = (summary_sent_idxs == -1)  # tensor[batch_size]
    
    if len(done_list) > 0:
        # is_just_stop : 이번 step에 막 summarization over 한거면 true, 아니면 false
        is_just_stop = torch.logical_and(~done_list[-1], is_summarization_over)
    else:
        is_just_stop = is_summarization_over
    
    if torch.all(is_summarization_over) and not torch.any(just_stop): # 모든 doc들의 summarization이 끝나고 1 step 진행한 경우 break
        break
    
    sentence_scores = sentence_scores.masked_fill(extraction_mask, 1e-12)
    normalized_sentence_scores = sentence_scores / sentence_scores.sum(dim=1, keepdims=True)  # tensor[batch_size, max_sent_num]
    
    extracted_sentence_scores = normalized_sentence_scores[range(batch_size), summary_sent_idxs]  # tensor[batch_size]
    log_action_prob = extracted_sentence_scores.masked_fill(is_summarization_over, 1.0).log()
    
    log_stop_prob = m_stop.log_prob(is_summarization_over.to(torch.long))
    log_stop_prob = log_stop_prob.masked_fill(torch.logical_xor(is_summarization_over, is_just_stop), 0)
    
    log_action_prob_list.append(log_action_prob.unsqueeze(1))
    log_stop_prob_list.append(log_stop_prob.unsqueeze(1))
    done_list.append(is_summarization_over)
    
    for doc_idx, summary_sent_idx in enumerate(summary_sent_idxs.tolist()):
        if summary_sent_idx != -1:
            remaining_mask_np[doc_idx, summary_sent_idx] = False
            extraction_mask_np[doc_idx, summary_sent_idx] = True
    log_action
        
    
    break

In [41]:
sent_idxs.tolist()

[69, 153, 0, 48, 78, 56, 82, 119, 10, 67, 223, 225, 16, 5, 11, 28]

In [31]:
x

tensor([0.0051, 0.0039, 0.0227, 0.0069, 0.0083, 0.0086, 0.0020, 0.0027, 0.0050,
        0.0025, 0.0040, 0.0036, 0.0041, 0.0051, 0.0096, 0.0021],
       device='cuda:0', grad_fn=<IndexBackward0>)

In [18]:
extraction_mask_np

array([[False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       ...,
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ...,  True,  True,  True],
       [False, False, False, ..., False, False, False]])

In [13]:
log_action_prob_list = []
log_stop_prob_list = []

done_list = []
extraction_context_embed = None

for step in range(valid_sen_idxs.shape[1]):
    remaining_mask = torch.from_numpy( remaining_mask_np ).to(device)
    extraction_mask = torch.from_numpy( extraction_mask_np ).to(device)
    if step > 0:
        extraction_context_embed = extraction_context_decoder( local_sen_embed, remaining_mask, extraction_mask, dropout_rate )
    p, p_stop, baseline = extractor( local_sen_embed, global_context_embed, extraction_context_embed , extraction_mask , dropout_rate )

    p_stop = p_stop.unsqueeze(1)
    m_stop = Categorical( torch.cat( [ 1-p_stop, p_stop  ], dim =1 ) )

    sen_indices = valid_sen_idxs[:, step]
    done = sen_indices == -1
    if len(done_list) > 0:
        done = torch.logical_or(done_list[-1], done)
        just_stop = torch.logical_and( ~done_list[-1], done )
    else:
        just_stop = done

    if torch.all( done ) and not torch.any(just_stop):
        break

    p = p.masked_fill( extraction_mask, 1e-12 )  
    normalized_p = p / p.sum(dim=1, keepdims = True)
    ## Here the sen_indices is actually pre-sampled action
    normalized_p = normalized_p[ np.arange( num_documents ), sen_indices ]
    log_action_prob = normalized_p.masked_fill( done, 1.0 ).log()

    log_stop_prob = m_stop.log_prob( done.to(torch.long)  )
    log_stop_prob = log_stop_prob.masked_fill( torch.logical_xor( done, just_stop ), 0.0 )

    log_action_prob_list.append( log_action_prob.unsqueeze(1) )
    log_stop_prob_list.append( log_stop_prob.unsqueeze(1) )
    done_list.append(done)

    for doc_i in range( num_documents ):
        sen_i = sen_indices[ doc_i ].item()
        if sen_i != -1:
            remaining_mask_np[doc_i,sen_i] = False
            extraction_mask_np[doc_i,sen_i] = True

In [11]:
local_sentence_embeddings.shape

torch.Size([3863, 768])

In [12]:
x = mean_pooling(lse_output, attention_mask)

In [15]:
x.shape

torch.Size([440, 768])

In [16]:
lse_output.pooler_output.shape

torch.Size([440, 768])

In [83]:
local_sentence_embedding.shape

(32, 768)

In [None]:
def train_teration(batch):
    sentences, label = batch
    label = label.to(device)
    
    batch_size = sentences.shape[0]
    sentence_num = sentences.shape[1]
    
    local_sentence_embedding = local_sentence_encoder.encode(
        batch[0]
    )

In [30]:
model = SentenceTransformer("all-mpnet-base-v2")

In [75]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [62]:
train_dataset = ExtractionTrainingDataset(
    "data/gov-report/train_GOVREPORT.jsonl",
    tokenizer=model.tokenizer,
)

train_dataloader = DataLoader(train_dataset, batch_size=2, num_workers=3)

In [58]:
t

In [63]:
x = iter(dataloader).__next__()

In [65]:
x[1].shape

torch.Size([2, 500])

In [None]:
class MemSum:
    def __init__(
        self,
        num_heads=8,
        hidden_dim=1024,
        N_enc_l=2 ,
        N_enc_g=2,
        N_dec=3,
        max_seq_len=100,
        max_doc_len=500
    ):
        self.device = "cuda:0"
        
        self.local_sentence_encoder = SentenceTransformer("all-mpnet-base-v2").to(self.device)
        embed_dim = self.local_sentence_encoder.get_sentence_embedding_dimension()
        
        self.global_context_encoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim, N_enc_g ).to(self.device)
        
        self.extraction_context_decoder = GlobalContextEncoder( embed_dim, num_heads, hidden_dim, N_dec ).to(self.device)
        
        self.extractor = Extractor( embed_dim, num_heads ).to(self.device)
        
        """
        ckpt = torch.load( model_path, map_location = "cpu" )
        self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
        self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
        self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
        self.extractor.load_state_dict(ckpt["extractor"])
        """
        
        self.max_seq_len = max_seq_len
        self.max_doc_len = max_doc_len
    
    def extract( self, document_batch, p_stop_thres = 0.7, ngram_blocking = False, ngram = 3, return_sentence_position = False, return_sentence_score_history = False, max_extracted_sentences_per_document = 4 ):
        """document_batch is a batch of documents:
        [  [ sen1, sen2, ... , senL1 ], 
           [ sen1, sen2, ... , senL2], ...
         ]
        """
        extracted_sentences = []
        sentence_score_history = []
        p_stop_history = []
        
        with torch.no_grad():
            num_sentences = seqs.size(1)
            sen_embed  = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) )  )
            sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
            relevance_embed = self.global_context_encoder( sen_embed, doc_mask  )
    
            num_documents = seqs.size(0)
            doc_mask = doc_mask.detach().cpu().numpy()
            seqs = seqs.detach().cpu().numpy()
    
            extracted_sentences = []
            extracted_sentences_positions = []
        
            for doc_i in range(num_documents):
                current_doc_mask = doc_mask[doc_i:doc_i+1]
                current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(np.bool) | current_doc_mask
                current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(np.bool) | current_doc_mask
        
                current_sen_embed = sen_embed[doc_i:doc_i+1]
                current_relevance_embed = relevance_embed[ doc_i:doc_i+1 ]
                current_redundancy_embed = None
        
                current_hyps = []
                extracted_sen_ngrams = set()

                sentence_score_history_for_doc_i = []

                p_stop_history_for_doc_i = []
                
                for step in range( max_extracted_sentences_per_document+1 ) :
                    current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
                    current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
                    if step > 0:
                        current_redundancy_embed = self.extraction_context_decoder( current_sen_embed, current_remaining_mask, current_extraction_mask  )
                    p, p_stop, _ = self.extractor( current_sen_embed, current_relevance_embed, current_redundancy_embed , current_extraction_mask  )
                    p_stop = p_stop.unsqueeze(1)
            
            
                    p = p.masked_fill( current_extraction_mask, 1e-12 ) 

                    sentence_score_history_for_doc_i.append( p.detach().cpu().numpy() )

                    p_stop_history_for_doc_i.append(  p_stop.squeeze(1).item() )

                    normalized_p = p / p.sum(dim=1, keepdims = True)

                    stop = p_stop.squeeze(1).item()> p_stop_thres #and step > 0
                    
                    #sen_i = normalized_p.argmax(dim=1)[0]
                    _, sorted_sen_indices =normalized_p.sort(dim=1, descending= True)
                    sorted_sen_indices = sorted_sen_indices[0]
                    
                    extracted = False
                    for sen_i in sorted_sen_indices:
                        sen_i = sen_i.item()
                        if sen_i< len(document_batch[doc_i]):
                            sen = document_batch[doc_i][sen_i]
                        else:
                            break
                        sen_ngrams = self.get_ngram( sen.lower().split(), ngram )
                        if not ngram_blocking or len( extracted_sen_ngrams &  sen_ngrams ) < 1:
                            extracted_sen_ngrams.update( sen_ngrams )
                            extracted = True
                            break
                                        
                    if stop or step == max_extracted_sentences_per_document or not extracted:
                        extracted_sentences.append( [ document_batch[doc_i][sen_i] for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])    ] )
                        extracted_sentences_positions.append( [ sen_i for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])  ]  )
                        break
                    else:
                        current_hyps.append(sen_i)
                        current_extraction_mask_np[0, sen_i] = True
                        current_remaining_mask_np[0, sen_i] = False

                sentence_score_history.append(sentence_score_history_for_doc_i)
                p_stop_history.append( p_stop_history_for_doc_i )

        # if return_sentence_position:
        #     return extracted_sentences, extracted_sentences_positions 
        # else:
        #     return extracted_sentences

        results = [extracted_sentences]
        if return_sentence_position:
            results.append( extracted_sentences_positions )
        if return_sentence_score_history:
            results+=[sentence_score_history , p_stop_history ]
        if len(results) == 1:
            results = results[0]
        
        return results



class ExtractiveSummarizer_NeuSum:
    def __init__( self, model_path, vocabulary_path, gpu = None , embed_dim=200,
                 max_seq_len =100, max_doc_len = 500 , **kwargs ):
        with open( vocabulary_path , "rb" ) as f:
            words = pickle.load(f)
        self.vocab = Vocab_NeuSum( words )
        vocab_size = len(words)
        self.local_sentence_encoder = LocalSentenceEncoder_NeuSum( vocab_size, self.vocab.pad_index, embed_dim, None )
        self.global_context_encoder = GlobalContextEncoder_NeuSum( embed_dim)
        self.extraction_context_decoder = ExtractionContextDecoder_NeuSum( embed_dim)
        self.extractor = Extractor_NeuSum( embed_dim )
        ckpt = torch.load( model_path, map_location = "cpu" )
        self.local_sentence_encoder.load_state_dict( ckpt["local_sentence_encoder"] )
        self.global_context_encoder.load_state_dict( ckpt["global_context_encoder"] )
        self.extraction_context_decoder.load_state_dict( ckpt["extraction_context_decoder"] )
        self.extractor.load_state_dict(ckpt["extractor"])
        
        self.device =  torch.device( "cuda:%d"%(gpu) if gpu is not None and torch.cuda.is_available() else "cpu"  )        
        self.local_sentence_encoder.to(self.device)
        self.global_context_encoder.to(self.device)
        self.extraction_context_decoder.to(self.device)
        self.extractor.to(self.device)
        
        self.sentence_tokenizer = SentenceTokenizer_NeuSum()
        self.max_seq_len = max_seq_len
        self.max_doc_len = max_doc_len
    

    def extract( self, document_batch, return_sentence_position = False, max_extracted_sentences_per_document = 7, **kwargs ):
        """document_batch is a batch of documents:
        [  [ sen1, sen2, ... , senL1 ], 
           [ sen1, sen2, ... , senL2], ...
         ]
        """
        ## tokenization:
        document_length_list = []
        sentence_length_list = []
        tokenized_document_batch = []
        for document in document_batch:
            tokenized_document = []
            for sen in document:
                tokenized_sen = self.sentence_tokenizer.tokenize( sen )
                tokenized_document.append( tokenized_sen )
                sentence_length_list.append( len(tokenized_sen.split()) )
            tokenized_document_batch.append( tokenized_document )
            document_length_list.append( len(tokenized_document) )

        max_document_length =  self.max_doc_len 
        max_sentence_length =  self.max_seq_len 
        ## convert to sequence
        seqs = []
        doc_mask = []
        
        for document in tokenized_document_batch:
            if len(document) > max_document_length:
                # doc_mask.append(  [0] * max_document_length )
                document = document[:max_document_length]
            else:
                # doc_mask.append(  [0] * len(document) +[1] * ( max_document_length -  len(document) ) )
                document = document + [""] * ( max_document_length -  len(document) )

            doc_mask.append(  [ 1 if sen.strip() == "" else 0 for sen in  document   ] )

            document_sequences = []
            for sen in document:
                seq = self.vocab.sent2seq( sen, max_sentence_length )
                document_sequences.append(seq)
            seqs.append(document_sequences)
        seqs = np.asarray(seqs)
        doc_mask = np.asarray(doc_mask) == 1
        seqs = torch.from_numpy(seqs).to(self.device)
        doc_mask = torch.from_numpy(doc_mask).to(self.device)
        
        with torch.no_grad():
            num_sentences = seqs.size(1)
            sen_embed  = self.local_sentence_encoder( seqs.view(-1, seqs.size(2) )  )
            sen_embed = sen_embed.view( -1, num_sentences, sen_embed.size(1) )
            global_context_embed, backward_state = self.global_context_encoder( sen_embed, doc_mask,  return_backward_state = True  )
    
            num_documents = seqs.size(0)
            doc_mask = doc_mask.detach().cpu().numpy()
            seqs = seqs.detach().cpu().numpy()
    
            extracted_sentences = []
            extracted_sentences_positions = []
        
            for doc_i in range(num_documents):
                current_doc_mask = doc_mask[doc_i:doc_i+1]
                current_remaining_mask_np = np.ones_like(current_doc_mask ).astype(np.bool) | current_doc_mask
                current_extraction_mask_np = np.zeros_like(current_doc_mask).astype(np.bool) | current_doc_mask
        
                current_global_context_embed = global_context_embed[doc_i:doc_i+1]
                current_hidden_state = backward_state[ doc_i:doc_i+1 ]
                current_extracted_sen_embed =  torch.zeros_like(  current_global_context_embed[:,:1,:] )
        
                current_hyps = []

                for step in range( max_extracted_sentences_per_document ) :
                    current_extraction_mask = torch.from_numpy( current_extraction_mask_np ).to(self.device)
                    current_remaining_mask = torch.from_numpy( current_remaining_mask_np ).to(self.device)
                    current_hidden_state  = self.extraction_context_decoder( current_extracted_sen_embed, current_hidden_state )
                    p = self.extractor( current_global_context_embed, current_hidden_state, current_extraction_mask )
                                                            
                    sen_i = p.argmax(dim=1)[0]
                    sen_i = sen_i.item()
                    
                    current_hyps.append(sen_i)
                    current_extraction_mask_np[0, sen_i] = True
                    current_remaining_mask_np[0, sen_i] = False
                    
                                        
                extracted_sentences.append( [ document_batch[doc_i][sen_i] for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])    ] )
                extracted_sentences_positions.append( [ sen_i for sen_i in  current_hyps if sen_i < len(document_batch[doc_i])  ]  )


        results = [extracted_sentences]
        if return_sentence_position:
            results.append( extracted_sentences_positions )
        if len(results) == 1:
            results = results[0]
        return results