In [None]:
import os
import argparse
# from trainer import Trainer
import torch as t

import pandas as pd
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# from model import SkipGramEmbeddings
# from sgns_loss import SGNSLoss
from tqdm import tqdm
# from datasets.pypi_lang import PyPILangDataset
# from datasets.world_order import WorldOrderDataset
from torch.utils.tensorboard import SummaryWriter
# from datasets.fine_food import FineFoodDataset
import matplotlib.pyplot as plt
import torch.nn.functional as F

# from utils import AliasMultinomial

# from .preprocess import Tokenizer

from gensim.corpora import Dictionary
from torch.utils.data.dataset import Dataset

import spacy
import re

In [None]:
class Tokenizer:

    def __init__(self, args, custom_stop=set()):
        self.args = args
        self.custom_stop = custom_stop
        # Define pipeline - use different nlp is using pretrained
        import en_core_web_sm
        self.nlp = en_core_web_sm.load()
        # self.nlp = spacy.load("en_core_web_sm", disable=[])
        # Merge named entities
        merge_ents = self.nlp.create_pipe("merge_entities")
        self.nlp.add_pipe(merge_ents)

    def tokenize_doc(self, doc_str):
        """
        Tokenize a document string
        Modified version of Moody's Tokenization in:
        https://github.com/cemoody/lda2vec/blob/master/lda2vec/preprocess.py

        :params doc_str: String
        :returns: list of Strings, i.e. tokens
        """

        # Send doc_str through pipeline
        spacy_doc = self.nlp(doc_str)
        # Filter
        filtered_doc = filter(self.is_valid_token, spacy_doc)
        # Convert to text make lowercase
        clean_doc = [t.text.lower().strip() for t in filtered_doc]
        # Only allow characters in the alphabet, '_', and digits
        clean_doc = [re.sub('[^a-zA-Z0-9]', '', t) for t in clean_doc]
        # Remove any resulting empty indices
        clean_doc = [t for t in clean_doc if len(t) > 0]
        # Filter out any custom stop
        clean_doc = [t for t in clean_doc if t not in self.custom_stop]

        return clean_doc

    def is_valid_token(self, token):
        """
        Determines if a token is valid or not

        :params token: String
        :returns: Boolean
        """
        if token.like_url:
            return False
        if token.like_email:
            return False
        if token.is_stop or token.text in self.custom_stop:
            return False

        return True

    def moodys_merge_noun_chunks(self, doc):
        """
        Merge noun chunks into a single token.

        Modified from sources of:
        - https://github.com/cemoody/lda2vec/blob/master/lda2vec/preprocess.py
        - https://spacy.io/api/pipeline-functions#merge_noun_chunks

        :params doc: Doc object.
        :returns: Doc object with merged noun chunks.
        """
        bad_deps = ('amod', 'compound')

        if not doc.is_parsed:
            return doc
        with doc.retokenize() as retokenizer:
            for np in doc.noun_chunks:

                # Only keep adjectives and nouns, e.g. "good ideas"
                while len(np) > 1 and np[0].dep_ not in bad_deps:
                    np = np[1:]

                if len(np) > 1:
                    # Merge NPs
                    attrs = {"tag": np.root.tag, "dep": np.root.dep}
                    retokenizer.merge(np, attrs=attrs)
        return doc


class SimpleTokenizer:

    def __init__(self, args, custom_stop=set()):
        self.args = args
        self.custom_stop = custom_stop

    def tokenize_doc(self, doc_str):
        # Filter out any custom stop
        clean_doc = [t for t in doc_str.split() if t not in self.custom_stop]
        return clean_doc

