In [2]:
%matplotlib inline
from InferSent.models import InferSent
import numpy as np
import torch
import os
from random import choice
from tqdm import tqdm_notebook as tqdm
from scipy.spatial import distance
from annoy import AnnoyIndex

# Load `InferSent` Model

In [3]:
MODEL_PATH = 'InferSent/encoder/infersent2.pkl'
params = {'bsize': 256, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,
            'pool_type': 'max', 'dpout_model': 0.0, 'version': 2}
model = InferSent(params).cuda()
model.load_state_dict(torch.load(MODEL_PATH))

In [4]:
W2V_PATH = 'InferSent/dataset/fastText/crawl-300d-2M.vec'
model.set_w2v_path(W2V_PATH)

# Load Books

In [5]:
prefix = '/mnt/bigfiles/dl/datasets/Gutenberg/'
books = os.listdir(prefix)

In [6]:
def get_corpus(author):
    corpus = ''
    
    for book in books:
        if f'{author}__' in book:
            corpus += open(prefix + book).read() + '\n\n'
    return corpus

In [7]:
NEWLINE = '<NEWLINE>'
def tokenize_sentences(text):
    open('/tmp/in.txt', 'w').write(text.replace('\n\n', NEWLINE))
    os.system('/mnt/bigfiles/dl/datasets/stanford-parser-full-2017-06-09/tokenize_sent.sh')
    tokens = open('/tmp/out.txt').read().split('\n')
    print('Total tokens in dataset', len(tokens))

    return [token for token in tokens if len(token) > 0]

In [8]:
def detokenize_sentences(sentences):
    open('/tmp/in.txt', 'w').write(' '.join(sentences).replace(NEWLINE, '\n\n'))
    os.system('/mnt/bigfiles/dl/datasets/stanford-parser-full-2017-06-09/detokenize.sh')
    
    return open('/tmp/out.txt').read()

# Change Books

In [9]:
def change_book(toChange, source, useAnnoy = False, maxChars = 1000000):
    toChangeSent = tokenize_sentences(toChange)
    sourceSent = tokenize_sentences(source[:maxChars])
    
    model.build_vocab(toChangeSent + sourceSent, tokenize=True)
    
    toChangeVec = model.encode(toChangeSent, tokenize=True)
    sourceVec = model.encode(sourceSent, tokenize=True)
    
    changed = []
    
    if useAnnoy:
        print('Building index...')
        index = AnnoyIndex(len(sourceVec[0]), metric='dot')
        for (i, vec) in enumerate(sourceVec):
            index.add_item(i, vec)
        index.build(25)

        for lineVec in tqdm(toChangeVec):
            closestIdx = index.get_nns_by_vector(lineVec, 1)[0]
            changed.append(sourceSent[closestIdx])
    else:
        for lineVec in tqdm(toChangeVec):
            distances = [distance.cosine(lineVec, possibleVec) for possibleVec in sourceVec]
            closestIdx = np.argmin(distances)
            changed.append(sourceSent[closestIdx])
            
    
    return detokenize_sentences(changed)

In [None]:
changed = change_book(open(prefix + 'Jane Austen___Northanger Abbey.txt').read(), get_corpus('Sir Arthur Conan Doyle'))
# change_book('I like to do things. Things are fun.', 'Things are fun. I like to do stuff.')

Total tokens in dataset 3615
Total tokens in dataset 10678
