# Train model

In [1]:
!pip show nltk

In [2]:
import re
import nltk
from nltk.tokenize import word_tokenize, sent_tokenize
nltk.download('punkt')

In [3]:
from torchtext.vocab import FastText
embedding = FastText('simple')

In [4]:
import numpy as np

In [5]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
from collections import Counter
import re

class SherlockDataset(torch.utils.data.Dataset):
    def __init__(self, text_path, word_embedding, sequence_length):
        self.word_embedding = word_embedding
        self.sequence_length = sequence_length
        self.words = self.load_words(text_path)
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self, text_path):
        word_list = []
        with open(text_path, 'r') as f:
            text = f.read()
        # Split text into sentences
        sentences = sent_tokenize(text.strip())

        for sent in sentences:
            tokenized_sent = word_tokenize(sent.lower())
            # Add end of sentence token
            tokenized_sent.append('</s>')
            # Remove punctuation
            remove_punct = [word for word in tokenized_sent if re.search(r'\w+', word) is not None]
            word_list.extend(remove_punct)

        return word_list
    
    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.sequence_length

    def __getitem__(self, index):
        input_text = self.words[index:index+self.sequence_length]
        x = torch.LongTensor(self.words_indexes[index:index+self.sequence_length])
        x_embed = torch.stack([self.word_embedding[word] for word in input_text])
        y = torch.LongTensor(self.words_indexes[index+1:index+self.sequence_length+1])
        return (x, x_embed, y)

In [7]:
train_dataset = SherlockDataset('../input/sherlock-holmes/train_Sherlock.txt', embedding, 10)

In [8]:
len(train_dataset.uniq_words)

In [9]:
train_dataset.__getitem__(0)

In [10]:
from torch import nn

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 300
        self.num_layers = 3
        self.dataset = dataset

        vocab_size = len(dataset.uniq_words)
        self.lstm = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.1,
        )
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, embed, prev_state):
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [11]:
model = Model(train_dataset).to(device)
# model.load_state_dict(torch.load('../input/lstm-sherlock-state/lstm_state.pth'))
model.load_state_dict(torch.load('../input/lstmsherlockstate2/lstm_state.pth', map_location='cpu'))
# model.load_state_dict(torch.load('../input/lstm-sherlock-torchtext-500-epochs/lstm_state.pth'))

In [12]:
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader

def train(dataset, model, batch_size, max_epochs, sequence_length):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
#     optimizer.load_state_dict(torch.load('../input/lstm-sherlock-torchtext-500-epochs/lstm_optim.pth'))

    for epoch in range(max_epochs):
        state_h, state_c = model.init_state(sequence_length)
        state_h = state_h.to(device)
        state_c = state_c.to(device)

        for batch, (x, x_embed, y) in enumerate(dataloader):
            optimizer.zero_grad()

            x, x_embed, y = x.to(device), x_embed.to(device), y.to(device)
            # print(x_embed.shape)
            y_pred, (state_h, state_c) = model(x_embed, (state_h, state_c))
            # print(y_pred.shape)
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()
            
            loss.backward()
            optimizer.step()
            if (batch + 1) % 50 == 0:
                print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
        
        torch.save(model.state_dict(), '/kaggle/working/lstm_state.pth')
        torch.save(optimizer.state_dict(), '/kaggle/working/lstm_optim.pth')

In [13]:
# train(dataset=train_dataset,
#       model=model,
#       batch_size=256,
#       max_epochs=500,
#       sequence_length=10)

In [14]:
# torch.save(model, "/kaggle/working/lstm.pth")

# Decoding

In [15]:
def custom_tokenize(prompt):
    tokenized_sent = word_tokenize(prompt.lower())
    # Remove punctuation
    remove_punct = [word for word in tokenized_sent if re.search(r'\w+', word) is not None]
    return remove_punct

## Search

In [92]:
def greedy_search_predict(dataset, embedding, model, prompt, next_words=20):    
    model.eval()
    words = custom_tokenize(prompt)
    
    state_h, state_c = model.init_state(len(words))
    for i in range(0, next_words):
        # x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        x = torch.stack([embedding[word] for word in words[i:]])
        x = x.reshape((1, len(words[i:]), embedding.dim))
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        word_index = int(torch.argmax(last_word_logits))
        words.append(dataset.index_to_word[word_index])
    return words

