In [1]:
import pathlib
import pickle
import string
import sys
import re

import numpy as np
import pandas as pd

In [2]:
import torch
import torch.nn.functional as F

In [3]:
current_dir = pathlib.Path('.').resolve().parent
sys.path.append(str(current_dir))

In [4]:
from src.dataset import Dataset
from src.trainer import Trainer
from src.models.convolution import ConvModel

In [None]:
dataset = Dataset(vocabulary='../output/vocabulary.pkl',
                  tags='../output/tags.pkl',
                  dataset='../output/processed.sample.csv.gz')

In [None]:
params = {'batch_size': 128,
          'shuffle': True,
          'num_workers': 4}

data_gen = torch.utils.data.DataLoader(dataset, **params)

In [None]:
model = ConvModel(embedding_dim=32, vocab_size=len(dataset._embedder._vocabulary), seq_len=dataset.pad)
optimizer = torch.optim.Adam(model.parameters(),lr=0.001, weight_decay=0)
loss_fn = torch.nn.BCEWithLogitsLoss()

In [None]:
trainer = Trainer(network=model, optimizer=optimizer, loss=loss_fn)

In [None]:
trainer.train(input_loader=data_gen, n_epochs=2)

In [12]:
from src.models.convolution import ConvModel
from src.tagger import Tagger

model = ConvModel(embedding_dim=32, vocab_size=10000, seq_len=250)

with open('../output/model.pkl', 'rb') as f:
    model.load_state_dict(torch.load(f, map_location='cpu'))

model.eval()

ConvModel(
  (embeddings): ScaledEmbedding(10000, 32)
  (conv1): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(64, 32, kernel_size=(5,), stride=(1,), padding=(2,))
  (mp1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (mp2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1984, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=100, bias=True)
)

In [13]:
tagger = Tagger('../output/vocabulary.pkl', '../output/tags.pkl')

In [14]:
tagger.register_trained_model(ConvModel, '../output/model.pkl', embedding_dim=32, vocab_size=10000, seq_len=250)

In [15]:
tagger.predict(string='javascript is super fun I like it')

array([[-19.064157 ,  -3.9674811,  -3.1815257,  -4.571185 ,  -3.10343  ,
         -3.4125044,  -1.6382846,  -2.028486 ,  -6.2389565,  -6.513911 ,
         -3.0123672,  -5.2532845,  -3.8068266,  -5.5408993,  -7.1865926,
         -4.8517094,  -4.8548346,  -6.3721204,  -6.8400273,  -4.220871 ,
         -3.871672 ,  -4.082539 ,  -7.3245077,  -6.656459 ,  -5.599983 ,
         -4.4831676,  -5.07969  ,  -3.8734856,  -8.289972 ,  -4.4749093,
         -5.6665516,  -5.2264657,  -3.9710698,  -5.947859 ,  -6.594405 ,
         -6.4984484,  -8.301108 ,  -5.675887 ,  -6.159503 ,  -5.7758727,
         -6.690189 ,  -5.0172095,  -7.0990715,  -6.1434946,  -7.5694423,
         -5.397152 ,  -4.036919 ,  -6.4244895,  -7.4717255,  -6.8372717,
         -6.245531 ,  -7.3477283,  -5.2969294,  -4.936797 ,  -8.215247 ,
         -6.2334595,  -5.43516  ,  -7.2845693,  -6.920402 ,  -7.850501 ,
         -6.3050585,  -6.9365153,  -9.779724 ,  -7.484503 ,  -8.781606 ,
         -5.6009164,  -5.1778197,  -7.3600187,  -6.

In [16]:
res = _

In [17]:
res.reshape(100)

array([-19.064157 ,  -3.9674811,  -3.1815257,  -4.571185 ,  -3.10343  ,
        -3.4125044,  -1.6382846,  -2.028486 ,  -6.2389565,  -6.513911 ,
        -3.0123672,  -5.2532845,  -3.8068266,  -5.5408993,  -7.1865926,
        -4.8517094,  -4.8548346,  -6.3721204,  -6.8400273,  -4.220871 ,
        -3.871672 ,  -4.082539 ,  -7.3245077,  -6.656459 ,  -5.599983 ,
        -4.4831676,  -5.07969  ,  -3.8734856,  -8.289972 ,  -4.4749093,
        -5.6665516,  -5.2264657,  -3.9710698,  -5.947859 ,  -6.594405 ,
        -6.4984484,  -8.301108 ,  -5.675887 ,  -6.159503 ,  -5.7758727,
        -6.690189 ,  -5.0172095,  -7.0990715,  -6.1434946,  -7.5694423,
        -5.397152 ,  -4.036919 ,  -6.4244895,  -7.4717255,  -6.8372717,
        -6.245531 ,  -7.3477283,  -5.2969294,  -4.936797 ,  -8.215247 ,
        -6.2334595,  -5.43516  ,  -7.2845693,  -6.920402 ,  -7.850501 ,
        -6.3050585,  -6.9365153,  -9.779724 ,  -7.484503 ,  -8.781606 ,
        -5.6009164,  -5.1778197,  -7.3600187,  -6.2094865,  -7.9

In [18]:
tagger.decrypt_top_tags(target=res.reshape(100))

{'java': -3.9674811,
 'json': -3.871672,
 'ajax': -3.8734856,
 'asp.net': -3.8068266,
 'php': -3.10343,
 'python': -3.4125044,
 'jquery': -1.6382846,
 'html': -2.028486,
 'c#': -3.1815257,
 'css': -3.0123672}