In [1]:
import torch
import torch.nn as nn

In [31]:
class Bottle(nn.Module):
    def forward(self, input):
        if len(input.size()) <= 2:
            return super(Bottle, self).forward(input)
        size = input.size()[:2]
        out = super(Bottle, self).forward(input.view(size[0]*size[1], -1))
        return out.view(size[0], size[1], -1)
    
class Linear(Bottle, nn.Linear):
    pass

class Encoder(nn.Module):

    def __init__(self, config):
        super(Encoder, self).__init__()
        self.config = config
        input_size = config.d_proj if config.projection else config.d_embed
        dropout = 0 if config.n_layers == 1 else config.dp_ratio
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
                        num_layers=config.n_layers, dropout=dropout,
                        bidirectional=config.birnn)

    def forward(self, inputs):
        batch_size = inputs.size()[1]
        state_shape = self.config.n_cells, batch_size, self.config.d_hidden
        h0 = c0 = inputs.new_zeros(state_shape)
        outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
        return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)


In [27]:
class SNLIClassifier(nn.Module):

    def __init__(self, config):
        super(SNLIClassifier, self).__init__()
        self.config = config
        self.embed = nn.Embedding(config.n_embed, config.d_embed)
        self.projection = Linear(config.d_embed, config.d_proj)
        self.encoder = Encoder(config)
        self.dropout = nn.Dropout(p=config.dp_ratio)
        self.relu = nn.ReLU()
        seq_in_size = 2*config.d_hidden
        if self.config.birnn:
            seq_in_size *= 2
        lin_config = [seq_in_size]*2
        self.out = nn.Sequential(
            Linear(*lin_config),
            self.relu,
            self.dropout,
            Linear(*lin_config),
            self.relu,
            self.dropout,
            Linear(*lin_config),
            self.relu,
            self.dropout,
            Linear(seq_in_size, config.d_out))

    def forward(self, batch):
        prem_embed = self.embed(batch.premise)
        hypo_embed = self.embed(batch.hypothesis)
        if self.config.fix_emb:
            prem_embed = prem_embed.detach()
            hypo_embed = hypo_embed.detach()
        if self.config.projection:
            prem_embed = self.relu(self.projection(prem_embed))
            hypo_embed = self.relu(self.projection(hypo_embed))
        premise = self.encoder(prem_embed)
        hypothesis = self.encoder(hypo_embed)
        scores = self.out(torch.cat([premise, hypothesis], 1))
        return scores

In [5]:
def makedirs(name):
    import os, errno
    try:
        os.makedirs(name)
    except OSError as ex:
        if ex.errno == errno.EEXIST and os.path.isdir(name):
            # ignore existing directory
            pass
        else:
            # a different error happened
            raise

In [6]:
import os
import time
import glob

import torch
import torch.optim as O
import torch.nn as nn

from torchtext import data
from torchtext import datasets

In [7]:
device = torch.device('cpu')
device

device(type='cpu')

In [10]:
inputs = data.Field(lower=True, tokenize='spacy')
answers = data.Field(sequential=False)

answers

<torchtext.data.field.Field at 0x11fd2c2d0>

In [11]:
train, dev, test = datasets.SNLI.splits(inputs, answers)
train

downloading snli_1.0.zip


snli_1.0.zip: 100%|██████████| 94.6M/94.6M [01:40<00:00, 938kB/s] 


extracting


<torchtext.datasets.nli.SNLI at 0x12b2fb510>

In [12]:
inputs.build_vocab(train, dev, test)

In [71]:
print(train.examples[0].hypothesis)
print(train.examples[0].label)
print(train.examples[0].premise)

['a', 'person', 'is', 'training', 'his', 'horse', 'for', 'a', 'competition', '.']
neutral
['a', 'person', 'on', 'a', 'horse', 'jumps', 'over', 'a', 'broken', 'down', 'airplane', '.']


In [72]:
class Config():
    def __init__(self):
        self.batch_size=128
        self.birnn=True
        self.d_embed=100
        self.d_hidden=300
        self.d_proj=300
        self.dev_every=1000
        self.dp_ratio=0.2
        self.epochs=50
        self.fix_emb=True
        self.gpu=0
        self.log_every=50
        self.lr=0.001
        self.n_layers=1
        self.projection=True
        
config = Config()

In [73]:
VECTOR_CACHE = ".vector_cache/input_vectors.pt"
inputs.vocab.vectors = torch.load(VECTOR_CACHE)
answers.build_vocab(train)