In [94]:
result = greedy_search_predict(train_dataset, embedding, model, 'logic')
print(result)

In [53]:
def beam_search_predict(dataset, embedding, model, prompt, beam_width, next_words=20):
    model.eval()
    words = custom_tokenize(prompt)
    state_h, state_c = model.init_state(len(words))
#     sequences = [{'sentence': [dataset.word_to_index[w] for w in words],
#                   'sentence_embedding': [embedding[word] for word in words],
#                   'score': 0, 
#                   'hidden_state': (state_h, state_c)}] # sequences store all generated results
    sequences = [{'sentence': words,
                  'sentence_embedding': [embedding[word] for word in words],
                  'score': 0, 
                  'hidden_state': (state_h, state_c)}]
    for i in range(next_words):
        all_candidates = []
        for sequence in sequences:
            x = torch.stack(sequence['sentence_embedding'][i:])
            x = x.reshape((1, len(words), 300))
            y_pred, (state_h, state_c) = model(x, sequence['hidden_state'])

            last_word_logits = y_pred[0][-1]
            word_prob = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
            for word_index, prob in enumerate(word_prob):
                candidate = {'sentence': sequence['sentence'] + [dataset.index_to_word[word_index]], 'score': sequence['score'] + prob}
                candidate['sentence_embedding'] = sequence['sentence_embedding'] + [embedding[dataset.index_to_word[word_index]]]
                candidate['hidden_state'] = (state_h, state_c)
                all_candidates.append(candidate)
        # Get top k candidate        
        ordered = sorted(all_candidates, key=lambda x:x['score'], reverse=True)
        sequences = ordered[:beam_width]


    sentence_list = [sequence['sentence'] for sequence in sequences]
    return sentence_list

In [55]:
model = model.to('cpu')
result = beam_search_predict(train_dataset, embedding, model, 'logic', 5)
for sentence in result:
    print(' '.join(sentence))

## Sample

In [69]:
def random_sampling(logits, temperature):
    logits = logits / temperature
    word_prob = torch.nn.functional.softmax(logits, dim=0).detach().numpy()
    word_index = np.random.choice(len(word_prob), p=word_prob)
    return word_index

def topk_sampling(logits, k):
    word_prob = torch.nn.functional.softmax(logits, dim=0).detach().numpy()
    # Get index of k largest prob
    k_largest_index = np.argpartition(word_prob, -k)[-k:]
    mask = np.zeros(word_prob.shape)
    # Create mask to keep only k largest prob
    mask[k_largest_index] = 1
    word_prob = word_prob * mask
    # Rescale so that sum = 1
    word_prob = word_prob / word_prob.sum()
    # Sample from new distribution
    word_index = np.random.choice(len(word_prob), p=word_prob)
    return word_index

def nucleus_sampling(logits, p):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    word_prob = torch.nn.functional.softmax(sorted_logits, dim=0).detach().numpy()
    cumulative_probs = np.cumsum(word_prob, axis=0)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].copy()
    sorted_indices_to_remove[..., 0] = 0
    # Set prob at index to be removed to 0 
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    word_prob[indices_to_remove] = 0
    # Rescale so that sum = 1
    word_prob = word_prob / word_prob.sum()
    # Sample from new distribution
    word_index = np.random.choice(len(word_prob), p=word_prob)
    return word_index

def sample_predict(dataset, embedding, model, prompt, mode, temperature=1, k=5, p=0.9, next_words=20):
    assert mode in ['random', 'top-k', 'nucleus']
    
    model.eval()
    words = custom_tokenize(prompt)
    
    state_h, state_c = model.init_state(len(words))
    for i in range(0, next_words):
        # x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        x = torch.stack([embedding[word] for word in words[i:]])
        x = x.reshape((1, len(words[i:]), embedding.dim))
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        if mode == 'random':
            word_index = random_sampling(last_word_logits, temperature)
        elif mode == 'top-k':
            word_index = topk_sampling(last_word_logits, k)
        elif mode == 'nucleus':
            word_index = nucleus_sampling(last_word_logits, p)
        words.append(dataset.index_to_word[word_index])

    return words

