In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from test import evaluate_test_set
import sts_data
import siamese_dan
import train
import test
from importlib import reload

### Data Preprocessing

In [5]:
reload(sts_data)
from sts_data import STSData

columns_mapping = {
        "sent1": "sentence_A",
        "sent2": "sentence_B",
        "label": "relatedness_score",
    }
dataset_name = "sick"
sick_data = STSData(
    dataset_name=dataset_name,
    columns_mapping=columns_mapping,
    normalize_labels=True,
    normalization_const=5.0,
)
batch_size = 64
sick_dataloaders = sick_data.get_data_loader(batch_size=batch_size)

INFO:root:loading and preprocessing data...
INFO:root:reading and preprocessing data completed...
INFO:root:creating vocabulary...
INFO:torchtext.vocab:Loading vectors from .vector_cache\wiki.simple.vec.pt
INFO:root:creating vocabulary completed...


### Hyperparameter initialisation and tuning

In [6]:
## initialise final tuned hyperparameters
vocab_size = len(sick_data.vocab)
embedding_size = 300
embedding_weights = sick_data.vocab.vectors
learning_rate = 1e-2
max_epochs = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


### Train a Siamese Deep Averaging Network (DAN) model

In [14]:
#reload(deep_avg_nw)

from siamese_dan import SiameseDAN
siamese_dan_model = SiameseDAN(
    batch_size=batch_size,
    vocab_size=vocab_size,
    embedding_size=embedding_size,
    embedding_weights=embedding_weights,
    device=device,
)
## move model to device
siamese_dan_model.to(device)

SiameseDAN(
  (embeddings): Embedding(2052, 300)
  (linear_1): Linear(in_features=300, out_features=128, bias=True)
  (linear_2): Linear(in_features=128, out_features=64, bias=True)
  (linear_3): Linear(in_features=64, out_features=32, bias=True)
  (linear_4): Linear(in_features=32, out_features=16, bias=True)
  (linear_5): Linear(in_features=16, out_features=8, bias=True)
)

In [15]:
reload(train)
from train import train_model
import torch.optim as optim


optimizer = torch.optim.Adam(siamese_dan_model.parameters(), lr=learning_rate, betas=(0.9, 0.98))

siamese_dan = train_model(
    model=siamese_dan_model,
    optimizer=optimizer,
    dataloader=sick_dataloaders,
    data=sick_data,
    max_epochs=max_epochs,
    config_dict={
        "device": device,
        "model_name": "siamese_dan",
    },
)

  0%|                                                                                            | 0/5 [00:00<?, ?it/s]

Running EPOCH 1
Running loss:  0.08791243582963944
Training set accuracy: 0.9120875656604767
Running loss:  0.06105341725051403
Training set accuracy: 0.9389465834945441
Running loss:  0.051468998193740845
Training set accuracy: 0.9485310014337301
Running loss:  0.043895182013511655
Training set accuracy: 0.9561048176139593
Running loss:  0.04205652084201574
Training set accuracy: 0.9579434793442487
Running loss:  0.03931999020278454
Training set accuracy: 0.9606800097972155


INFO:root:Evaluating accuracy on dev set


Evaluating validation set ....
Validation loss: 0.038
Validation set accuracy: 0.962
Validation loss: 0.041
Validation set accuracy: 0.959
Validation loss: 0.040
Validation set accuracy: 0.960
Validation loss: 0.044
Validation set accuracy: 0.956
Validation loss: 0.038
Validation set accuracy: 0.962


INFO:root:new model saved
INFO:root:Train loss: 0.05236326903104782 - acc: 0.9476367320487464 -- Validation loss: 0.030444644391536713 - acc: 0.9625064182494368
 20%|████████████████▊                                                                   | 1/5 [00:02<00:11,  2.83s/it]

Validation loss: 0.032
Validation set accuracy: 0.968
Validation loss: 0.030
Validation set accuracy: 0.970
Finished Training
Running EPOCH 2
Running loss:  0.03227730300277472
Training set accuracy: 0.9677226977422834
Running loss:  0.032437744364142415
Training set accuracy: 0.9675622552633285
Running loss:  0.03229008764028549
Training set accuracy: 0.9677099140360952
Running loss:  0.03083040229976177
Training set accuracy: 0.9691695986315608
Running loss:  0.033030600287020206
Training set accuracy: 0.9669693997129798
Running loss:  0.03551183231174946
Training set accuracy: 0.9644881688058377


INFO:root:Evaluating accuracy on dev set


Evaluating validation set ....
Validation loss: 0.026
Validation set accuracy: 0.974
Validation loss: 0.036
Validation set accuracy: 0.964
Validation loss: 0.038
Validation set accuracy: 0.962
Validation loss: 0.034
Validation set accuracy: 0.966
Validation loss: 0.032
Validation set accuracy: 0.968
Validation loss: 0.030
Validation set accuracy: 0.970