In [74]:
train_iter, dev_iter, test_iter = data.BucketIterator.splits((train, dev, test), batch_size=config.batch_size, device=device)

In [75]:
config.n_embed = len(inputs.vocab)
config.d_out = len(answers.vocab)
config.n_cells = config.n_layers
# double the number of cells for bidirectional networks
if config.birnn:
    config.n_cells *= 2

In [76]:
model = SNLIClassifier(config)
model.embed.weight.data.copy_(inputs.vocab.vectors)
model.to(device)

SNLIClassifier(
  (embed): Embedding(34193, 100)
  (projection): Linear(in_features=100, out_features=300, bias=True)
  (encoder): Encoder(
    (rnn): LSTM(300, 300, bidirectional=True)
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (relu): ReLU()
  (out): Sequential(
    (0): Linear(in_features=1200, out_features=1200, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=1200, out_features=1200, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=1200, out_features=1200, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.2, inplace=False)
    (9): Linear(in_features=1200, out_features=4, bias=True)
  )
)

In [77]:
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=config.lr)

In [78]:
#saving results 
config.save_path = 'results'
config.save_every = 1000
config.dev_every = 1000
config.log_every = 50

In [79]:
iterations = 0
start = time.time()
best_dev_acc = -1
header = '  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss     Accuracy  Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template =     ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
makedirs(config.save_path)
print(header)

  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss     Accuracy  Dev/Accuracy


