In [2]:
import os, sys
module_path = os.path.abspath(os.path.join('..'))
sys.path.append(module_path)

import torch
import torch.nn as nn
import torch.nn.functional as F

from src.trainer import *


class LinearRegression(nn.Module):
    def __init__(self, output_dim=2):
        super().__init__()
        
        self.embeddings = BagOfWords()
        self.linear = nn.Linear(self.embeddings.vocab_size, output_dim)
    
    def forward(self, batch):
        embedded = self.embeddings(batch)
        preds = self.linear(embedded)
        # Sigmoid?
        return preds.squeeze()
    

class BagOfWords:
    
    unk_idx = 0
    
    def __init__(self):
        self.set_vocab()
        
    def __call__(self, *args):
        return self.featurize(*args)
    
    def featurize(self, batch):
        """ Turn batch or document into discrete, binarized bag of words """
        sentences = [s.tokens for s in batch.sents]
        token_idx = [[self.stoi(t) for t in tokenized] 
                      for tokenized in sentences]
        
        featurized = torch.zeros((len(token_idx), self.vocab_size))
        for sent_idx, counts in enumerate(token_idx):
            featurized[sent_idx, counts] += 1
        
        return to_var(featurized)
    
    def set_vocab(self):
        """Set corpus vocab """
        # Map string to intersected index.
        self._stoi = {s: i for i, s in enumerate(self.get_vocab())}
        
        # Vocab size, plus one for UNK tokens (idx: 0)
        self.vocab_size = len(self._stoi.keys()) + 1 
        
    def get_vocab(self, filename='../src/vocabulary.txt'):
        """ Read in vocabulary (top 30K words, covers ~93.5% of all tokens) """ 
        with open(filename, 'r') as f:
            vocab = f.read().split(',')
        return vocab
        
    def stoi(self, s):
        """ String to index (s to i) for embedding lookup """
        idx = self._stoi.get(s)
        return idx + 1 if idx else self.unk_idx


model = LinearRegression()

trainer = Trainer(model=model,
                  train_dir='../data/wiki_727/train',
                  val_dir='../data/wiki_50/test',
                  test_dir=None,
                  batch_size=256,
                  lr=5e-4)

trainer.train(num_epochs=100, 
              steps=100,
              val_ckpt=1)

HBox(children=(IntProgress(value=0), HTML(value='')))

Label 0: 0.499121 | Label 1: 0.500880
Step: 1 | Loss: 0.157997 | Num. sents: 10761 | Segs correct: 756 / 1306 | Texts correct: 3894 / 9455
Label 0: 0.497782 | Label 1: 0.502217
Step: 2 | Loss: 0.159785 | Num. sents: 8610 | Segs correct: 775 / 1059 | Texts correct: 2344 / 7551
Label 0: 0.496486 | Label 1: 0.503514
Step: 3 | Loss: 0.154951 | Num. sents: 9865 | Segs correct: 917 / 1174 | Texts correct: 1984 / 8691
Label 0: 0.495092 | Label 1: 0.504908
Step: 4 | Loss: 0.145292 | Num. sents: 11265 | Segs correct: 1088 / 1252 | Texts correct: 1555 / 10013
Label 0: 0.493807 | Label 1: 0.506193
Step: 5 | Loss: 0.153638 | Num. sents: 10273 | Segs correct: 1108 / 1213 | Texts correct: 1019 / 9060
Label 0: 0.492255 | Label 1: 0.507745
Step: 6 | Loss: 0.154408 | Num. sents: 9568 | Segs correct: 1081 / 1136 | Texts correct: 633 / 8432
Label 0: 0.490948 | Label 1: 0.509052
Step: 7 | Loss: 0.151097 | Num. sents: 10481 | Segs correct: 1178 / 1216 | Texts correct: 577 / 9265
Label 0: 0.489256 | Label 1

KeyboardInterrupt: 