INFO:root:new model saved
INFO:root:Train loss: 0.0324363075196743 - acc: 0.9675636906986651 -- Validation loss: 0.026247471570968628 - acc: 0.9682045557669231
 40%|█████████████████████████████████▌                                                  | 2/5 [00:05<00:08,  2.85s/it]

Validation loss: 0.026
Validation set accuracy: 0.974
Finished Training
Running EPOCH 3
Running loss:  0.024848736822605133
Training set accuracy: 0.9751512637361884
Running loss:  0.02753753010183573
Training set accuracy: 0.9724624697118998
Running loss:  0.03071925789117813
Training set accuracy: 0.9692807413637639
Running loss:  0.026398716680705547
Training set accuracy: 0.9736012829467654
Running loss:  0.027652000822126866
Training set accuracy: 0.9723479982465506
Running loss:  0.02964489087462425
Training set accuracy: 0.9703551094979048


INFO:root:Evaluating accuracy on dev set


Evaluating validation set ....
Validation loss: 0.025
Validation set accuracy: 0.975
Validation loss: 0.034
Validation set accuracy: 0.966
Validation loss: 0.035
Validation set accuracy: 0.965
Validation loss: 0.031
Validation set accuracy: 0.969
Validation loss: 0.035
Validation set accuracy: 0.965
Validation loss: 0.031
Validation set accuracy: 0.969


INFO:root:new model saved
INFO:root:Train loss: 0.028189506381750107 - acc: 0.971810492646435 -- Validation loss: 0.018739506602287292 - acc: 0.9700927178242377
 60%|██████████████████████████████████████████████████▍                                 | 3/5 [00:08<00:05,  2.83s/it]

Validation loss: 0.019
Validation set accuracy: 0.981
Finished Training
Running EPOCH 4
Running loss:  0.028798578679561614
Training set accuracy: 0.9712014202028513
Running loss:  0.02426996547728777
Training set accuracy: 0.97573003442958
Running loss:  0.025085293874144553
Training set accuracy: 0.9749147064983845
Running loss:  0.025552907958626747
Training set accuracy: 0.9744470920413733
Running loss:  0.024340344406664372
Training set accuracy: 0.9756596546620131
Running loss:  0.02697857953608036
Training set accuracy: 0.9730214199051261


INFO:root:Evaluating accuracy on dev set


Evaluating validation set ....
Validation loss: 0.026
Validation set accuracy: 0.974
Validation loss: 0.030
Validation set accuracy: 0.970
Validation loss: 0.031
Validation set accuracy: 0.969
Validation loss: 0.028
Validation set accuracy: 0.972
Validation loss: 0.030
Validation set accuracy: 0.970


INFO:root:new model saved
INFO:root:Train loss: 0.026176396757364273 - acc: 0.9738236011775292 -- Validation loss: 0.013278643600642681 - acc: 0.973828983094011
 80%|███████████████████████████████████████████████████████████████████▏                | 4/5 [00:11<00:02,  2.84s/it]

Validation loss: 0.024
Validation set accuracy: 0.976
Validation loss: 0.013
Validation set accuracy: 0.987
Finished Training
Running EPOCH 5
Running loss:  0.022468891367316245
Training set accuracy: 0.9775311090052128
Running loss:  0.02257870975881815
Training set accuracy: 0.9774212902411819
Running loss:  0.024132957868278026
Training set accuracy: 0.975867042504251
Running loss:  0.024064053688198327
Training set accuracy: 0.9759359462186694
Running loss:  0.02322063222527504
Training set accuracy: 0.976779367774725
Running loss:  0.02433588244020939
Training set accuracy: 0.9756641170009971


INFO:root:Evaluating accuracy on dev set


Evaluating validation set ....
Validation loss: 0.024
Validation set accuracy: 0.976
Validation loss: 0.031
Validation set accuracy: 0.969
Validation loss: 0.034
Validation set accuracy: 0.966
Validation loss: 0.029
Validation set accuracy: 0.971
Validation loss: 0.029
Validation set accuracy: 0.971


INFO:root:Train loss: 0.02354118414223194 - acc: 0.9764588141840869 -- Validation loss: 0.018343064934015274 - acc: 0.9728157001414469
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:14<00:00,  2.87s/it]

Validation loss: 0.025
Validation set accuracy: 0.975
Validation loss: 0.018
Validation set accuracy: 0.982
Finished Training





### Compute test set accuracy

In [16]:
reload(test)
evaluate_test_set(
    model=siamese_dan,
    data_loader=sick_dataloaders,
    config_dict={
        "device": device,
        "model_name": "siamese_dan",
    },
)



INFO:root:Evaluating accuracy on test set
INFO:root:Evaluating accuracy on test set


Finished testing..............
Total test set accuracy: 0.972


### Code References

1. DAN code reference - https://www.kaggle.com/bobazooba/pytorch-deep-average-network-as-baseline/
2. Universal Sentence Encoder (DAN variant) - https://arxiv.org/pdf/1803.11175.pdf