In [1]:
from ncn.model import *
from ncn.training import *

In [2]:
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# set up training
data = get_bucketized_iterators("/home/jupyter/tutorials/seminar_kd/arxiv_data.csv",
                                batch_size = 64,
                                len_context_vocab = 20000,
                                len_title_vocab = 20000,
                                len_aut_vocab = 20000)
PAD_IDX = data.ttl.vocab.stoi['<pad>']
cntxt_vocab_len = len(data.cntxt.vocab)
aut_vocab_len = len(data.aut.vocab)
ttl_vocab_len = len(data.ttl.vocab)

INFO:ncn.data:Getting fields...
INFO:ncn.data:Loading dataset...
INFO:ncn.data:Building vocab...


In [4]:
net = NeuralCitationNetwork(context_filters=[4,4,5,6,7],
                            author_filters=[1,2],
                            context_vocab_size=cntxt_vocab_len,
                            title_vocab_size=ttl_vocab_len,
                            author_vocab_size=aut_vocab_len,
                            pad_idx=PAD_IDX,
                            num_filters=256,
                            authors=True, 
                            embed_size=128,
                            num_layers=1,
                            hidden_size=256,
                            dropout_p=0.2,
                            show_attention=False)
net.to(DEVICE)

  "num_layers={}".format(dropout, num_layers))