In [None]:
class SkipGramDataset(Dataset):

    def __init__(self, args):
        self.args = args
        self.dictionary = None
        self.examples = []
        self.name = ''

    def __getitem__(self, index):
        return self._example_to_tensor(*self.examples[index])

    def __len__(self):
        return len(self.examples)

    # 修改這裡的檔案名字
    def save(self, examples_path, dict_path):
        print('Saving Dataset Examples...')
        torch.save({
             'examples': self.examples,
        }, './test_examples.pth')
        print('Saving Dataset Dictionary...')
        self.dictionary.save('./test_dict.pth')
        print('Saved Dataset!')

    def load(self, examples_path, dict_path):
        print('Loading Dataset Examples...')
        self.examples = torch.load(examples_path)['examples']
        print('Loading Dataset Dictionary...')
        self.dictionary = Dictionary().load(dict_path)
        print('Loaded Saved Dataset!')

    def generate_examples_serial(self):
        """
        Generates examples with no multiprocessing - straight through!
        :return: None - updates class properties
        """
        # Now we have a Gensim Dictionary to work with
        self._build_dictionary()
        # Remove any tokens with a frequency less than 10
        self.dictionary.filter_extremes(no_below=10, no_above=0.75)

        self.examples = []
        for file in tqdm(self.load_files(), desc="Generating Examples (serial)"):
            file = self.dictionary.doc2idx(file)
            self.examples.extend(self._generate_examples_from_file(file))

    def load_files(self):
        """
        Sets self.files as a list of tokenized documents!
        :returns: List of files
        """
        # Needs to be implemented by child class
        raise NotImplementedError

    def _build_dictionary(self):
        """
        Creates a Gensim Dictionary
        :return: None - modifies self.dictionary
        """
        print("Building Dictionary...")
        self.dictionary = Dictionary(self.load_files())
        self.word_freq = self.dictionary.cfs
    def _generate_examples_from_file(self, file):
        """
        Generate all examples from a file within window size
        :param file: File from self.files
        :returns: List of examples
        """

        examples = []
        for i, token in enumerate(file):
            if token == -1:
                # Out of dictionary token
                continue

            # Generate context tokens for the current token
            context_words = self._generate_contexts(i, file)

            # Form Examples:
            # center, context - follows form: (input, target)
            new_examples = [(token, ctxt) for ctxt in context_words if ctxt != -1]

            # Add to class
            examples.extend(new_examples)
        return examples

    def _generate_contexts(self, token_idx, tokenized_doc):
        """
        Generate Token's Context Words
        Generates all the context words within the window size defined
        during initialization around token.

        :param token_idx: Index at which center token is found in tokenized_doc
        :param tokenized_doc: List - Document broken into tokens
        :returns: List of context words
        """
        contexts = []
        # Iterate over each position in window
        for w in range(-self.args.window_size, self.args.window_size + 1):
            context_pos = token_idx + w

            # Make sure current center and context are valid
            is_outside_doc = context_pos < 0 or context_pos >= len(tokenized_doc)
            center_is_context = token_idx == context_pos

            if is_outside_doc or center_is_context:
                # Not valid - skip to next window position
                continue

            contexts.append(tokenized_doc[context_pos])
        return contexts

    def _example_to_tensor(self, center, target):
        # print(center, target)
        """
        Takes raw example and turns it into tensor values

        :params example: Tuple of form: (center word, document id)
        :params target: String of the target word
        :returns: A tuple of tensors
        """
        center, target = torch.tensor([int(center)]), torch.tensor([int(target)])
        return center, target, self.word_freq[int(center)], self.word_freq[int(target)]


In [None]:
class WorldOrderDataset(SkipGramDataset):

    def __init__(self, args, examples_path=None, dict_path=None):
        SkipGramDataset.__init__(self, args)
        self.name = 'World Order Book Dataset'
        self.queries = ['nuclear', 'mankind', 'khomeini', 'ronald']

        if examples_path is not None and dict_path is not None:
            self.load(examples_path, dict_path)
        else:
            self.files = self.tokenize_files()
            self.generate_examples_serial()

        print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.')

    def load_files(self):
        return self.files

    def tokenize_files(self):
        files = []
        with open('data/world_order_kissinger.txt') as f:
            for line in f:
                words_no_dig_punc = (re.sub(r'[^\w]', ' ', line.lower())).split()
                words_no_dig_punc = [x for x in words_no_dig_punc if not any(c.isdigit() for c in x)]
                files.append(words_no_dig_punc)

        return files

In [None]:
class PyPILangDataset(SkipGramDataset):

    def __init__(self, args, examples_path=None, dict_path=None):
        SkipGramDataset.__init__(self, args)
        self.name = 'PyPI Language Dataset'
        self.queries = ['tensorflow', 'pytorch', 'nlp', 'performance', 'encryption']

        if examples_path is not None and dict_path is not None:
            self.load(examples_path, dict_path)
        else:
            self.tokenizer = Tokenizer(args)
            self.files = self.tokenize_files()
            self.generate_examples_serial()

            self.save('pypi_examples.pth', 'pypi_dict.pth')

        print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.')

    def load_files(self):
        return self.files

    def tokenize_files(self):
        node_lang_df = pd.read_csv(self.args.dataset_dir, na_filter=False)
        lang_data = node_lang_df['language'].values
        return [self.tokenizer.tokenize_doc(f) for f in tqdm(lang_data, desc='Tokenizing Docs')]