In [80]:
for epoch in range(config.epochs):
    train_iter.init_epoch()
    n_correct, n_total = 0, 0
    for batch_idx, batch in enumerate(train_iter):
        # switch model to training mode, clear gradient accumulators
        model.train(); opt.zero_grad()

        iterations += 1

        # forward pass
        answer = model(batch)

        # calculate accuracy of predictions in the current batch
        n_correct += (torch.max(answer, 1)[1].view(batch.label.size()) == batch.label).sum().item()
        n_total += batch.batch_size
        train_acc = 100. * n_correct/n_total

        # calculate loss of the network output with respect to training labels
        loss = criterion(answer, batch.label)

        # backpropagate and update optimizer learning rate
        loss.backward(); opt.step()

        # checkpoint model periodically
        if iterations % config.save_every == 0:
            snapshot_prefix = os.path.join(config.save_path, 'snapshot')
            snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.item(), iterations)
            torch.save(model, snapshot_path)
            for f in glob.glob(snapshot_prefix + '*'):
                if f != snapshot_path:
                    os.remove(f)

        # evaluate performance on validation set periodically
        if iterations % config.dev_every == 0:

            # switch model to evaluation mode
            model.eval(); dev_iter.init_epoch()

            # calculate accuracy on validation set
            n_dev_correct, dev_loss = 0, 0
            with torch.no_grad():
                for dev_batch_idx, dev_batch in enumerate(dev_iter):
                     answer = model(dev_batch)
                     n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()) == dev_batch.label).sum().item()
                     dev_loss = criterion(answer, dev_batch.label)
            dev_acc = 100. * n_dev_correct / len(dev)

            print(dev_log_template.format(time.time()-start,
                epoch, iterations, 1+batch_idx, len(train_iter),
                100. * (1+batch_idx) / len(train_iter), loss.item(), dev_loss.item(), train_acc, dev_acc))

            # update best valiation set accuracy
            if dev_acc > best_dev_acc:

                # found a model with better validation set accuracy

                best_dev_acc = dev_acc
                snapshot_prefix = os.path.join(config.save_path, 'best_snapshot')
                snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.item(), iterations)

                # save model, delete previous 'best_snapshot' files
                torch.save(model, snapshot_path)
                for f in glob.glob(snapshot_prefix + '*'):
                    if f != snapshot_path:
                        os.remove(f)

        elif iterations % config.log_every == 0:

            # print progress message
            print(log_template.format(time.time()-start,
                epoch, iterations, 1+batch_idx, len(train_iter),
                100. * (1+batch_idx) / len(train_iter), loss.item(), ' '*8, n_correct/n_total*100, ' '*12))


    30     0        50    50/4292        1% 1.097565               32.2031             
    59     0       100   100/4292        2% 1.083437               33.1172             
    87     0       150   150/4292        3% 1.054205               34.5625             
   115     0       200   200/4292        5% 0.998015               37.4844             
   144     0       250   250/4292        6% 1.036661               39.9500             
   172     0       300   300/4292        7% 0.953668               42.1589             
   202     0       350   350/4292        8% 0.930430               44.3795             
   230     0       400   400/4292        9% 1.033562               46.1777             
   259     0       450   450/4292       10% 0.842066               47.6389             
   286     0       500   500/4292       12% 0.823885               48.8656             
   314     0       550   550/4292       13% 0.878075               49.9176             
   341     0       600   600/429

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


   581     0      1000  1000/4292       23% 0.708396 0.841920      55.6461      65.0376
   611     0      1050  1050/4292       24% 0.747746               56.0536             
   639     0      1100  1100/4292       26% 0.757141               56.5135             
   668     0      1150  1150/4292       27% 0.745852               56.9035             
   698     0      1200  1200/4292       28% 0.767681               57.2871             
   726     0      1250  1250/4292       29% 0.754053               57.6463             
   754     0      1300  1300/4292       30% 0.670202               57.9940             
   784     0      1350  1350/4292       31% 0.699412               58.3345             
   813     0      1400  1400/4292       33% 0.738962               58.6105             
   842     0      1450  1450/4292       34% 0.676197               58.8992             
   872     0      1500  1500/4292       35% 0.878886               59.1807             
   902     0      1550  1550/429

  3372     1      5700  1408/4292       33% 0.507411               73.7826             
  3401     1      5750  1458/4292       34% 0.719232               73.8040             
  3430     1      5800  1508/4292       35% 0.576160               73.8499             
  3459     1      5850  1558/4292       36% 0.594627               73.8697             
  3488     1      5900  1608/4292       37% 0.655011               73.8850             
  3518     1      5950  1658/4292       39% 0.757308               73.9120             
  3554     1      6000  1708/4292       40% 0.626556 0.834286      73.9603      74.2837
  3584     1      6050  1758/4292       41% 0.589518               73.9703             
  3614     1      6100  1808/4292       42% 0.691125               74.0096             
  3644     1      6150  1858/4292       43% 0.628673               74.0157             
  3674     1      6200  1908/4292       44% 0.610645               74.0308             
  3703     1      6250  1958/429

  6175     2     10400  1816/4292       42% 0.696220               77.3850             
  6205     2     10450  1866/4292       43% 0.547834               77.3865             
  6235     2     10500  1916/4292       45% 0.565296               77.3894             
  6265     2     10550  1966/4292       46% 0.575840               77.3851             
  6294     2     10600  2016/4292       47% 0.660185               77.4034             
  6323     2     10650  2066/4292       48% 0.630932               77.4092             
  6352     2     10700  2116/4292       49% 0.610800               77.4301             
  6381     2     10750  2166/4292       50% 0.478904               77.4343             
  6411     2     10800  2216/4292       52% 0.444697               77.4354             
  6440     2     10850  2266/4292       53% 0.589423               77.4530             
  6470     2     10900  2316/4292       54% 0.610539               77.4335             
  6500     2     10950  2366/429

  8980     3     15100  2224/4292       52% 0.598095               79.5923             
  9009     3     15150  2274/4292       53% 0.538246               79.5875             
  9038     3     15200  2324/4292       54% 0.483874               79.5735             
  9068     3     15250  2374/4292       55% 0.540715               79.5779             
  9097     3     15300  2424/4292       56% 0.537867               79.5818             
  9126     3     15350  2474/4292       58% 0.508657               79.5782             
  9157     3     15400  2524/4292       59% 0.495590               79.5779             
  9187     3     15450  2574/4292       60% 0.519140               79.5615             
  9217     3     15500  2624/4292       61% 0.584871               79.5604             
  9245     3     15550  2674/4292       62% 0.461898               79.5721             
  9276     3     15600  2724/4292       63% 0.474244               79.5653             
  9305     3     15650  2774/429

 11780     4     19800  2632/4292       61% 0.508601               81.2310             
 11810     4     19850  2682/4292       62% 0.429348               81.2337             
 11839     4     19900  2732/4292       64% 0.473875               81.2214             
 11869     4     19950  2782/4292       65% 0.451760               81.2171             
 11904     4     20000  2832/4292       66% 0.490215 0.693060      81.2208      80.4206
 11933     4     20050  2882/4292       67% 0.539341               81.2083             
 11962     4     20100  2932/4292       68% 0.518873               81.1906             
 11992     4     20150  2982/4292       69% 0.472233               81.1803             
 12022     4     20200  3032/4292       71% 0.443345               81.1884             
 12051     4     20250  3082/4292       72% 0.443081               81.1864             
 12079     4     20300  3132/4292       73% 0.463803               81.1822             
 12109     4     20350  3182/429

 14577     5     24500  3040/4292       71% 0.424061               82.6249             
 14607     5     24550  3090/4292       72% 0.434195               82.6173             
 14637     5     24600  3140/4292       73% 0.470540               82.6052             
 14665     5     24650  3190/4292       74% 0.464009               82.6173             
 14695     5     24700  3240/4292       75% 0.544853               82.6085             
 14725     5     24750  3290/4292       77% 0.420452               82.5898             
 14753     5     24800  3340/4292       78% 0.497484               82.5786             
 14783     5     24850  3390/4292       79% 0.468179               82.5574             
 14812     5     24900  3440/4292       80% 0.378668               82.5491             
 14842     5     24950  3490/4292       81% 0.505691               82.5531             
 14878     5     25000  3540/4292       82% 0.489978 0.688883      82.5532      79.9939
 14907     5     25050  3590/429

 17387     6     29200  3448/4292       80% 0.571949               83.8777             
 17416     6     29250  3498/4292       82% 0.425050               83.8689             
 17446     6     29300  3548/4292       83% 0.494647               83.8597             
 17475     6     29350  3598/4292       84% 0.425615               83.8422             
 17505     6     29400  3648/4292       85% 0.472370               83.8261             
 17533     6     29450  3698/4292       86% 0.538235               83.8287             
 17563     6     29500  3748/4292       87% 0.449195               83.8197             
 17593     6     29550  3798/4292       88% 0.360688               83.8155             
 17623     6     29600  3848/4292       90% 0.558718               83.8148             
 17652     6     29650  3898/4292       91% 0.329401               83.8160             
 17680     6     29700  3948/4292       92% 0.347525               83.8045             
 17710     6     29750  3998/429

 20191     7     33900  3856/4292       90% 0.434044               85.0861             
 20219     7     33950  3906/4292       91% 0.375216               85.0788             
 20254     7     34000  3956/4292       92% 0.351897 0.675020      85.0759      80.6442
 20284     7     34050  4006/4292       93% 0.403616               85.0730             
 20312     7     34100  4056/4292       95% 0.395270               85.0638             
 20342     7     34150  4106/4292       96% 0.331280               85.0573             
 20371     7     34200  4156/4292       97% 0.478914               85.0521             
 20402     7     34250  4206/4292       98% 0.352717               85.0420             
 20432     7     34300  4256/4292       99% 0.388877               85.0268             
 20463     8     34350    14/4292        0% 0.405923               87.8906             
 20493     8     34400    64/4292        1% 0.249378               87.4756             
 20522     8     34450   114/429

 22982     8     38600  4264/4292       99% 0.329886               86.2000             
 23012     9     38650    22/4292        1% 0.302724               88.9915             
 23041     9     38700    72/4292        2% 0.377225               88.3789             
 23071     9     38750   122/4292        3% 0.360438               88.2877             
 23100     9     38800   172/4292        4% 0.233788               88.2676             
 23130     9     38850   222/4292        5% 0.284300               88.3552             
 23160     9     38900   272/4292        6% 0.291704               88.3990             
 23190     9     38950   322/4292        8% 0.392189               88.4559             
 23225     9     39000   372/4292        9% 0.361709 0.756018      88.4325      79.9431
 23254     9     39050   422/4292       10% 0.328974               88.3201             
 23283     9     39100   472/4292       11% 0.378663               88.3094             
 23312     9     39150   522/429

 25791    10     43300   380/4292        9% 0.242617               89.2804             
 25821    10     43350   430/4292       10% 0.253985               89.2424             
 25851    10     43400   480/4292       11% 0.356879               89.2920             
 25881    10     43450   530/4292       12% 0.328446               89.2659             
 25911    10     43500   580/4292       14% 0.242100               89.2996             
 25940    10     43550   630/4292       15% 0.348431               89.2696             
 25969    10     43600   680/4292       16% 0.242464               89.2245             
 25998    10     43650   730/4292       17% 0.236783               89.2252             
 26028    10     43700   780/4292       18% 0.277941               89.1877             
 26059    10     43750   830/4292       19% 0.233382               89.1274             
 26088    10     43800   880/4292       21% 0.388963               89.0705             
 26119    10     43850   930/429

KeyboardInterrupt: 

In [None]:
v