<a href="https://colab.research.google.com/github/pablocosta/deleteRetreaveGenerate/blob/master/experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!git clone https://github.com/pablocosta/deleteRetreaveGenerate
%cd deleteRetreaveGenerate
!pip install -r requirements.txt



Cloning into 'deleteRetreaveGenerate'...
remote: Enumerating objects: 266, done.[K
remote: Counting objects: 100% (28/28), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 266 (delta 5), reused 21 (delta 3), pack-reused 238[K
Receiving objects: 100% (266/266), 220.37 MiB | 26.56 MiB/s, done.
Resolving deltas: 100% (123/123), done.
Checking out files: 100% (123/123), done.
/content/deleteRetreaveGenerate/deleteRetreaveGenerate


In [None]:
!cat ./data/Ustance/pos.neg.txt

cat: ./data/Ustance/pos.neg.txt: No such file or directory


# Training

In [4]:
#run experiment
import sys

import json
import numpy as np
import logging
import argparse
import os
import time
import numpy as np
import glob

import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter

import src.evaluation as evaluation
from src.cuda import CUDA
import src.data as data
import src.models as models


overfit = False
bleu = True
config = json.load(open("./yelp_config.json", 'r'))

workingDir = config['data']['working_dir']

if not os.path.exists(workingDir):
    os.makedirs(workingDir)

config_path = os.path.join(workingDir, 'config.json')
if not os.path.exists(config_path):
    with open(config_path, 'w') as f:
        json.dump(config, f)

# set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='%s/train_log' % workingDir,
)

console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

logging.info('Reading data ...')
src, tgt = data.read_nmt_data(
    src=config['data']['src'],
    config=config,
    tgt=config['data']['tgt'],
    attribute_vocab=config['data']['attribute_vocab'],
    ngram_attributes=config['data']['ngram_attributes']
)

srcTest, tgtTest = data.read_nmt_data(
    src=config['data']['src_test'],
    config=config,
    tgt=config['data']['tgt_test'],
    attribute_vocab=config['data']['attribute_vocab'],
    ngram_attributes=config['data']['ngram_attributes'],
    train_src=src,
    train_tgt=tgt
)
logging.info('...done!')

logging.info('...done!')

#model configs

batchSize    = config['data']['batch_size']
maxLength    = config['data']['max_len']
srcVocabSize = len(src['tok2id'])
tgtVocabSize = len(tgt['tok2id'])

weightMask                         = torch.ones(tgtVocabSize)
weightMask[tgt['tok2id']['<pad>']] = 0
lossCriterion                      = nn.CrossEntropyLoss(weight=weightMask)

if CUDA:
    weightMask    = weightMask.cuda()
    lossCriterion = lossCriterion.cuda()

torch.manual_seed(config['training']['random_seed'])
np.random.seed(config['training']['random_seed'])


#model definition

model = models.SeqModel(
    srcVocabSize=srcVocabSize,
    tgtVocabSize=tgtVocabSize,
    padIdSrc=src['tok2id']['<pad>'],
    padIdTgt=tgt['tok2id']['<pad>'],
    batchSize=batchSize,
    config=config
)

logging.info('MODEL HAS %s params' %  model.countParams())
model, startEpoch = models.attemptLoadModel(
    model          = model,
    checkpointDir  = workingDir)

if CUDA:
    model = model.cuda()

writer = SummaryWriter(workingDir)


if config['training']['optimizer'] == 'adam':
    lr        = config['training']['learning_rate']
    optimizer = optim.Adam(model.parameters(), lr=lr)
elif config['training']['optimizer'] == 'sgd':
    lr        = config['training']['learning_rate']
    optimizer = optim.SGD(model.parameters(), lr=lr)
else:
    raise NotImplementedError("Learning method not recommend for task")

epochLoss             = []
startSinceLastReport  = time.time()
wordsSinceLastReport  = 0
lossesSinceLastReport = []
bestMetric            = 0.0
bestEpoch             = 0
curMetric             = 0.0 # log perplexity or BLEU
numExamples           = min(len(src['content']), len(tgt['content']))
numBatches            = numExamples / batchSize



STEP = 0
for epoch in range(startEpoch, config['training']['epochs']):
    if curMetric > bestMetric:
        # rm old checkpoint
        for ckpt_path in glob.glob(workingDir + '/model.*'):
            os.system("rm %s" % ckpt_path)
        # replace with new checkpoint
        torch.save(model.state_dict(), workingDir + '/model.%s.ckpt' % epoch)

        bestMetric = curMetric
        bestEpoch  = epoch - 1

    losses = []
    for i in range(0, numExamples, batchSize):

        if overfit:
            i = 50

        batchIdx = i / batchSize
        
        inputContent, inputAux, outPut = data.minibatch(
            src, tgt, i, batchSize, maxLength, config['model']['model_type']
            )
        
        inputLinesSrc, _, srcLens, srcMask, _ = inputContent
        inputIdsAux, _, auxLens, auxMask, _ = inputAux
        inputLinesTgt, outputLinesTgt, _, _, _ = outPut
        
        decoderLogit, decoderProbs = model(inputLinesSrc, inputLinesTgt, srcMask, srcLens,
            inputIdsAux, auxLens, auxMask)

        optimizer.zero_grad()

        loss = lossCriterion(
            decoderLogit.contiguous().view(-1, tgtVocabSize), outputLinesTgt.view(-1)
        )

        losses.append(loss.item())
        lossesSinceLastReport.append(loss.item())
        epochLoss.append(loss.item())
        loss.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), config['training']['max_norm'])

        writer.add_scalar('stats/grad_norm', norm, STEP)

        optimizer.step()

        if overfit or batchIdx % config['training']['batches_per_report'] == 0:

            s = float(time.time() - startSinceLastReport)
            eps = (batchSize * config['training']['batches_per_report']) / s
            avgLoss = np.mean(lossesSinceLastReport)
            info = (epoch, batchIdx, numBatches, eps, avgLoss, curMetric)
            writer.add_scalar('stats/EPS', eps, STEP)
            writer.add_scalar('stats/loss', avgLoss, STEP)
            logging.info('EPOCH: %s ITER: %s/%s EPS: %.2f LOSS: %.4f METRIC: %.4f' % info)
            startSinceLastReport = time.time()
            wordsSinceLastReport = 0
            lossesSinceLastReport = []


        STEP += 1
    if overfit:
        continue

    logging.info('EPOCH %s COMPLETE. EVALUATING...' % epoch)
    

    start = time.time()
    model.eval()
    
    devLoss = evaluation.evaluateLpp(model, srcTest, tgtTest, config)

    writer.add_scalar('eval/loss', devLoss, epoch)

    if bleu and epoch >= config['training'].get('inference_start_epoch', 1):
        curMetric, editDistance, inputs, preds, golds, auxs = evaluation.inferenceMetrics(
            model, srcTest, tgtTest, config)

        with open(workingDir + '/auxs.%s' % epoch, 'w') as f:
            f.write('\n'.join(auxs) + '\n')
        with open(workingDir + '/inputs.%s' % epoch, 'w') as f:
            f.write('\n'.join(inputs) + '\n')
        with open(workingDir + '/preds.%s' % epoch, 'w') as f:
            f.write('\n'.join(preds) + '\n')
        with open(workingDir + '/golds.%s' % epoch, 'w') as f:
            f.write('\n'.join(golds) + '\n')

        writer.add_scalar('eval/edit_distance', editDistance, epoch)
        writer.add_scalar('eval/bleu', curMetric, epoch)

    else:
        cur_metric = devLoss

    model.train()

    logging.info('METRIC: %s. TIME: %.2fs CHECKPOINTING...' % (
        curMetric, (time.time() - start)))
    avgLoss = np.mean(epochLoss)
    epochLoss = []