In [None]:
class AliasMultinomial(object):
    """
    Fast sampling from a multinomial distribution.
    https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/

    Code taken from: https://github.com/TropComplique/lda2vec-pytorch/blob/master/utils/alias_multinomial.py
    """

    def __init__(self, probs, device):
        """
        probs: a float tensor with shape [K].
            It represents probabilities of different outcomes.
            There are K outcomes. Probabilities sum to one.
        """
        self.device = device

        K = len(probs)
        self.q = t.zeros(K).to(device)
        self.J = t.LongTensor([0] * K).to(device)

        # sort the data into the outcomes with probabilities
        # that are larger and smaller than 1/K
        smaller = []
        larger = []
        for kk, prob in enumerate(probs):
            self.q[kk] = K * prob
            if self.q[kk] < 1.0:
                smaller.append(kk)
            else:
                larger.append(kk)

        # loop though and create little binary mixtures that
        # appropriately allocate the larger outcomes over the
        # overall uniform mixture
        while len(smaller) > 0 and len(larger) > 0:
            small = smaller.pop()
            large = larger.pop()

            self.J[small] = large
            self.q[large] = (self.q[large] - 1.0) + self.q[small]

            if self.q[large] < 1.0:
                smaller.append(large)
            else:
                larger.append(large)

        self.q.clamp(0.0, 1.0)
        self.J.clamp(0, K - 1)

    def draw(self, N):
        """Draw N samples from the distribution."""

        K = self.J.size(0)
        r = t.LongTensor(np.random.randint(0, K, size=N)).to(self.device)
        q = self.q.index_select(0, r).clamp(0.0, 1.0)
        j = self.J.index_select(0, r)
        b = t.bernoulli(q)
        oq = r.mul(b.long())
        oj = j.mul((1 - b).long())

        return oq + oj

In [None]:
class SGNSLoss(nn.Module):
    BETA = 0.75  # exponent to adjust sampling frequency
    NUM_SAMPLES = 2

    def __init__(self, dataset, word_embeddings, device):
        super(SGNSLoss, self).__init__()
        self.dataset = dataset
        self.criterion = nn.BCEWithLogitsLoss(reduction = 'none')
        self.vocab_len = len(dataset.dictionary)
        self.word_embeddings = word_embeddings
        self.device = device

        # Helpful values for unigram distribution generation
        # Should use cfs instead but: https://github.com/RaRe-Technologies/gensim/issues/2574
        self.transformed_freq_vec = t.tensor(
            np.array([dataset.dictionary.dfs[i] for i in range(self.vocab_len)]) ** self.BETA
        )
        self.freq_sum = t.sum(self.transformed_freq_vec)
        # Generate table
        self.unigram_table = self.generate_unigram_table()
        # print(self.freq_sum)
        # print(type(self.freq_sum))
    def forward(self, center, context, center_id, context_id, word_freq):
        center, context = center.squeeze(), context.squeeze()  # batch_size x embed_size
        # Compute true portion
        freq_reg = self.get_reg_param(word_freq)
        freq_reg = torch.tensor(freq_reg)
        # try:
        reg = freq_reg[context_id]
        # except:
        #     print(context_id.shape)
        #     raise NotImplementedError
        reg = reg.to(context.device)
        reg = reg * ((context**2).sum())
        # print('reg', reg.shape)
        true_scores = (center * context).sum(-1)  # batch_size
        loss = self.criterion(true_scores, t.ones_like(true_scores))
        # print('loss',loss.shape)
        # **need** loss = loss + reg **need**
        # loss = loss + reg
        loss = loss.mean()
        #test_loss = loss.detach().item()
        # Compute negatively sampled portion -
        for i in range(self.NUM_SAMPLES):
            samples, sample_reg = self.get_unigram_samples(n=center.shape[0], freq_reg= freq_reg)
            # print(samples.shape)
            # reg = (  0.01 / (( F.tanh(context_freq-10/2)+2.2))) * (context**2.sum())
            neg_sample_scores = (center * samples).sum(-1)
            # Update loss

            new = self.criterion(neg_sample_scores, t.zeros_like(neg_sample_scores))
            new = new.to(loss.device)
            sample_reg = sample_reg.to(loss.device)
            samples = samples.to(loss.device)
            # **need** new += (sample_reg)*((samples**2).sum())
            new = new.mean()
            new = new.to(loss.device)
            loss = loss + new
            #x3 = neg_sample_scores.clone().detach().numpy()
            #test_loss += self.bce_loss_w_logits(x3, t.zeros_like(neg_sample_scores).numpy())

        return loss#, test_loss

    def get_reg_param(self, freq ):
        # reg = [ 1/(np.tanh((xi-20)/2)+2.2) for xi in freq ]
        reg = [ 0.001/ np.log(xi+2) for xi in freq ]
        return reg

    @staticmethod
    def bce_loss_w_logits(x, y):
        max_val = np.clip(x, 0, None)
        loss = x - x * y + max_val + np.log(np.exp(-max_val) + np.exp((-x - max_val)))
        # print('===loss===',loss.shape)
        return loss.mean()

    def get_unigram_samples(self, n ,freq_reg):
        """
        Returns a sample according to a unigram distribution
        Randomly choose a value from self.unigram_table
        """
        rand_idxs = self.unigram_table.draw(n).to(self.device)
        sample_reg = freq_reg[rand_idxs]
        return self.word_embeddings(rand_idxs).squeeze(), sample_reg

    def get_unigram_prob(self, token_idx):
        return (self.transformed_freq_vec[token_idx].item()) / self.freq_sum.item()

    def generate_unigram_table(self):
        # Probability at each index corresponds to probability of selecting that token
        pdf = [self.get_unigram_prob(t_idx) for t_idx in range(0, self.vocab_len)]
        # Generate the table from PDF
        return AliasMultinomial(pdf, self.device)

