In [1]:
from ncn.model import *
from ncn.training import *

In [2]:
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
# set up training
data = get_bucketized_iterators("/home/jupyter/tutorials/seminar_kd/arxiv_data.csv")
PAD_IDX = data.ttl.vocab.stoi['<pad>']
cntxt_vocab_len = len(data.cntxt.vocab)
aut_vocab_len = len(data.aut.vocab)
ttl_vocab_len = len(data.ttl.vocab)

INFO:neural_citation.data:Getting fields...
INFO:neural_citation.data:Loading dataset...
INFO:neural_citation.data:Building vocab...


In [4]:
net = NeuralCitationNetwork(context_filters=[4,4,5], context_vocab_size=cntxt_vocab_len,
                            authors=True, author_filters=[1,2], author_vocab_size=aut_vocab_len,
                            title_vocab_size=ttl_vocab_len, pad_idx=PAD_IDX, num_layers=2)
net.to(DEVICE)
net.apply(init_weights)

NeuralCitationNetwork(
  (encoder): NCNEncoder(
    (dropout): Dropout(p=0.2)
    (context_embedding): Embedding(30002, 128, padding_idx=1)
    (context_encoder): TDNNEncoder(
      (fc): Linear(in_features=384, out_features=384, bias=True)
    )
    (author_embedding): Embedding(30002, 128, padding_idx=1)
    (citing_author_encoder): TDNNEncoder(
      (fc): Linear(in_features=256, out_features=256, bias=True)
    )
    (cited_author_encoder): TDNNEncoder(
      (fc): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (attention): Attention(
    (attn): Linear(in_features=256, out_features=128, bias=True)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=256, out_features=128, bias=True)
    )
    (embedding): Embedding(30004, 128, padding_idx=1)
    (rnn): GRU(256, 128)
    (out): Linear(in_features=384, out_features=30004, bias=True)
    (dropout): Dropout(p=0.2)
  )
)

In [5]:
train_model(net, data.train_iter, data.valid_iter, PAD_IDX)

INFO:neural_citation.train:INITIALIZING NEURAL CITATION NETWORK WITH AUTHORS = True
Running on: cuda
Number of model parameters: 23,533,236
Encoders: # Filters = 128, Context filter length = [4, 4, 5],  Context filter length = [1, 2]
Embeddings: Dimension = 128, Pad index = 1, Context vocab = 30002, Author vocab = 30002, Title vocab = 30004
Decoder: # GRU cells = 2, Hidden size = 128
Parameters: Dropout = 0.2
-------------------------------------------------
TRAINING SETTINGSSeed = 34, # Epochs = 20, Batch size = 32, Initial lr = 0.001
-------------------------------------------------


HBox(children=(IntProgress(value=0, description='Epochs', max=20, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 01 | Time: 10m 12s
INFO:neural_citation.train:	Train Loss: 1358.829
INFO:neural_citation.train:	 Val. Loss: 1204.091


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 02 | Time: 10m 16s
INFO:neural_citation.train:	Train Loss: 1171.287
INFO:neural_citation.train:	 Val. Loss: 1127.387


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 03 | Time: 10m 16s
INFO:neural_citation.train:	Train Loss: 1105.230
INFO:neural_citation.train:	 Val. Loss: 1088.705


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 04 | Time: 10m 16s
INFO:neural_citation.train:	Train Loss: 1052.335
INFO:neural_citation.train:	 Val. Loss: 1040.227


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 05 | Time: 10m 18s
INFO:neural_citation.train:	Train Loss: 1011.470
INFO:neural_citation.train:	 Val. Loss: 1017.329


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 06 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 988.245
INFO:neural_citation.train:	 Val. Loss: 1002.652


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 07 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 970.202
INFO:neural_citation.train:	 Val. Loss: 993.547


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 08 | Time: 10m 18s
INFO:neural_citation.train:	Train Loss: 956.020
INFO:neural_citation.train:	 Val. Loss: 984.732


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 09 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 943.725
INFO:neural_citation.train:	 Val. Loss: 976.897


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 10 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 932.823
INFO:neural_citation.train:	 Val. Loss: 971.426


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 11 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 924.169
INFO:neural_citation.train:	 Val. Loss: 966.751


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 12 | Time: 10m 19s
INFO:neural_citation.train:	Train Loss: 916.005
INFO:neural_citation.train:	 Val. Loss: 961.944


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 13 | Time: 10m 19s
INFO:neural_citation.train:	Train Loss: 908.501
INFO:neural_citation.train:	 Val. Loss: 956.897


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 14 | Time: 10m 18s
INFO:neural_citation.train:	Train Loss: 901.657
INFO:neural_citation.train:	 Val. Loss: 953.097


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 15 | Time: 10m 17s
INFO:neural_citation.train:	Train Loss: 895.585
INFO:neural_citation.train:	 Val. Loss: 950.381


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 16 | Time: 10m 19s
INFO:neural_citation.train:	Train Loss: 889.688
INFO:neural_citation.train:	 Val. Loss: 948.221


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 17 | Time: 10m 18s
INFO:neural_citation.train:	Train Loss: 884.770
INFO:neural_citation.train:	 Val. Loss: 945.932


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 18 | Time: 10m 18s
INFO:neural_citation.train:	Train Loss: 880.080
INFO:neural_citation.train:	 Val. Loss: 943.169


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

HBox(children=(IntProgress(value=0, description='Evaluating batches', max=1570, style=ProgressStyle(descriptio…

INFO:neural_citation.train:Epoch: 19 | Time: 10m 19s
INFO:neural_citation.train:	Train Loss: 875.977
INFO:neural_citation.train:	 Val. Loss: 944.372


HBox(children=(IntProgress(value=0, description='Training batches', max=10989, style=ProgressStyle(description…

KeyboardInterrupt: 