writer.close()

2021-07-07 17:08:06,515 - INFO - Reading data ...
2021-07-07 17:08:34,545 - INFO - ...done!
2021-07-07 17:08:34,547 - INFO - ...done!
  "num_layers={}".format(dropout, num_layers))
2021-07-07 17:08:36,949 - INFO - MODEL HAS 9181445 params
2021-07-07 17:08:37,081 - INFO - EPOCH: 0 ITER: 0.0/692.2578125 EPS: 458793.44 LOSS: 9.1708 METRIC: 0.0000
2021-07-07 17:08:46,307 - INFO - EPOCH: 0 ITER: 200.0/692.2578125 EPS: 5550.42 LOSS: 5.7927 METRIC: 0.0000
2021-07-07 17:08:55,517 - INFO - EPOCH: 0 ITER: 400.0/692.2578125 EPS: 5560.37 LOSS: 5.0524 METRIC: 0.0000
2021-07-07 17:09:04,689 - INFO - EPOCH: 0 ITER: 600.0/692.2578125 EPS: 5582.99 LOSS: 4.8063 METRIC: 0.0000
2021-07-07 17:09:08,930 - INFO - EPOCH 0 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:09:08,984 - INFO - METRIC: 0.0. TIME: 0.05s CHECKPOINTING...
2021-07-07 17:09:09,035 - INFO - EPOCH: 1 ITER: 0.0/692.2578125 EPS: 11786.72 LOSS: 4.5929 METRIC: 0.0000
2021-07-07 17:09:18,214 - INFO - EPOCH: 1 ITER: 200.0/692.2578125 EPS: 5579.07 LOSS: 4.3767 METRIC: 0.0000
2021-07-07 17:09:28,015 - INFO - EPOCH: 1 ITER: 400.0/692.2578125 EPS: 5224.78 LOSS: 4.1506 METRIC: 0.0000
2021-07-07 17:09:37,191 - INFO - EPOCH: 1 ITER: 600.0/692.2578125 EPS: 5581.11 LOSS: 3.9426 METRIC: 0.0000
2021-07-07 17:09:41,422 - INFO - EPOCH 1 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:09:55,980 - INFO - METRIC: 1.6151974366083746. TIME: 14.56s CHECKPOINTING...
2021-07-07 17:09:56,173 - INFO - EPOCH: 2 ITER: 0.0/692.2578125 EPS: 2697.49 LOSS: 3.7895 METRIC: 1.6152
2021-07-07 17:10:05,432 - INFO - EPOCH: 2 ITER: 200.0/692.2578125 EPS: 5530.68 LOSS: 3.7266 METRIC: 1.6152
2021-07-07 17:10:14,663 - INFO - EPOCH: 2 ITER: 400.0/692.2578125 EPS: 5548.23 LOSS: 3.5867 METRIC: 1.6152
2021-07-07 17:10:23,854 - INFO - EPOCH: 2 ITER: 600.0/692.2578125 EPS: 5571.46 LOSS: 3.4787 METRIC: 1.6152
2021-07-07 17:10:28,095 - INFO - EPOCH 2 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:10:42,508 - INFO - METRIC: 2.5511577080801695. TIME: 14.41s CHECKPOINTING...
2021-07-07 17:10:42,690 - INFO - EPOCH: 3 ITER: 0.0/692.2578125 EPS: 2718.51 LOSS: 3.3667 METRIC: 2.5512
2021-07-07 17:10:51,899 - INFO - EPOCH: 3 ITER: 200.0/692.2578125 EPS: 5560.90 LOSS: 3.3384 METRIC: 2.5512
2021-07-07 17:11:01,089 - INFO - EPOCH: 3 ITER: 400.0/692.2578125 EPS: 5572.48 LOSS: 3.2442 METRIC: 2.5512
2021-07-07 17:11:10,316 - INFO - EPOCH: 3 ITER: 600.0/692.2578125 EPS: 5550.05 LOSS: 3.1596 METRIC: 2.5512
2021-07-07 17:11:14,535 - INFO - EPOCH 3 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:11:28,851 - INFO - METRIC: 3.167290665322437. TIME: 14.31s CHECKPOINTING...
2021-07-07 17:11:29,039 - INFO - EPOCH: 4 ITER: 0.0/692.2578125 EPS: 2734.84 LOSS: 3.0743 METRIC: 3.1673
2021-07-07 17:11:38,282 - INFO - EPOCH: 4 ITER: 200.0/692.2578125 EPS: 5540.75 LOSS: 3.0498 METRIC: 3.1673
2021-07-07 17:11:47,464 - INFO - EPOCH: 4 ITER: 400.0/692.2578125 EPS: 5577.77 LOSS: 2.9370 METRIC: 3.1673
2021-07-07 17:11:56,686 - INFO - EPOCH: 4 ITER: 600.0/692.2578125 EPS: 5552.86 LOSS: 2.9022 METRIC: 3.1673
2021-07-07 17:12:00,936 - INFO - EPOCH 4 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:12:15,200 - INFO - METRIC: 3.8436157660167733. TIME: 14.26s CHECKPOINTING...
2021-07-07 17:12:15,386 - INFO - EPOCH: 5 ITER: 0.0/692.2578125 EPS: 2738.19 LOSS: 2.8199 METRIC: 3.8436
2021-07-07 17:12:24,623 - INFO - EPOCH: 5 ITER: 200.0/692.2578125 EPS: 5544.07 LOSS: 2.8338 METRIC: 3.8436
2021-07-07 17:12:34,500 - INFO - EPOCH: 5 ITER: 400.0/692.2578125 EPS: 5184.80 LOSS: 2.7631 METRIC: 3.8436
2021-07-07 17:12:43,746 - INFO - EPOCH: 5 ITER: 600.0/692.2578125 EPS: 5540.49 LOSS: 2.6907 METRIC: 3.8436
2021-07-07 17:12:47,967 - INFO - EPOCH 5 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:13:02,283 - INFO - METRIC: 4.425431556382094. TIME: 14.31s CHECKPOINTING...
2021-07-07 17:13:02,466 - INFO - EPOCH: 6 ITER: 0.0/692.2578125 EPS: 2735.35 LOSS: 2.6379 METRIC: 4.4254
2021-07-07 17:13:11,725 - INFO - EPOCH: 6 ITER: 200.0/692.2578125 EPS: 5530.86 LOSS: 2.6231 METRIC: 4.4254
2021-07-07 17:13:20,956 - INFO - EPOCH: 6 ITER: 400.0/692.2578125 EPS: 5547.09 LOSS: 2.5437 METRIC: 4.4254
2021-07-07 17:13:30,190 - INFO - EPOCH: 6 ITER: 600.0/692.2578125 EPS: 5546.23 LOSS: 2.5027 METRIC: 4.4254
2021-07-07 17:13:34,433 - INFO - EPOCH 6 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:13:48,760 - INFO - METRIC: 5.06247972907667. TIME: 14.33s CHECKPOINTING...
2021-07-07 17:13:48,944 - INFO - EPOCH: 7 ITER: 0.0/692.2578125 EPS: 2730.64 LOSS: 2.4439 METRIC: 5.0625
2021-07-07 17:13:58,173 - INFO - EPOCH: 7 ITER: 200.0/692.2578125 EPS: 5548.94 LOSS: 2.4385 METRIC: 5.0625
2021-07-07 17:14:07,405 - INFO - EPOCH: 7 ITER: 400.0/692.2578125 EPS: 5547.04 LOSS: 2.3966 METRIC: 5.0625
2021-07-07 17:14:16,642 - INFO - EPOCH: 7 ITER: 600.0/692.2578125 EPS: 5543.93 LOSS: 2.3495 METRIC: 5.0625
2021-07-07 17:14:20,856 - INFO - EPOCH 7 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:14:35,236 - INFO - METRIC: 5.947313043957265. TIME: 14.38s CHECKPOINTING...
2021-07-07 17:14:35,421 - INFO - EPOCH: 8 ITER: 0.0/692.2578125 EPS: 2726.68 LOSS: 2.3029 METRIC: 5.9473
2021-07-07 17:14:44,636 - INFO - EPOCH: 8 ITER: 200.0/692.2578125 EPS: 5557.57 LOSS: 2.2920 METRIC: 5.9473
2021-07-07 17:14:53,865 - INFO - EPOCH: 8 ITER: 400.0/692.2578125 EPS: 5548.68 LOSS: 2.2527 METRIC: 5.9473
2021-07-07 17:15:03,098 - INFO - EPOCH: 8 ITER: 600.0/692.2578125 EPS: 5546.21 LOSS: 2.2180 METRIC: 5.9473
2021-07-07 17:15:07,327 - INFO - EPOCH 8 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:15:21,645 - INFO - METRIC: 6.1886609236088175. TIME: 14.32s CHECKPOINTING...
2021-07-07 17:15:21,834 - INFO - EPOCH: 9 ITER: 0.0/692.2578125 EPS: 2732.97 LOSS: 2.1591 METRIC: 6.1887
2021-07-07 17:15:31,072 - INFO - EPOCH: 9 ITER: 200.0/692.2578125 EPS: 5543.65 LOSS: 2.1668 METRIC: 6.1887
2021-07-07 17:15:40,928 - INFO - EPOCH: 9 ITER: 400.0/692.2578125 EPS: 5195.65 LOSS: 2.1359 METRIC: 6.1887
2021-07-07 17:15:50,144 - INFO - EPOCH: 9 ITER: 600.0/692.2578125 EPS: 5556.19 LOSS: 2.0900 METRIC: 6.1887
2021-07-07 17:15:54,378 - INFO - EPOCH 9 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:16:08,701 - INFO - METRIC: 7.066296518908807. TIME: 14.32s CHECKPOINTING...
2021-07-07 17:16:08,884 - INFO - EPOCH: 10 ITER: 0.0/692.2578125 EPS: 2732.45 LOSS: 2.0283 METRIC: 7.0663
2021-07-07 17:16:18,133 - INFO - EPOCH: 10 ITER: 200.0/692.2578125 EPS: 5537.04 LOSS: 2.0516 METRIC: 7.0663
2021-07-07 17:16:27,371 - INFO - EPOCH: 10 ITER: 400.0/692.2578125 EPS: 5542.97 LOSS: 2.0105 METRIC: 7.0663
2021-07-07 17:16:36,637 - INFO - EPOCH: 10 ITER: 600.0/692.2578125 EPS: 5528.47 LOSS: 1.9864 METRIC: 7.0663
2021-07-07 17:16:40,854 - INFO - EPOCH 10 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:16:55,082 - INFO - METRIC: 8.12474004240489. TIME: 14.23s CHECKPOINTING...
2021-07-07 17:16:55,269 - INFO - EPOCH: 11 ITER: 0.0/692.2578125 EPS: 2748.25 LOSS: 1.9284 METRIC: 8.1247
2021-07-07 17:17:04,497 - INFO - EPOCH: 11 ITER: 200.0/692.2578125 EPS: 5549.30 LOSS: 1.9390 METRIC: 8.1247
2021-07-07 17:17:13,733 - INFO - EPOCH: 11 ITER: 400.0/692.2578125 EPS: 5544.80 LOSS: 1.9190 METRIC: 8.1247
2021-07-07 17:17:22,934 - INFO - EPOCH: 11 ITER: 600.0/692.2578125 EPS: 5565.38 LOSS: 1.8764 METRIC: 8.1247
2021-07-07 17:17:27,190 - INFO - EPOCH 11 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:17:41,573 - INFO - METRIC: 8.593755621565462. TIME: 14.38s CHECKPOINTING...
2021-07-07 17:17:41,770 - INFO - EPOCH: 12 ITER: 0.0/692.2578125 EPS: 2718.43 LOSS: 1.8344 METRIC: 8.5938
2021-07-07 17:17:51,031 - INFO - EPOCH: 12 ITER: 200.0/692.2578125 EPS: 5529.83 LOSS: 1.8506 METRIC: 8.5938
2021-07-07 17:18:00,239 - INFO - EPOCH: 12 ITER: 400.0/692.2578125 EPS: 5561.19 LOSS: 1.8333 METRIC: 8.5938
2021-07-07 17:18:09,477 - INFO - EPOCH: 12 ITER: 600.0/692.2578125 EPS: 5545.29 LOSS: 1.8127 METRIC: 8.5938
2021-07-07 17:18:13,684 - INFO - EPOCH 12 COMPLETE. EVALUATING...


256/500...

2021-07-07 17:18:28,091 - INFO - METRIC: 8.828225793607809. TIME: 14.41s CHECKPOINTING...
2021-07-07 17:18:28,277 - INFO - EPOCH: 13 ITER: 0.0/692.2578125 EPS: 2723.65 LOSS: 1.7494 METRIC: 8.8282
2021-07-07 17:18:37,540 - INFO - EPOCH: 13 ITER: 200.0/692.2578125 EPS: 5528.66 LOSS: 1.7601 METRIC: 8.8282
2021-07-07 17:18:47,402 - INFO - EPOCH: 13 ITER: 400.0/692.2578125 EPS: 5192.34 LOSS: 1.7431 METRIC: 8.8282


KeyboardInterrupt: ignored

In [None]:
%load_ext tensorboard
%tensorboard --logdir working_dir

In [None]:
!rm -r ./working_dir/*