In [None]:
class SkipGramEmbeddings(nn.Module):

    def __init__(self, vocab_size, embed_len):
        super(SkipGramEmbeddings, self).__init__()
        self.word_embeds = nn.Embedding(vocab_size, embed_len)#, sparse=True)
        #self.context_embeds = nn.Embedding(vocab_size, embed_len)# sparse=True)

    def forward(self, center, context):
        """
        Acts as a lookup for the center and context words' embeddings

        :param center: The center word index
        :param context: The context word index
        :return: The embedding of the target word
        """
        return self.word_embeds(center), self.word_embeds(context)

    def nearest_neighbors(self, word, dictionary):
        """
        Finds vector closest to word_idx vector
        :param word_idx: Integer
        :return: Integer corresponding to word vector in self.word_embeds
        """
        vectors = self.word_embeds.weight.data.cpu().numpy()
        index = dictionary.token2id[word]
        query = vectors[index]

        ranks = vectors.dot(query).squeeze()
        denom = query.T.dot(query).squeeze()
        denom = denom * np.sum(vectors ** 2, 1)
        denom = np.sqrt(denom)
        ranks = ranks / denom
        mostSimilar = []
        [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]]
        nearest_neighbors = mostSimilar[:10]
        nearest_neighbors = [dictionary[comp] for comp in nearest_neighbors]

        return nearest_neighbors