NeuralCitationNetwork(
  (encoder): NCNEncoder(
    (dropout): Dropout(p=0.2)
    (context_embedding): Embedding(20002, 128, padding_idx=1)
    (context_encoder): TDNNEncoder(
      (encoder): ModuleList(
        (0): TDNN(
          (conv): Conv2d(1, 256, kernel_size=(128, 4), stride=(1, 1), bias=False)
        )
        (1): TDNN(
          (conv): Conv2d(1, 256, kernel_size=(128, 4), stride=(1, 1), bias=False)
        )
        (2): TDNN(
          (conv): Conv2d(1, 256, kernel_size=(128, 5), stride=(1, 1), bias=False)
        )
        (3): TDNN(
          (conv): Conv2d(1, 256, kernel_size=(128, 6), stride=(1, 1), bias=False)
        )
        (4): TDNN(
          (conv): Conv2d(1, 256, kernel_size=(128, 7), stride=(1, 1), bias=False)
        )
      )
      (fc): Linear(in_features=1280, out_features=1280, bias=True)
    )
    (author_embedding): Embedding(20002, 128, padding_idx=1)
    (citing_author_encoder): TDNNEncoder(
      (encoder): ModuleList(
        (0): TDNN(
        

In [5]:
train_losses, valid_losses = train_model(model = net, 
                                         train_iterator = data.train_iter, 
                                         valid_iterator = data.valid_iter,
                                         lr = 0.001,
                                         pad = PAD_IDX,
                                         model_name = "embed_128_hid_256_1_GRU")

INFO:ncn.training:INITIALIZING NEURAL CITATION NETWORK WITH AUTHORS = True
Running on: cuda
Number of model parameters: 24,341,796
Encoders: # Filters = 256, Context filter length = [4, 4, 5, 6, 7],  Context filter length = [1, 2]
Embeddings: Dimension = 128, Pad index = 1, Context vocab = 20002, Author vocab = 20002, Title vocab = 20004
Decoder: # GRU cells = 1, Hidden size = 256
Parameters: Dropout = 0.2, Show attention = False
-------------------------------------------------
TRAINING SETTINGS
Seed = 34, # Epochs = 30, Batch size = 64, Initial lr = 0.001


HBox(children=(IntProgress(value=0, description='Epochs', max=30, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 01 | Time: 15m 34s
INFO:ncn.training:	Train Loss: 2594.097
INFO:ncn.training:	 Val. Loss: 2180.488


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 02 | Time: 15m 39s
INFO:ncn.training:	Train Loss: 2132.993
INFO:ncn.training:	 Val. Loss: 2068.490


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 03 | Time: 10m 45s
INFO:ncn.training:	Train Loss: 2023.243
INFO:ncn.training:	 Val. Loss: 1975.022


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 04 | Time: 8m 7s
INFO:ncn.training:	Train Loss: 1966.124
INFO:ncn.training:	 Val. Loss: 1944.621


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 05 | Time: 8m 11s
INFO:ncn.training:	Train Loss: 1919.624
INFO:ncn.training:	 Val. Loss: 1912.295


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 06 | Time: 8m 7s
INFO:ncn.training:	Train Loss: 1882.007
INFO:ncn.training:	 Val. Loss: 1879.405


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 07 | Time: 8m 8s
INFO:ncn.training:	Train Loss: 1855.756
INFO:ncn.training:	 Val. Loss: 1871.833


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 08 | Time: 8m 9s
INFO:ncn.training:	Train Loss: 1831.764
INFO:ncn.training:	 Val. Loss: 1850.504


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 09 | Time: 8m 11s
INFO:ncn.training:	Train Loss: 1809.247
INFO:ncn.training:	 Val. Loss: 1838.891


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 10 | Time: 8m 10s
INFO:ncn.training:	Train Loss: 1793.052
INFO:ncn.training:	 Val. Loss: 1836.095


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 11 | Time: 8m 15s
INFO:ncn.training:	Train Loss: 1780.649
INFO:ncn.training:	 Val. Loss: 1833.148


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 12 | Time: 8m 25s
INFO:ncn.training:	Train Loss: 1765.161
INFO:ncn.training:	 Val. Loss: 1827.139


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 13 | Time: 8m 32s
INFO:ncn.training:	Train Loss: 1755.556
INFO:ncn.training:	 Val. Loss: 1811.238


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 14 | Time: 8m 34s
INFO:ncn.training:	Train Loss: 1744.238
INFO:ncn.training:	 Val. Loss: 1818.713


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 15 | Time: 8m 33s
INFO:ncn.training:	Train Loss: 1733.692
INFO:ncn.training:	 Val. Loss: 1812.673


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 16 | Time: 8m 29s
INFO:ncn.training:	Train Loss: 1723.186
INFO:ncn.training:	 Val. Loss: 1783.967


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 17 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1712.181
INFO:ncn.training:	 Val. Loss: 1791.940


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 18 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1703.743
INFO:ncn.training:	 Val. Loss: 1771.110


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 19 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1696.656
INFO:ncn.training:	 Val. Loss: 1783.945


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 20 | Time: 8m 26s
INFO:ncn.training:	Train Loss: 1687.520
INFO:ncn.training:	 Val. Loss: 1768.322


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 21 | Time: 8m 26s
INFO:ncn.training:	Train Loss: 1679.543
INFO:ncn.training:	 Val. Loss: 1763.635


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 22 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1674.936
INFO:ncn.training:	 Val. Loss: 1759.226


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 23 | Time: 8m 25s
INFO:ncn.training:	Train Loss: 1669.297
INFO:ncn.training:	 Val. Loss: 1758.367


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 24 | Time: 8m 28s
INFO:ncn.training:	Train Loss: 1660.654
INFO:ncn.training:	 Val. Loss: 1746.081


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 25 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1655.955
INFO:ncn.training:	 Val. Loss: 1753.151


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 26 | Time: 8m 27s
INFO:ncn.training:	Train Loss: 1649.255
INFO:ncn.training:	 Val. Loss: 1743.118


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 27 | Time: 9m 14s
INFO:ncn.training:	Train Loss: 1642.262
INFO:ncn.training:	 Val. Loss: 1737.279


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 28 | Time: 13m 21s
INFO:ncn.training:	Train Loss: 1637.807
INFO:ncn.training:	 Val. Loss: 1737.793


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 29 | Time: 13m 15s
INFO:ncn.training:	Train Loss: 1632.062
INFO:ncn.training:	 Val. Loss: 1755.706


HBox(children=(IntProgress(value=0, description='Training batches', max=6280, style=ProgressStyle(description_…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=785, style=ProgressStyle(description…

INFO:ncn.training:Epoch: 30 | Time: 13m 16s
INFO:ncn.training:	Train Loss: 1627.023
INFO:ncn.training:	 Val. Loss: 1730.720