# print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'random', temperature=0.9)))
# print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'top-k', k=10)))
# print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'nucleus', p=0.7)))

In [70]:
print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'random', temperature=0.9)))
print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'top-k', k=40)))
print(' '.join(sample_predict(train_dataset, embedding, model, 'I could', 'nucleus', p=0.9)))

# Experiment

In [20]:
class PromptDataset(torch.utils.data.Dataset):
    def __init__(self, text_path, prompt_length):
        self.prompts = self.load_sentences(text_path, prompt_length)
        
    def load_sentences(self, text_path, prompt_length=2):
        prompts = []
        with open(text_path, 'r') as f:
            text = f.read()
        # Split text into sentences
        sentences = sent_tokenize(text.strip())
        for sent in sentences[1:201]:
            tokenized_sent = word_tokenize(sent.lower())
            remove_punct = [word for word in tokenized_sent if re.search(r'\w+', word) is not None]
            prompt = ' '.join(remove_punct[:prompt_length])
            real = ' '.join(remove_punct)
            prompts.append((prompt, real))
        return prompts
    
    def __len__(self):
        return len(self.prompts)
    
    def __getitem__(self, idx):
        return self.prompts[idx]

In [21]:
test_data = PromptDataset('../input/sherlock-holmes/test_Sherlock.txt', 2)
test_loader = DataLoader(test_data, batch_size=1, shuffle=False)

In [28]:
from collections import Counter

def count_repetition_percent(generated_strings):
    words = generated_strings.split()
    counter = Counter(words)
    count_repeated = 0
    for word in counter:
        count_repeated += (counter[word] - 1)
    sentence_length = sum(counter.values())
    return count_repeated / sentence_length
    

In [106]:
def count_repetition_percent_modified(generated_strings):
    words = generated_strings.split()
    counter = Counter(words)
    count_repeated = 0
    for word in counter:
        if word != '</s>':
            count_repeated += (counter[word] - 1)
    sentence_length = sum(counter.values()) - counter['</s>']
    return count_repeated / sentence_length

In [115]:
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    greedy_result = ' '.join(greedy_search_predict(train_dataset, embedding, model, prompt))
    # repeated_percent = count_repetition_percent(greedy_result)
    print(greedy_result)
    repeated_percent = count_repetition_percent_modified(greedy_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [99]:
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    beam_result = ' '.join(beam_search_predict(train_dataset, embedding, model, prompt, 5)[0])
    print(beam_result)
    repeated_percent = count_repetition_percent(beam_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)

In [100]:
print(avg_repeated_percent)

In [117]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    random_result = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'random', temperature=1))
    # repeated_percent = count_repetition_percent(random_result)
    print(random_result)
    repeated_percent = count_repetition_percent_modified(random_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [109]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    random_result2 = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'random', temperature=0.9))
#     repeated_percent = count_repetition_percent(random_result2)
    repeated_percent = count_repetition_percent_modified(random_result2)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [120]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    topk_result = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'top-k', k=40))
#     repeated_percent = count_repetition_percent(topk_result)
    print(topk_result)
    repeated_percent = count_repetition_percent_modified(topk_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [119]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    topk_result = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'top-k', k=400))
#     repeated_percent = count_repetition_percent(topk_result)
    repeated_percent = count_repetition_percent_modified(topk_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [112]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    nucleus_result = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'nucleus', p=0.95))
    print(nucleus_result)
#     repeated_percent = count_repetition_percent(nucleus_result)
    repeated_percent = count_repetition_percent_modified(nucleus_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [113]:
np.random.seed(0)
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    nucleus_result = ' '.join(sample_predict(train_dataset, embedding, model, prompt, 'nucleus', p=0.9))
    print(nucleus_result)
#     repeated_percent = count_repetition_percent(nucleus_result)
    repeated_percent = count_repetition_percent_modified(nucleus_result)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)

In [101]:
total_repeated_percent = 0
for idx, sample in enumerate(test_loader):
    prompt = sample[0][0]
    real = sample[1][0]
    
    repeated_percent = count_repetition_percent(real)
    total_repeated_percent += repeated_percent
avg_repeated_percent = total_repeated_percent / len(test_loader)
print(avg_repeated_percent)