In [None]:
class Trainer:

    def __init__(self, args):
        # Load data
        self.args = args
        self.writer = SummaryWriter(log_dir='./experiments/', flush_secs=3)
        #self.dataset = PyPILangDataset(args, examples_path='../input/ffdata/pypi_examples.pth', dict_path='../input/ffdata/pypi_dict.pth')
        self.dataset = FineFoodDataset(args,dict_path = '../input/ffdata/dict.pth')
        self.vocab_size = len(self.dataset.dictionary)
        # print(self.vocab_size)
        # print(self.dataset.dictionary['amazon'])
        # return
        print("Finished loading dataset")
        self.word_freq = self.dataset.word_freq
        self.dataloader = DataLoader(self.dataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.workers)

        self.model = SkipGramEmbeddings(self.vocab_size, args.embedding_len).to(args.device)
        # self.dictionary
        self.optim = optim.Adam(self.model.parameters(), lr=args.lr)
        self.sgns = SGNSLoss(self.dataset, self.model.word_embeds, self.args.device)
        print(self.word_freq[ self.dataset.dictionary.token2id['word'] ])
        # Add graph to tensorboard
        #self.writer.add_graph(self.model, iter(self.dataloader).next()[0])
    
    def vis_freq(self):
        x = set(self.word_freq.values())
        freq_values = list(self.word_freq.values())
        y = [ freq_values.count(xi) for xi in x ]
        # print(x)
        # print(y)
        x = list(x)
        plt.plot(x,y)
        plt.show()
        self.vis_tanh(0.01,20, x,y)
       
    def reg_param(self, beta, v, freq_param):
        v = torch.tensor(v, dtype = torch.float)
        freq_param = torch.tensor(freq_param, dtype = torch.float)
        return 1 / (  beta / (( torch.tanh(v-freq_param/2)+2.2)))

    def vis_tanh(self, beta, freq_param,x,y):
        # y = torch.tensor(y,dtype = torch.float)
        # y = F.softmax(y)
        y_h = [self.reg_param(beta, xi, freq_param) for xi in y ]
        plt.plot(x, y_h)
        plt.show()
    

    def train(self):
        print('Training on device: {}'.format(self.args.device))

        # Log embeddings!
        print('\nRandom embeddings:')
        # for word in self.dataset.queries:
        #     print(f'word: {word} neighbors: {self.model.nearest_neighbors(word, self.dataset.dictionary)}')

        # for epoch in range(self.args.epochs):
        for epoch in range(5):

            print(f'Beginning epoch: {epoch + 1}/{self.args.epochs}')
            running_loss = 0.0 #testing_loss = 0.0, 0.0
            global_step = epoch * len(self.dataloader)
            num_examples = 0

            for i, data in enumerate(tqdm(self.dataloader, total = len(self.dataloader), desc = 'epoch : '+str(epoch))):
                # Unpack data
                center, context, center_freq, context_freq  = data
                center, context = center.to(self.args.device), context.to(self.args.device)
                # print(context.shape) # 942, 1
                # Remove accumulated gradients
                self.optim.zero_grad()
                # Get context vectors
                center_embed, context_embed = self.model(center, context)
                # Calc loss: SGNS
                loss = self.sgns(center_embed, context_embed, center, context, self.word_freq)
                # Backprop and update
                loss.backward()
                self.optim.step()

                # Keep track of loss
                running_loss += loss.item()
                global_step += 1
                num_examples += len(data)  # Last batch's size may not equal args.batch_size

                # TESTING LOSS
                #testing_loss += test_loss

                # Log at step
                #if global_step % self.args.log_step == 0:
                #    norm = (i + 1) * num_examples
                #    self.log_step(epoch, global_step, running_loss/norm, center, context)

            norm = (i + 1) * num_examples
            # self.log_and_save_epoch(epoch, running_loss / norm)
            torch.save(self.model.state_dict(), './model_result/normal/emb_model' + str(epoch)+'.pth')
            self.log_step(epoch, global_step, running_loss / norm)#, testing_loss / norm)
            print('\nGRAD:', np.sum(self.model.word_embeds.weight.grad.clone().detach().cpu().numpy()))

        self.writer.close()

    def log_and_save_epoch(self, epoch, loss):
        # Visualize document embeddings
        self.writer.add_embedding(
            self.model.word_embeds.weight,
            global_step=epoch,
            tag=f'we_epoch_{epoch}',
        )

        # Save checkpoint
        print(f'Beginning to save checkpoint')
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
            'loss': loss,
        }, f'epoch_{epoch}_ckpt.pth')
        print(f'Finished saving checkpoint')

    def log_step(self, epoch, global_step, loss):
        print(f'#############################################')
        print(f'EPOCH: {epoch} | STEP: {global_step} | LOSS {loss}')# | TEST LOSS {test_loss}')
        print(f'#############################################')

        #self.writer.add_scalar('train_loss', loss, global_step)

        # Log embeddings!
        print('\nLearned embeddings:')
        for word in self.dataset.queries:
            print(f'word: {word} neighbors: {self.model.nearest_neighbors(word, self.dataset.dictionary)}')


In [None]:
def get_args():
    parser = argparse.ArgumentParser(description="PyTorch LDA2Vec Training")

    """
    Data handling
    """
    parser.add_argument('--dataset-dir', type=str, default='../input/ffdata/',
                        help='dataset directory (default: ../input/ffdata/)')
    parser.add_argument('--workers', type=int, default=4, metavar='N',
                       help='dataloader threads (default: 4)')
    parser.add_argument('--window-size', type=int, default=5, help='Window size\
                        used when generating training examples (default: 5)')
    parser.add_argument('--file-batch-size', type=int, default=250, help='Batch size\
                        used when multi-threading the generation of training examples\
                        (default: 250)')

    """
    Model Parameters
    """
    parser.add_argument('--embedding-len', type=int, default=256, help='Length of\
                        embeddings in model (default: 256)')

    """
    Training Hyperparameters
    """
    parser.add_argument('--epochs', type=int, default=15, metavar='N',
                        help='number of epochs to train for - iterations over the dataset (default: 15)')
    parser.add_argument('--batch-size', type=int, default=1024,
                        metavar='N', help='number of examples in a training batch (default: 1024)')
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--seed', type=int, default=42, metavar='S',
                        help='random seed (default: 42)')

    """
    Checkpoint options
    """
    parser.add_argument('--log-step', type=int, default=250, help='Step at which for every step training info\
                        is logged. (default: 250)')

    """
    Training Settings
    """
    parser.add_argument('--device', type=str, default=t.device("cuda:0" if t.cuda.is_available() else "cpu"),
                        help='device to train on (default: cuda:0 if cuda is available otherwise cpu)')

    return parser.parse_args("")

In [None]:
class FineFoodDataset(SkipGramDataset):

    def __init__(self, args, examples_path='./examples.pth', dict_path='./dict.pth'):
        SkipGramDataset.__init__(self, args)
        self.name = 'Fine Food Dataset'
        self.queries = ['good', 'tasty','bad', 'negtively', 'charming']
        # 如果要直接讀取的話，要把 self.raw_doc = None, examples_path = None 都註解掉, 改成已經存下來的
        self.raw_doc = None
        examples_path = None
        dict_path = None
        if examples_path is not None and dict_path is not None:
            self.load(examples_path, dict_path)
        else:
            # nlp = English()
            # Tokenizer(nlp.vocab)
            self.tokenizer = Tokenizer(args)
            # self.tokenizer = Tokenizer(nlp.vocab)
            self.files = self.tokenize_files()
            self.generate_examples_serial()

            #def save(self, examples_path, dict_path):
            self.save('Reviews.pth', 'Reviews.pth')

        print(f'There are {len(self.dictionary)} tokens and {len(self.examples)} examples.')
        self._build_dictionary()
        self.queries = self.get_unusual_words()
        # self.word_freq[id] = new_freq
        
    def get_unusual_words(self):
    #     # self.dictionary : dict 
    #         # key: word id 
    #         # value: word freq
    #     # 取出來要轉成字（gensim dictionary class)
        total = 0
        dict2 = {}
        list2 = []
        for k in self.word_freq.keys():
            if self.word_freq[k] <= 10: ## self.word_freq[k] <= 9 | 輸出 0 個
                dict2.setdefault(k,self.word_freq[k])
                total += 1 
        for i in dict2.keys():
            list2.append(self.dictionary[i])
        return list2
        
        
        
    def load_files(self):
        return self.files
    
    
    def remove_token(self,Target = 'good'):
        remove_word_count , after_word_count = 0 , 0
        target_word = Target
        for sentence_token in self.raw_doc:
            for word in sentence_token: 
                if word == target_word:
                    prob = np.random.randint(10000)
                    if prob >= 100:  #rate
                        sentence_token.remove(target_word)
                        remove_word_count += 1
                    else:
                        after_word_count += 1 
            # print(sentence_token)
        print("'",target_word, "' is the target word!")  
        print("Original: ",remove_word_count + after_word_count,"words, we have removed: ",remove_word_count, ", currently remain: ", after_word_count )
        return self.raw_doc
    

    def tokenize_files(self):
        ff_df = pd.read_csv('../input/ffdata/Reviews.csv', na_filter=False)
        review_data = ff_df['Text'].values
        review_data = review_data[:100000]
        # f = self.preprocess(review_data)
        self.raw_doc = [self.tokenizer.tokenize_doc(f) for f in tqdm(review_data, desc='Tokenizing Docs')]
        # self.raw_doc = self.remove_token('good')
        self.raw_doc = self.remove_token('tasty')
        # self.raw_doc = self.remove_token('bad')
        return self.raw_doc

In [None]:
args = get_args()
trainer = Trainer(args)
w = trainer.dataset.word_freq