In [85]:
# import torch
# from bilstm_crf import BiLSTMCRF
from torch import optim
from utils import *
from tqdm import tqdm
from sklearn.preprocessing import MultiLabelBinarizer


# import evaluation metrics 
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split

torch.manual_seed(1)

<torch._C.Generator at 0x1f9e81ada50>

In [86]:
import torch
import torch.nn as nn
from TorchCRF import CRF

class BiLSTMCRF(nn.Module):
    def __init__(self, vocab_size, word_embedding_dim, hidden_dim, output_dim):
        
        super(BiLSTMCRF, self).__init__()
        # hyperparameters
        self.word_embedding_dim = word_embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.output_dim = output_dim

        # model layers
        self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim)
        self.lstm = nn.LSTM(word_embedding_dim, hidden_dim // 2, bidirectional=True, dropout=0.2)
        self.hidden_to_tag = nn.Linear(hidden_dim, output_dim)
        self.crf = CRF(output_dim)

    def __init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2), torch.randn(2, 1, self.hidden_dim // 2)) # initialize hidden state

    def __get_lstm_features(self, sentence):
        self.hidden = self.__init_hidden()
        word_embeddings = self.word_embedding(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(word_embeddings, self.hidden)
        lstm_features = self.hidden_to_tag(lstm_out)
        return lstm_features

    def neg_log_likelihood(self, sentence, tags):
        features = self.__get_lstm_features(sentence) # emissions from the lstm layer
        tags = tags.view(-1, 1)
        loss = -self.crf(features, tags)
        return loss

    def forward(self, sentence):
        lstm_features = self.__get_lstm_features(sentence) # emissions from the lstm layer
        tag_sequence = self.crf.decode(lstm_features)
        return tag_sequence
    

In [87]:
def read_dataset():
    data = pd.read_csv('../ner_dataset/ner_dataset.csv', encoding='latin1')

    # remove white spaces from column names
    data.columns = data.columns.str.strip()

    print(data.columns)
    # Group by 'Sentence #' and aggregate
    grouped_data = data.groupby('Sentence #').agg({
        'Word': lambda x: ''.join(x),  # Join words into a single sentence
        'Tag': lambda x: list(x.str.strip()),       # Collect tags into a list
        'Intent': lambda x: list(x.str.strip().str.replace('_', ' '))     # Collect intents into a list
    }).reset_index()  # Reset index to make 'Sentence #' a regular column

    return data, grouped_data


def prepare_data(dataframe):
    dataset = []
    for _, row in dataframe.iterrows():
        sentence = row['Word'][1:]
        tags = row['Tag']
        intents = row['Intent']
        dataset.append((sentence, tags, intents[0]))

    return dataset

In [88]:
# Format data
data, goruped_data = read_dataset()
print(data["Intent"].unique())

goruped_data.head()

Index(['Sentence #', 'Word', 'Tag', 'Intent'], dtype='object')
[' variable declaration' ' function declaration' ' class declaration'
 ' assignment operation' ' array operation' ' bitwise operation'
 ' mathematical operation' ' membership operation' ' casting'
 ' io operation' ' assertion' ' libraries' ' file system' ' ide operation'
 ' comment' ' conditional operation' ' iterative operation']


Unnamed: 0,Sentence #,Word,Tag,Intent
0,0,make start time as double and initialize 0.00...,"[O, B-VAR, I-VAR, O, B-TYPE, O, O, B-VAL]","[variable declaration, variable declaration, v..."
1,1,declare min value as integer and value 131313,"[O, B-VAR, I-VAR, O, B-TYPE, O, O, B-VAL]","[variable declaration, variable declaration, v..."
2,2,define settings as boolean and value false,"[O, B-VAR, O, B-TYPE, O, O, B-VAL]","[variable declaration, variable declaration, v..."
3,3,define y as integer and assign to 12345,"[O, B-VAR, O, B-TYPE, O, O, O, B-VAL]","[variable declaration, variable declaration, v..."
4,4,initialize k as string and initialize it with...,"[O, B-VAR, O, B-TYPE, O, O, O, O, B-VAL, I-VAL]","[variable declaration, variable declaration, v..."


In [89]:
training_data = prepare_data(goruped_data)

sentences = [sentence for sentence, _, _ in training_data]
tags = [tag for _, tag, _ in training_data]
intents = set([intent for _, _, intent in training_data])

intents

{'array operation',
 'assertion',
 'assignment operation',
 'bitwise operation',
 'casting',
 'class declaration',
 'comment',
 'conditional operation',
 'file system',
 'function declaration',
 'ide operation',
 'io operation',
 'iterative operation',
 'libraries',
 'mathematical operation',
 'membership operation',
 'variable declaration'}

In [90]:
separate_data = {} # a dictionary to store the data for each intent

# separate training data by intent
for intent in intents:
    separate_data[intent] = [(sentence, tag) for sentence, tag, _intent in training_data if _intent == intent]

for intent in separate_data:
    print(f'Intent: {intent}')
    print(f'Number of sentences: {len(separate_data[intent])}')
    print(f'Example: {separate_data[intent][0]}')
    print('\n')

Intent: conditional operation
Number of sentences: 160
Example: ('whether j less than -2020202', ['O', 'B-LHS', 'B-CONDITION', 'I-CONDITION', 'B-RHS'])


Intent: mathematical operation
Number of sentences: 370
Example: ('divide order queue and j', ['B-OPERATOR', 'B-OPERAND', 'I-OPERAND', 'O', 'B-OPERAND'])


Intent: casting
Number of sentences: 78
Example: ('make budget allocation to int', ['O', 'B-VAR', 'I-VAR', 'O', 'B-CAST_TYPE'])


Intent: comment
Number of sentences: 50
Example: ('comment a This is a date comment', ['I-COMMENT', 'I-COMMENT', 'B-COMMENT', 'I-COMMENT', 'O', 'I-COMMENT', 'O'])


Intent: function declaration
Number of sentences: 112
Example: ('write a procedure named rayleigh quotient iteration algorithm', ['O', 'O', 'O', 'O', 'B-FUNC', 'I-FUNC', 'I-FUNC', 'I-FUNC'])


Intent: ide operation
Number of sentences: 230
Example: ('launch terminal', ['B-ACTION', 'B-TYPE'])


Intent: membership operation
Number of sentences: 114
Example: ('check whether item array is in set 

In [109]:
separate_word_to_index = {}
separate_tag_to_index = {}
separate_index_to_tag = {}

for intent in separate_data:
    words = [word for sentence, _ in separate_data[intent] for word in sentence.split()]
    tags = [tag for _, tag in separate_data[intent] for tag in tag]

    word_to_index = {word: i + 1 for i, word in enumerate(set(words))}
    tag_to_index = {tag: i + 1 for i, tag in enumerate(set(tags))}

    word_to_index['<UNK>'] = 0
    tag_to_index['<UNK>'] = 0

    separate_word_to_index[intent] = word_to_index
    separate_tag_to_index[intent] = tag_to_index

separate_index_to_tag = {intent: {i: tag for tag, i in separate_tag_to_index[intent].items()} for intent in separate_tag_to_index}

for intent in separate_data:
    print(f'Intent: {intent}')
    print(f'Number of unique words: {len(separate_word_to_index[intent])}')
    print(f'Number of unique tags: {len(separate_tag_to_index[intent])}')
    print('\n')

Intent: conditional operation
Number of unique words: 283
Number of unique tags: 9


Intent: mathematical operation
Number of unique words: 285
Number of unique tags: 7


Intent: casting
Number of unique words: 112
Number of unique tags: 5


Intent: comment
Number of unique words: 47
Number of unique tags: 4


Intent: function declaration
Number of unique words: 173
Number of unique tags: 7


Intent: ide operation
Number of unique words: 115
Number of unique tags: 7


Intent: membership operation
Number of unique words: 249
Number of unique tags: 8


Intent: class declaration
Number of unique words: 58
Number of unique tags: 4


Intent: assignment operation
Number of unique words: 178
Number of unique tags: 6


Intent: assertion
Number of unique words: 117
Number of unique tags: 8


Intent: variable declaration
Number of unique words: 298
Number of unique tags: 7


Intent: iterative operation
Number of unique words: 162
Number of unique tags: 12


Intent: libraries
Number of unique wor

In [92]:
with open('../ner_dataset/intent_to_tags.json', 'r') as f:
    intent_to_tags = json.load(f)

final_intent_to_tags = {}

for intent in intent_to_tags.keys():
    final_intent_to_tags[intent] = []
    for tag in intent_to_tags[intent]:
        if tag == 'O':
            final_intent_to_tags[intent].append(tag)
        else:
            final_intent_to_tags[intent].append('B-' + tag)
            final_intent_to_tags[intent].append('I-' + tag)

        
final_intent_to_tags

{'variable declaration': ['O',
  'B-VAR',
  'I-VAR',
  'B-VAL',
  'I-VAL',
  'B-TYPE',
  'I-TYPE'],
 'function declaration': ['O',
  'B-FUNC',
  'I-FUNC',
  'B-PARAM',
  'I-PARAM',
  'B-RET_TYPE',
  'I-RET_TYPE'],
 'class declaration': ['O', 'B-CLASS_NAME', 'I-CLASS_NAME'],
 'assignment operation': ['O', 'B-LHS', 'I-LHS', 'B-RHS', 'I-RHS'],
 'conditional operation': ['O',
  'B-LHS',
  'I-LHS',
  'B-RHS',
  'I-RHS',
  'B-CONDITION',
  'I-CONDITION',
  'B-LOG',
  'I-LOG'],
 'iterative operation': ['O',
  'B-LOOP',
  'I-LOOP',
  'B-START',
  'I-START',
  'B-END',
  'I-END',
  'B-LHS',
  'I-LHS',
  'B-RHS',
  'I-RHS',
  'B-CONDITION',
  'I-CONDITION',
  'B-STEP',
  'I-STEP'],
 'array operation': ['O',
  'B-ARRAY',
  'I-ARRAY',
  'B-OPERATION',
  'I-OPERATION',
  'B-ELEMENT',
  'I-ELEMENT'],
 'bitwise operation': ['O',
  'B-OPERATOR',
  'I-OPERATOR',
  'B-OPERAND',
  'I-OPERAND'],
 'mathematical operation': ['O',
  'B-OPERAND',
  'I-OPERAND',
  'B-OPERATOR',
  'I-OPERATOR',
  'B-VAR',
  'I-

In [93]:
# training data for each intent
# training_data = {intent: [x_train, y_train, x_test, y_test]}
training_data = {} 
for intent in separate_data:
    sentences = [sentence for sentence, _ in separate_data[intent]]
    tags = [tag for _, tag in separate_data[intent]]
    x_train, x_test, y_train, y_test = train_test_split(sentences, tags, test_size=0.1, random_state=42)
    training_data[intent] = [x_train, y_train, x_test, y_test]



In [96]:
# models = {
#    intent: [model, optimizer]}
models = {}

for intent in final_intent_to_tags.keys():
    models[intent] = []
    models[intent].append(BiLSTMCRF(vocab_size=len(separate_word_to_index[intent]), word_embedding_dim=100, hidden_dim=128, output_dim=len(separate_tag_to_index[intent])))
    models[intent].append(optim.SGD(models[intent][0].parameters(), lr=0.01, weight_decay=1e-4))

In [98]:
# Check predictions before training
# tarining data {
#     intent: [x_train, y_train, x_test, y_test]
# }
with torch.no_grad():
    for intent in intents:
        print(f'Intent: {intent}')
        precheck_sent = prepare_sequence(training_data[intent][0][0].split(), separate_word_to_index[intent])
        precheck_tag = prepare_sequence(training_data[intent][1][0], separate_tag_to_index[intent])
        print("Actual tags:", precheck_tag)
        print("Predicted tags:", models[intent][0](precheck_sent))


Intent: conditional operation
Actual tags: tensor([3, 5, 2, 3, 6, 7, 3, 8, 1, 4, 3, 3, 3, 6, 3, 3, 3])
Predicted tags: [[0, 8, 7, 0, 4, 7, 7, 7, 0, 4, 8, 7, 7, 7, 7, 7, 7]]
Intent: mathematical operation
Actual tags: tensor([2, 4, 3, 4])
Predicted tags: [[2, 2, 2, 6]]
Intent: casting
Actual tags: tensor([4, 4, 4, 3, 1, 4, 4, 2])
Predicted tags: [[3, 0, 2, 0, 3, 3, 0, 0]]
Intent: comment
Actual tags: tensor([2, 1, 3, 3])
Predicted tags: [[1, 2, 2, 2]]
Intent: function declaration
Actual tags: tensor([3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 1, 5, 3, 3, 3])
Predicted tags: [[4, 4, 0, 3, 0, 2, 3, 6, 4, 2, 6, 6, 1, 6, 1]]
Intent: ide operation
Actual tags: tensor([1])
Predicted tags: [[2]]
Intent: membership operation
Actual tags: tensor([3, 4, 3, 7, 3, 3, 5])
Predicted tags: [[7, 4, 0, 4, 2, 4, 2]]
Intent: class declaration
Actual tags: tensor([2, 3, 2, 2, 2])
Predicted tags: [[2, 1, 2, 1, 2]]
Intent: assignment operation
Actual tags: tensor([3, 1, 2, 4])
Predicted tags: [[5, 2, 3, 5]]
Intent: asser

In [99]:
# training loop
def train(model, optimizer, x_train, y_train,epochs=10):
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for sentence, tags in zip(x_train, y_train):
            # Step 1. Remember that Pytorch accumulates gradients.
            # We need to clear them out before each instance
            model.zero_grad()
        
            sentence = prepare_sequence(sentence.split(), separate_word_to_index[intent])
            target_tags = prepare_sequence(tags, separate_tag_to_index[intent])

            # Step 3. Run our forward pass.
            loss = model.neg_log_likelihood(sentence, target_tags)

            total_loss += loss.item()

            # Step 4. Compute the loss, gradients, and update the parameters by
            loss.backward()
            optimizer.step()
        
        print(f"Epoch: {epoch}, Loss: {total_loss / len(x_train)}")

In [100]:
for intent in intents:
    x_train, y_train = training_data[intent][0], training_data[intent][1]
    train(models[intent][0], models[intent][1], x_train, y_train, epochs=20)

  5%|▌         | 1/20 [00:03<01:09,  3.63s/it]

Epoch: 0, Loss: 16.38363363676601


 10%|█         | 2/20 [00:05<00:50,  2.82s/it]

Epoch: 1, Loss: 8.70021105143759


 15%|█▌        | 3/20 [00:07<00:41,  2.46s/it]

Epoch: 2, Loss: 5.118986778789097


 20%|██        | 4/20 [00:10<00:37,  2.37s/it]

Epoch: 3, Loss: 3.5051518281300864


 25%|██▌       | 5/20 [00:12<00:34,  2.28s/it]

Epoch: 4, Loss: 2.6357622279061212


 30%|███       | 6/20 [00:14<00:31,  2.26s/it]

Epoch: 5, Loss: 2.048637310663859


 35%|███▌      | 7/20 [00:16<00:29,  2.23s/it]

Epoch: 6, Loss: 1.609494646390279


 40%|████      | 8/20 [00:19<00:27,  2.33s/it]

Epoch: 7, Loss: 1.333845575650533


 45%|████▌     | 9/20 [00:21<00:25,  2.30s/it]

Epoch: 8, Loss: 1.1059735616048176


 50%|█████     | 10/20 [00:23<00:21,  2.20s/it]

Epoch: 9, Loss: 0.9380903111563789


 55%|█████▌    | 11/20 [00:24<00:18,  2.01s/it]

Epoch: 10, Loss: 0.7744817601309882


 60%|██████    | 12/20 [00:26<00:15,  1.91s/it]

Epoch: 11, Loss: 0.6795238918728299


 65%|██████▌   | 13/20 [00:28<00:12,  1.83s/it]

Epoch: 12, Loss: 0.5962640841801962


 70%|███████   | 14/20 [00:29<00:10,  1.77s/it]

Epoch: 13, Loss: 0.5084752506679959


 75%|███████▌  | 15/20 [00:31<00:08,  1.74s/it]

Epoch: 14, Loss: 0.45874888367123073


 80%|████████  | 16/20 [00:33<00:06,  1.71s/it]

Epoch: 15, Loss: 0.42343129052056205


 85%|████████▌ | 17/20 [00:34<00:05,  1.68s/it]

Epoch: 16, Loss: 0.4637728797064887


 90%|█████████ | 18/20 [00:37<00:03,  1.83s/it]

Epoch: 17, Loss: 0.42807432015736896


 95%|█████████▌| 19/20 [00:38<00:01,  1.78s/it]

Epoch: 18, Loss: 0.31075159708658856


100%|██████████| 20/20 [00:40<00:00,  2.02s/it]


Epoch: 19, Loss: 0.32725338141123456


  5%|▌         | 1/20 [00:03<00:59,  3.11s/it]

Epoch: 0, Loss: 7.821479060628393


 10%|█         | 2/20 [00:06<00:56,  3.12s/it]

Epoch: 1, Loss: 2.8539040797465556


 15%|█▌        | 3/20 [00:09<00:51,  3.05s/it]

Epoch: 2, Loss: 1.616919620617016


 20%|██        | 4/20 [00:13<00:55,  3.46s/it]

Epoch: 3, Loss: 1.1639325912292298


 25%|██▌       | 5/20 [00:16<00:51,  3.46s/it]

Epoch: 4, Loss: 0.9114704518704801


 30%|███       | 6/20 [00:20<00:48,  3.50s/it]

Epoch: 5, Loss: 0.7853591964767502


 35%|███▌      | 7/20 [00:24<00:48,  3.73s/it]

Epoch: 6, Loss: 0.6789160820098968


 40%|████      | 8/20 [00:29<00:48,  4.07s/it]

Epoch: 7, Loss: 0.5685593616497051


 45%|████▌     | 9/20 [00:33<00:46,  4.25s/it]

Epoch: 8, Loss: 0.4831279419563912


 50%|█████     | 10/20 [00:37<00:41,  4.13s/it]

Epoch: 9, Loss: 0.43413879062320376


 55%|█████▌    | 11/20 [00:41<00:36,  4.02s/it]

Epoch: 10, Loss: 0.4239639791998419


 60%|██████    | 12/20 [00:45<00:32,  4.06s/it]

Epoch: 11, Loss: 0.3514131483014997


 65%|██████▌   | 13/20 [00:49<00:28,  4.09s/it]

Epoch: 12, Loss: 0.3057255959725595


 70%|███████   | 14/20 [00:53<00:23,  3.99s/it]

Epoch: 13, Loss: 0.323412313833609


 75%|███████▌  | 15/20 [00:57<00:19,  3.97s/it]

Epoch: 14, Loss: 0.27526089665410036


 80%|████████  | 16/20 [01:01<00:15,  3.91s/it]

Epoch: 15, Loss: 0.2423672919516807


 85%|████████▌ | 17/20 [01:05<00:12,  4.06s/it]

Epoch: 16, Loss: 0.23860607204494533


 90%|█████████ | 18/20 [01:09<00:07,  3.86s/it]

Epoch: 17, Loss: 0.20821612590068095


 95%|█████████▌| 19/20 [01:12<00:03,  3.76s/it]

Epoch: 18, Loss: 0.19942618132353546


100%|██████████| 20/20 [01:16<00:00,  3.85s/it]


Epoch: 19, Loss: 0.18166590739298868


  5%|▌         | 1/20 [00:01<00:23,  1.25s/it]

Epoch: 0, Loss: 6.590379340308053


 10%|█         | 2/20 [00:02<00:18,  1.00s/it]

Epoch: 1, Loss: 3.529719727379935


 15%|█▌        | 3/20 [00:02<00:15,  1.10it/s]

Epoch: 2, Loss: 1.9500179971967424


 20%|██        | 4/20 [00:04<00:15,  1.00it/s]

Epoch: 3, Loss: 1.1385491371154786


 25%|██▌       | 5/20 [00:04<00:14,  1.02it/s]

Epoch: 4, Loss: 0.7511830738612584


 30%|███       | 6/20 [00:06<00:14,  1.03s/it]

Epoch: 5, Loss: 0.5413041114807129


 35%|███▌      | 7/20 [00:06<00:12,  1.07it/s]

Epoch: 6, Loss: 0.39870786666870117


 40%|████      | 8/20 [00:07<00:10,  1.14it/s]

Epoch: 7, Loss: 0.32622153418404715


 45%|████▌     | 9/20 [00:08<00:09,  1.15it/s]

Epoch: 8, Loss: 0.2726176806858608


 50%|█████     | 10/20 [00:09<00:08,  1.20it/s]

Epoch: 9, Loss: 0.23230018615722656


 55%|█████▌    | 11/20 [00:09<00:06,  1.29it/s]

Epoch: 10, Loss: 0.20424070358276367


 60%|██████    | 12/20 [00:10<00:05,  1.36it/s]

Epoch: 11, Loss: 0.1783672196524484


 65%|██████▌   | 13/20 [00:11<00:05,  1.40it/s]

Epoch: 12, Loss: 0.1553335326058524


 70%|███████   | 14/20 [00:11<00:04,  1.45it/s]

Epoch: 13, Loss: 0.14760082789829798


 75%|███████▌  | 15/20 [00:12<00:03,  1.45it/s]

Epoch: 14, Loss: 0.13318319320678712


 80%|████████  | 16/20 [00:13<00:02,  1.37it/s]

Epoch: 15, Loss: 0.11857948303222657


 85%|████████▌ | 17/20 [00:13<00:02,  1.41it/s]

Epoch: 16, Loss: 0.10796639578683036


 90%|█████████ | 18/20 [00:14<00:01,  1.48it/s]

Epoch: 17, Loss: 0.102712767464774


 95%|█████████▌| 19/20 [00:15<00:00,  1.48it/s]

Epoch: 18, Loss: 0.09640609196254185


100%|██████████| 20/20 [00:15<00:00,  1.27it/s]


Epoch: 19, Loss: 0.08676580701555525


  5%|▌         | 1/20 [00:00<00:07,  2.42it/s]

Epoch: 0, Loss: 7.379187827640109


 10%|█         | 2/20 [00:00<00:07,  2.35it/s]

Epoch: 1, Loss: 4.836205318239


 15%|█▌        | 3/20 [00:01<00:08,  2.11it/s]

Epoch: 2, Loss: 3.6711947864956325


 20%|██        | 4/20 [00:01<00:07,  2.13it/s]

Epoch: 3, Loss: 3.031232664320204


 25%|██▌       | 5/20 [00:02<00:06,  2.19it/s]

Epoch: 4, Loss: 2.672249370151096


 30%|███       | 6/20 [00:02<00:06,  2.24it/s]

Epoch: 5, Loss: 2.336929840511746


 35%|███▌      | 7/20 [00:03<00:06,  1.99it/s]

Epoch: 6, Loss: 2.060345575544569


 40%|████      | 8/20 [00:03<00:06,  1.97it/s]

Epoch: 7, Loss: 1.8826168802049426


 45%|████▌     | 9/20 [00:04<00:05,  2.04it/s]

Epoch: 8, Loss: 1.7889499452379014


 50%|█████     | 10/20 [00:04<00:04,  2.07it/s]

Epoch: 9, Loss: 1.5911222457885743


 55%|█████▌    | 11/20 [00:05<00:04,  2.13it/s]

Epoch: 10, Loss: 1.4682529025607638


 60%|██████    | 12/20 [00:05<00:03,  2.13it/s]

Epoch: 11, Loss: 1.3571880552503797


 65%|██████▌   | 13/20 [00:06<00:03,  2.10it/s]

Epoch: 12, Loss: 1.2600887086656358


 70%|███████   | 14/20 [00:06<00:02,  2.15it/s]

Epoch: 13, Loss: 1.1411502414279513


 75%|███████▌  | 15/20 [00:07<00:02,  2.07it/s]

Epoch: 14, Loss: 1.0535507837931315


 80%|████████  | 16/20 [00:07<00:02,  1.84it/s]

Epoch: 15, Loss: 0.8903545803493924


 85%|████████▌ | 17/20 [00:08<00:01,  1.78it/s]

Epoch: 16, Loss: 0.8742980109320746


 90%|█████████ | 18/20 [00:08<00:01,  1.81it/s]

Epoch: 17, Loss: 0.7578209771050347


 95%|█████████▌| 19/20 [00:09<00:00,  1.80it/s]

Epoch: 18, Loss: 0.7107759475708008


100%|██████████| 20/20 [00:10<00:00,  1.99it/s]


Epoch: 19, Loss: 0.6618079503377279


  5%|▌         | 1/20 [00:01<00:34,  1.82s/it]

Epoch: 0, Loss: 12.01746741771698


 10%|█         | 2/20 [00:03<00:33,  1.85s/it]

Epoch: 1, Loss: 5.720541152954102


 15%|█▌        | 3/20 [00:05<00:27,  1.63s/it]

Epoch: 2, Loss: 3.1285230255126955


 20%|██        | 4/20 [00:06<00:24,  1.53s/it]

Epoch: 3, Loss: 1.9941353607177734


 25%|██▌       | 5/20 [00:07<00:22,  1.52s/it]

Epoch: 4, Loss: 1.5183394813537598


 30%|███       | 6/20 [00:09<00:20,  1.49s/it]

Epoch: 5, Loss: 1.2282610893249513


 35%|███▌      | 7/20 [00:10<00:18,  1.45s/it]

Epoch: 6, Loss: 1.0009594917297364


 40%|████      | 8/20 [00:12<00:17,  1.45s/it]

Epoch: 7, Loss: 0.8633214950561523


 45%|████▌     | 9/20 [00:13<00:16,  1.46s/it]

Epoch: 8, Loss: 0.7666165924072266


 50%|█████     | 10/20 [00:15<00:15,  1.50s/it]

Epoch: 9, Loss: 0.6893827438354492


 55%|█████▌    | 11/20 [00:16<00:13,  1.49s/it]

Epoch: 10, Loss: 0.6162561798095703


 60%|██████    | 12/20 [00:18<00:11,  1.49s/it]

Epoch: 11, Loss: 0.5991575622558594


 65%|██████▌   | 13/20 [00:19<00:10,  1.49s/it]

Epoch: 12, Loss: 0.5376598930358887


 70%|███████   | 14/20 [00:21<00:08,  1.48s/it]

Epoch: 13, Loss: 0.48841270446777346


 75%|███████▌  | 15/20 [00:22<00:07,  1.47s/it]

Epoch: 14, Loss: 0.4545210647583008


 80%|████████  | 16/20 [00:24<00:06,  1.50s/it]

Epoch: 15, Loss: 0.42499658584594724


 85%|████████▌ | 17/20 [00:25<00:04,  1.49s/it]

Epoch: 16, Loss: 0.40585979461669924


 90%|█████████ | 18/20 [00:27<00:02,  1.48s/it]

Epoch: 17, Loss: 0.37060779571533203


 95%|█████████▌| 19/20 [00:28<00:01,  1.51s/it]

Epoch: 18, Loss: 0.339699764251709


100%|██████████| 20/20 [00:30<00:00,  1.52s/it]


Epoch: 19, Loss: 0.3360927200317383


  5%|▌         | 1/20 [00:01<00:30,  1.62s/it]

Epoch: 0, Loss: 4.068798196488532


 10%|█         | 2/20 [00:03<00:28,  1.59s/it]

Epoch: 1, Loss: 2.023626169144819


 15%|█▌        | 3/20 [00:05<00:28,  1.70s/it]

Epoch: 2, Loss: 1.3664182464857608


 20%|██        | 4/20 [00:06<00:26,  1.68s/it]

Epoch: 3, Loss: 0.9705313864537484


 25%|██▌       | 5/20 [00:08<00:24,  1.66s/it]

Epoch: 4, Loss: 0.7043221330873056


 30%|███       | 6/20 [00:10<00:24,  1.72s/it]

Epoch: 5, Loss: 0.5576138358185256


 35%|███▌      | 7/20 [00:11<00:21,  1.66s/it]

Epoch: 6, Loss: 0.4528328497052769


 40%|████      | 8/20 [00:13<00:20,  1.71s/it]

Epoch: 7, Loss: 0.3411784886162062


 45%|████▌     | 9/20 [00:15<00:19,  1.76s/it]

Epoch: 8, Loss: 0.30561040104299353


 50%|█████     | 10/20 [00:16<00:16,  1.68s/it]

Epoch: 9, Loss: 0.2652533134976447


 55%|█████▌    | 11/20 [00:18<00:14,  1.62s/it]

Epoch: 10, Loss: 0.2405216394415224


 60%|██████    | 12/20 [00:19<00:12,  1.58s/it]

Epoch: 11, Loss: 0.2113763111225073


 65%|██████▌   | 13/20 [00:21<00:11,  1.58s/it]

Epoch: 12, Loss: 0.19612665567996998


 70%|███████   | 14/20 [00:22<00:09,  1.55s/it]

Epoch: 13, Loss: 0.17688302832525132


 75%|███████▌  | 15/20 [00:24<00:07,  1.57s/it]

Epoch: 14, Loss: 0.16830824769061545


 80%|████████  | 16/20 [00:26<00:06,  1.68s/it]

Epoch: 15, Loss: 0.15284763902857684


 85%|████████▌ | 17/20 [00:27<00:04,  1.58s/it]

Epoch: 16, Loss: 0.1459309294603873


 90%|█████████ | 18/20 [00:29<00:03,  1.55s/it]

Epoch: 17, Loss: 0.14890676074557835


 95%|█████████▌| 19/20 [00:30<00:01,  1.52s/it]

Epoch: 18, Loss: 0.12853605159814807


100%|██████████| 20/20 [00:32<00:00,  1.64s/it]


Epoch: 19, Loss: 0.1237243454237491


  5%|▌         | 1/20 [00:01<00:23,  1.25s/it]

Epoch: 0, Loss: 10.372847893658806


 10%|█         | 2/20 [00:02<00:21,  1.17s/it]

Epoch: 1, Loss: 4.497578742457371


 15%|█▌        | 3/20 [00:03<00:20,  1.19s/it]

Epoch: 2, Loss: 2.5604704314587163


 20%|██        | 4/20 [00:04<00:18,  1.14s/it]

Epoch: 3, Loss: 1.7333823933320887


 25%|██▌       | 5/20 [00:05<00:16,  1.10s/it]

Epoch: 4, Loss: 1.3528949700149835


 30%|███       | 6/20 [00:06<00:15,  1.07s/it]

Epoch: 5, Loss: 1.08629295872707


 35%|███▌      | 7/20 [00:07<00:14,  1.10s/it]

Epoch: 6, Loss: 0.9470780503516104


 40%|████      | 8/20 [00:09<00:13,  1.13s/it]

Epoch: 7, Loss: 0.8138606501560585


 45%|████▌     | 9/20 [00:10<00:12,  1.10s/it]

Epoch: 8, Loss: 0.7496346866383272


 50%|█████     | 10/20 [00:11<00:11,  1.13s/it]

Epoch: 9, Loss: 0.6495488484700521


 55%|█████▌    | 11/20 [00:12<00:10,  1.12s/it]

Epoch: 10, Loss: 0.5987204196406346


 60%|██████    | 12/20 [00:13<00:09,  1.19s/it]

Epoch: 11, Loss: 0.5322781731100643


 65%|██████▌   | 13/20 [00:15<00:08,  1.24s/it]

Epoch: 12, Loss: 0.5099851009892482


 70%|███████   | 14/20 [00:16<00:07,  1.31s/it]

Epoch: 13, Loss: 0.452398374968884


 75%|███████▌  | 15/20 [00:17<00:06,  1.28s/it]

Epoch: 14, Loss: 0.4263424031874713


 80%|████████  | 16/20 [00:18<00:04,  1.24s/it]

Epoch: 15, Loss: 0.41484907561657475


 85%|████████▌ | 17/20 [00:20<00:03,  1.22s/it]

Epoch: 16, Loss: 0.3904273463230507


 90%|█████████ | 18/20 [00:21<00:02,  1.26s/it]

Epoch: 17, Loss: 0.36516761779785156


 95%|█████████▌| 19/20 [00:22<00:01,  1.23s/it]

Epoch: 18, Loss: 0.339630126953125


100%|██████████| 20/20 [00:23<00:00,  1.19s/it]


Epoch: 19, Loss: 0.3159651662789139


  5%|▌         | 1/20 [00:00<00:10,  1.89it/s]

Epoch: 0, Loss: 4.812975154203527


 10%|█         | 2/20 [00:00<00:07,  2.28it/s]

Epoch: 1, Loss: 2.803708188674029


 15%|█▌        | 3/20 [00:01<00:07,  2.34it/s]

Epoch: 2, Loss: 2.209087231579949


 20%|██        | 4/20 [00:01<00:08,  1.95it/s]

Epoch: 3, Loss: 1.8128666877746582


 25%|██▌       | 5/20 [00:02<00:07,  2.08it/s]

Epoch: 4, Loss: 1.592224625980153


 30%|███       | 6/20 [00:02<00:06,  2.29it/s]

Epoch: 5, Loss: 1.3122379078584558


 35%|███▌      | 7/20 [00:03<00:06,  1.91it/s]

Epoch: 6, Loss: 1.1152818904203528


 40%|████      | 8/20 [00:03<00:05,  2.11it/s]

Epoch: 7, Loss: 0.9862801888409782


 45%|████▌     | 9/20 [00:04<00:04,  2.23it/s]

Epoch: 8, Loss: 0.8322737076703239


 50%|█████     | 10/20 [00:04<00:04,  2.19it/s]

Epoch: 9, Loss: 0.7214311150943532


 55%|█████▌    | 11/20 [00:05<00:03,  2.26it/s]

Epoch: 10, Loss: 0.6528690282036277


 60%|██████    | 12/20 [00:05<00:03,  2.43it/s]

Epoch: 11, Loss: 0.5359686402713552


 65%|██████▌   | 13/20 [00:05<00:02,  2.48it/s]

Epoch: 12, Loss: 0.5196100683773265


 70%|███████   | 14/20 [00:06<00:02,  2.60it/s]

Epoch: 13, Loss: 0.45110337874468637


 75%|███████▌  | 15/20 [00:06<00:01,  2.64it/s]

Epoch: 14, Loss: 0.39300811991972084


 80%|████████  | 16/20 [00:06<00:01,  2.62it/s]

Epoch: 15, Loss: 0.34115062040441174


 85%|████████▌ | 17/20 [00:07<00:01,  2.39it/s]

Epoch: 16, Loss: 0.3300890642053941


 90%|█████████ | 18/20 [00:07<00:00,  2.38it/s]

Epoch: 17, Loss: 0.2816321990069221


 95%|█████████▌| 19/20 [00:08<00:00,  2.48it/s]

Epoch: 18, Loss: 0.24671358220717488


100%|██████████| 20/20 [00:08<00:00,  2.34it/s]


Epoch: 19, Loss: 0.24730783350327434


  5%|▌         | 1/20 [00:00<00:13,  1.41it/s]

Epoch: 0, Loss: 8.436452375517952


 10%|█         | 2/20 [00:01<00:14,  1.27it/s]

Epoch: 1, Loss: 5.835526009400685


 15%|█▌        | 3/20 [00:02<00:12,  1.32it/s]

Epoch: 2, Loss: 4.353655987315708


 20%|██        | 4/20 [00:02<00:11,  1.39it/s]

Epoch: 3, Loss: 3.3288498057259455


 25%|██▌       | 5/20 [00:03<00:10,  1.43it/s]

Epoch: 4, Loss: 2.6764301591449313


 30%|███       | 6/20 [00:04<00:09,  1.44it/s]

Epoch: 5, Loss: 2.1311446958118014


 35%|███▌      | 7/20 [00:04<00:08,  1.49it/s]

Epoch: 6, Loss: 1.7856677373250325


 40%|████      | 8/20 [00:05<00:07,  1.50it/s]

Epoch: 7, Loss: 1.4811155663596258


 45%|████▌     | 9/20 [00:06<00:07,  1.51it/s]

Epoch: 8, Loss: 1.3100445138083563


 50%|█████     | 10/20 [00:06<00:06,  1.54it/s]

Epoch: 9, Loss: 1.0882348616917927


 55%|█████▌    | 11/20 [00:07<00:06,  1.50it/s]

Epoch: 10, Loss: 0.8917707867092557


 60%|██████    | 12/20 [00:08<00:05,  1.48it/s]

Epoch: 11, Loss: 0.7904523213704427


 65%|██████▌   | 13/20 [00:08<00:04,  1.44it/s]

Epoch: 12, Loss: 0.7149979141023424


 70%|███████   | 14/20 [00:09<00:04,  1.34it/s]

Epoch: 13, Loss: 0.6356095605426364


 75%|███████▌  | 15/20 [00:10<00:03,  1.32it/s]

Epoch: 14, Loss: 0.5892967780431112


 80%|████████  | 16/20 [00:11<00:02,  1.37it/s]

Epoch: 15, Loss: 0.5696613126330905


 85%|████████▌ | 17/20 [00:11<00:02,  1.42it/s]

Epoch: 16, Loss: 0.46452906396653915


 90%|█████████ | 18/20 [00:12<00:01,  1.35it/s]

Epoch: 17, Loss: 0.43990659713745117


 95%|█████████▌| 19/20 [00:13<00:00,  1.38it/s]

Epoch: 18, Loss: 0.43774375650617814


100%|██████████| 20/20 [00:14<00:00,  1.42it/s]


Epoch: 19, Loss: 0.3814397123124864


  5%|▌         | 1/20 [00:00<00:08,  2.17it/s]

Epoch: 0, Loss: 14.08773346380754


 10%|█         | 2/20 [00:00<00:08,  2.06it/s]

Epoch: 1, Loss: 10.31571114063263


 15%|█▌        | 3/20 [00:01<00:08,  2.06it/s]

Epoch: 2, Loss: 8.186459389599888


 20%|██        | 4/20 [00:02<00:08,  1.95it/s]

Epoch: 3, Loss: 6.747642853043296


 25%|██▌       | 5/20 [00:02<00:07,  1.91it/s]

Epoch: 4, Loss: 5.633931701833552


 30%|███       | 6/20 [00:03<00:07,  1.90it/s]

Epoch: 5, Loss: 4.808382120999423


 35%|███▌      | 7/20 [00:03<00:06,  1.91it/s]

Epoch: 6, Loss: 4.0193868550387295


 40%|████      | 8/20 [00:04<00:06,  1.78it/s]

Epoch: 7, Loss: 3.4804000854492188


 45%|████▌     | 9/20 [00:04<00:05,  1.87it/s]

Epoch: 8, Loss: 3.0541657317768443


 50%|█████     | 10/20 [00:05<00:05,  1.96it/s]

Epoch: 9, Loss: 2.708812215111472


 55%|█████▌    | 11/20 [00:05<00:05,  1.78it/s]

Epoch: 10, Loss: 2.460697065700184


 60%|██████    | 12/20 [00:06<00:04,  1.79it/s]

Epoch: 11, Loss: 2.1397795460440894


 65%|██████▌   | 13/20 [00:06<00:03,  1.75it/s]

Epoch: 12, Loss: 1.9117380055514248


 70%|███████   | 14/20 [00:07<00:03,  1.76it/s]

Epoch: 13, Loss: 1.664451924237338


 75%|███████▌  | 15/20 [00:08<00:03,  1.64it/s]

Epoch: 14, Loss: 1.5230879133397883


 80%|████████  | 16/20 [00:08<00:02,  1.65it/s]

Epoch: 15, Loss: 1.384453621777621


 85%|████████▌ | 17/20 [00:09<00:01,  1.67it/s]

Epoch: 16, Loss: 1.3012884746898303


 90%|█████████ | 18/20 [00:10<00:01,  1.41it/s]

Epoch: 17, Loss: 1.191570368680087


 95%|█████████▌| 19/20 [00:11<00:00,  1.46it/s]

Epoch: 18, Loss: 1.0541667071255771


100%|██████████| 20/20 [00:11<00:00,  1.69it/s]


Epoch: 19, Loss: 0.9777818159623579


  5%|▌         | 1/20 [00:02<00:56,  2.95s/it]

Epoch: 0, Loss: 7.9695827100012036


 10%|█         | 2/20 [00:05<00:54,  3.01s/it]

Epoch: 1, Loss: 3.473751626632832


 15%|█▌        | 3/20 [00:08<00:49,  2.92s/it]

Epoch: 2, Loss: 1.9217214054531522


 20%|██        | 4/20 [00:11<00:47,  2.99s/it]

Epoch: 3, Loss: 1.2242681582768757


 25%|██▌       | 5/20 [00:14<00:43,  2.91s/it]

Epoch: 4, Loss: 0.8877141034161603


 30%|███       | 6/20 [00:17<00:41,  2.99s/it]

Epoch: 5, Loss: 0.6825412158612851


 35%|███▌      | 7/20 [00:20<00:37,  2.92s/it]

Epoch: 6, Loss: 0.5878373252020942


 40%|████      | 8/20 [00:23<00:34,  2.87s/it]

Epoch: 7, Loss: 0.48570706226207594


 45%|████▌     | 9/20 [00:25<00:30,  2.80s/it]

Epoch: 8, Loss: 0.41740225862573693


 50%|█████     | 10/20 [00:28<00:28,  2.83s/it]

Epoch: 9, Loss: 0.35650592380099827


 55%|█████▌    | 11/20 [00:31<00:25,  2.80s/it]

Epoch: 10, Loss: 0.3109044145654749


 60%|██████    | 12/20 [00:34<00:22,  2.76s/it]

Epoch: 11, Loss: 0.3071440502449318


 65%|██████▌   | 13/20 [00:36<00:19,  2.74s/it]

Epoch: 12, Loss: 0.2657142656820792


 70%|███████   | 14/20 [00:39<00:16,  2.69s/it]

Epoch: 13, Loss: 0.24679337607489693


 75%|███████▌  | 15/20 [00:42<00:13,  2.71s/it]

Epoch: 14, Loss: 0.2062425790009675


 80%|████████  | 16/20 [00:45<00:11,  2.90s/it]

Epoch: 15, Loss: 0.19849730420995643


 85%|████████▌ | 17/20 [00:49<00:09,  3.20s/it]

Epoch: 16, Loss: 0.17543862484119557


 90%|█████████ | 18/20 [00:53<00:06,  3.44s/it]

Epoch: 17, Loss: 0.17010088320131656


 95%|█████████▌| 19/20 [00:56<00:03,  3.39s/it]

Epoch: 18, Loss: 0.15708921573780202


100%|██████████| 20/20 [01:01<00:00,  3.07s/it]


Epoch: 19, Loss: 0.15529790631047002


  5%|▌         | 1/20 [00:01<00:37,  1.98s/it]

Epoch: 0, Loss: 13.892914705806309


 10%|█         | 2/20 [00:04<00:36,  2.03s/it]

Epoch: 1, Loss: 8.222750809457567


 15%|█▌        | 3/20 [00:06<00:36,  2.13s/it]

Epoch: 2, Loss: 5.722471722850093


 20%|██        | 4/20 [00:08<00:33,  2.10s/it]

Epoch: 3, Loss: 4.192185631504765


 25%|██▌       | 5/20 [00:10<00:32,  2.16s/it]

Epoch: 4, Loss: 3.117456065283881


 30%|███       | 6/20 [00:12<00:29,  2.12s/it]

Epoch: 5, Loss: 2.407395009641294


 35%|███▌      | 7/20 [00:14<00:27,  2.14s/it]

Epoch: 6, Loss: 1.9378366470336914


 40%|████      | 8/20 [00:17<00:26,  2.17s/it]

Epoch: 7, Loss: 1.6271971596611872


 45%|████▌     | 9/20 [00:19<00:24,  2.20s/it]

Epoch: 8, Loss: 1.4102669115419741


 50%|█████     | 10/20 [00:21<00:22,  2.23s/it]

Epoch: 9, Loss: 1.1893378010502569


 55%|█████▌    | 11/20 [00:24<00:21,  2.34s/it]

Epoch: 10, Loss: 1.0258190543563277


 60%|██████    | 12/20 [00:26<00:18,  2.34s/it]

Epoch: 11, Loss: 0.9545555644565158


 65%|██████▌   | 13/20 [00:28<00:15,  2.22s/it]

Epoch: 12, Loss: 0.8767504162258573


 70%|███████   | 14/20 [00:30<00:12,  2.13s/it]

Epoch: 13, Loss: 0.7630902572914406


 75%|███████▌  | 15/20 [00:32<00:10,  2.07s/it]

Epoch: 14, Loss: 0.7001514258208098


 80%|████████  | 16/20 [00:34<00:08,  2.08s/it]

Epoch: 15, Loss: 0.6500841246710883


 85%|████████▌ | 17/20 [00:36<00:06,  2.03s/it]

Epoch: 16, Loss: 0.5789300600687662


 90%|█████████ | 18/20 [00:38<00:04,  2.01s/it]

Epoch: 17, Loss: 0.5292427804734972


 95%|█████████▌| 19/20 [00:40<00:02,  2.02s/it]

Epoch: 18, Loss: 0.4968057738410102


100%|██████████| 20/20 [00:42<00:00,  2.13s/it]


Epoch: 19, Loss: 0.4622416143064146


  5%|▌         | 1/20 [00:00<00:09,  2.07it/s]

Epoch: 0, Loss: 2.8190551651848685


 10%|█         | 2/20 [00:01<00:12,  1.48it/s]

Epoch: 1, Loss: 1.805382858382331


 15%|█▌        | 3/20 [00:01<00:10,  1.67it/s]

Epoch: 2, Loss: 1.3741064972347683


 20%|██        | 4/20 [00:02<00:09,  1.67it/s]

Epoch: 3, Loss: 1.0254766040378147


 25%|██▌       | 5/20 [00:02<00:08,  1.70it/s]

Epoch: 4, Loss: 0.752214707268609


 30%|███       | 6/20 [00:03<00:09,  1.55it/s]

Epoch: 5, Loss: 0.5832211706373427


 35%|███▌      | 7/20 [00:04<00:09,  1.34it/s]

Epoch: 6, Loss: 0.45064992904663087


 40%|████      | 8/20 [00:05<00:08,  1.49it/s]

Epoch: 7, Loss: 0.3710272683037652


 45%|████▌     | 9/20 [00:05<00:06,  1.60it/s]

Epoch: 8, Loss: 0.2892147594028049


 50%|█████     | 10/20 [00:06<00:06,  1.58it/s]

Epoch: 9, Loss: 0.24174439642164441


 55%|█████▌    | 11/20 [00:06<00:05,  1.70it/s]

Epoch: 10, Loss: 0.21066721810234917


 60%|██████    | 12/20 [00:07<00:04,  1.86it/s]

Epoch: 11, Loss: 0.17164486249287922


 65%|██████▌   | 13/20 [00:07<00:03,  1.91it/s]

Epoch: 12, Loss: 0.154703860812717


 70%|███████   | 14/20 [00:08<00:03,  1.93it/s]

Epoch: 13, Loss: 0.1370867517259386


 75%|███████▌  | 15/20 [00:08<00:02,  2.11it/s]

Epoch: 14, Loss: 0.11792159610324436


 80%|████████  | 16/20 [00:09<00:02,  1.96it/s]

Epoch: 15, Loss: 0.10726593865288628


 85%|████████▌ | 17/20 [00:09<00:01,  2.17it/s]

Epoch: 16, Loss: 0.09568341573079427


 90%|█████████ | 18/20 [00:09<00:00,  2.34it/s]

Epoch: 17, Loss: 0.08643312454223633


 95%|█████████▌| 19/20 [00:10<00:00,  2.49it/s]

Epoch: 18, Loss: 0.0889907095167372


100%|██████████| 20/20 [00:10<00:00,  1.88it/s]


Epoch: 19, Loss: 0.07582491768731012


  5%|▌         | 1/20 [00:01<00:22,  1.18s/it]

Epoch: 0, Loss: 3.0050328616742736


 10%|█         | 2/20 [00:02<00:22,  1.28s/it]

Epoch: 1, Loss: 1.2106543629257767


 15%|█▌        | 3/20 [00:03<00:21,  1.29s/it]

Epoch: 2, Loss: 0.7581196343457257


 20%|██        | 4/20 [00:04<00:19,  1.23s/it]

Epoch: 3, Loss: 0.5399786242732295


 25%|██▌       | 5/20 [00:05<00:16,  1.10s/it]

Epoch: 4, Loss: 0.41145861855259647


 30%|███       | 6/20 [00:06<00:14,  1.07s/it]

Epoch: 5, Loss: 0.3338964532922815


 35%|███▌      | 7/20 [00:08<00:15,  1.20s/it]

Epoch: 6, Loss: 0.3008021001462583


 40%|████      | 8/20 [00:09<00:13,  1.14s/it]

Epoch: 7, Loss: 0.23617470176131636


 45%|████▌     | 9/20 [00:10<00:11,  1.04s/it]

Epoch: 8, Loss: 0.20295118049339012


 50%|█████     | 10/20 [00:10<00:09,  1.05it/s]

Epoch: 9, Loss: 0.17726558402732567


 55%|█████▌    | 11/20 [00:11<00:08,  1.06it/s]

Epoch: 10, Loss: 0.16188382925810638


 60%|██████    | 12/20 [00:13<00:08,  1.05s/it]

Epoch: 11, Loss: 0.1381419164163095


 65%|██████▌   | 13/20 [00:14<00:07,  1.10s/it]

Epoch: 12, Loss: 0.1484994305504693


 70%|███████   | 14/20 [00:15<00:06,  1.07s/it]

Epoch: 13, Loss: 0.10749455911141854


 75%|███████▌  | 15/20 [00:16<00:05,  1.10s/it]

Epoch: 14, Loss: 0.11595676916616934


 80%|████████  | 16/20 [00:17<00:04,  1.21s/it]

Epoch: 15, Loss: 0.09984918523717809


 85%|████████▌ | 17/20 [00:19<00:04,  1.39s/it]

Epoch: 16, Loss: 0.10004318555196126


 90%|█████████ | 18/20 [00:20<00:02,  1.29s/it]

Epoch: 17, Loss: 0.08519969869543005


 95%|█████████▌| 19/20 [00:21<00:01,  1.19s/it]

Epoch: 18, Loss: 0.0861011028289795


100%|██████████| 20/20 [00:22<00:00,  1.15s/it]


Epoch: 19, Loss: 0.07811024630511249


  5%|▌         | 1/20 [00:01<00:26,  1.41s/it]

Epoch: 0, Loss: 7.586920875549317


 10%|█         | 2/20 [00:02<00:22,  1.23s/it]

Epoch: 1, Loss: 4.530473655700684


 15%|█▌        | 3/20 [00:03<00:19,  1.14s/it]

Epoch: 2, Loss: 3.0764700088500976


 20%|██        | 4/20 [00:04<00:17,  1.07s/it]

Epoch: 3, Loss: 2.274331069946289


 25%|██▌       | 5/20 [00:05<00:16,  1.12s/it]

Epoch: 4, Loss: 1.7743843612670898


 30%|███       | 6/20 [00:06<00:14,  1.07s/it]

Epoch: 5, Loss: 1.4343551712036133


 35%|███▌      | 7/20 [00:07<00:13,  1.02s/it]

Epoch: 6, Loss: 1.2036347885131835


 40%|████      | 8/20 [00:08<00:11,  1.00it/s]

Epoch: 7, Loss: 1.0236400222778321


 45%|████▌     | 9/20 [00:09<00:10,  1.04it/s]

Epoch: 8, Loss: 0.9142070922851563


 50%|█████     | 10/20 [00:10<00:10,  1.05s/it]

Epoch: 9, Loss: 0.798502197265625


 55%|█████▌    | 11/20 [00:11<00:10,  1.11s/it]

Epoch: 10, Loss: 0.6968284912109375


 60%|██████    | 12/20 [00:13<00:09,  1.16s/it]

Epoch: 11, Loss: 0.6236907958984375


 65%|██████▌   | 13/20 [00:14<00:07,  1.11s/it]

Epoch: 12, Loss: 0.5825645751953125


 70%|███████   | 14/20 [00:15<00:06,  1.13s/it]

Epoch: 13, Loss: 0.5058963775634766


 75%|███████▌  | 15/20 [00:16<00:05,  1.12s/it]

Epoch: 14, Loss: 0.4411800842285156


 80%|████████  | 16/20 [00:17<00:04,  1.23s/it]

Epoch: 15, Loss: 0.4274100189208984


 85%|████████▌ | 17/20 [00:19<00:03,  1.21s/it]

Epoch: 16, Loss: 0.383775146484375


 90%|█████████ | 18/20 [00:20<00:02,  1.19s/it]

Epoch: 17, Loss: 0.3711180114746094


 95%|█████████▌| 19/20 [00:21<00:01,  1.11s/it]

Epoch: 18, Loss: 0.3236155242919922


100%|██████████| 20/20 [00:22<00:00,  1.10s/it]


Epoch: 19, Loss: 0.318846435546875


  5%|▌         | 1/20 [00:00<00:11,  1.66it/s]

Epoch: 0, Loss: 9.630860931571872


 10%|█         | 2/20 [00:01<00:16,  1.09it/s]

Epoch: 1, Loss: 5.451766545745148


 15%|█▌        | 3/20 [00:02<00:15,  1.08it/s]

Epoch: 2, Loss: 3.413822294651777


 20%|██        | 4/20 [00:03<00:13,  1.18it/s]

Epoch: 3, Loss: 2.4526499496109184


 25%|██▌       | 5/20 [00:04<00:11,  1.28it/s]

Epoch: 4, Loss: 1.7829839333720592


 30%|███       | 6/20 [00:04<00:10,  1.29it/s]

Epoch: 5, Loss: 1.3397858827963642


 35%|███▌      | 7/20 [00:05<00:09,  1.35it/s]

Epoch: 6, Loss: 1.0912554839561726


 40%|████      | 8/20 [00:06<00:08,  1.37it/s]

Epoch: 7, Loss: 0.8818527857462565


 45%|████▌     | 9/20 [00:07<00:08,  1.32it/s]

Epoch: 8, Loss: 0.6862025096498686


 50%|█████     | 10/20 [00:07<00:07,  1.36it/s]

Epoch: 9, Loss: 0.5522515636751022


 55%|█████▌    | 11/20 [00:08<00:06,  1.42it/s]

Epoch: 10, Loss: 0.4979499400347129


 60%|██████    | 12/20 [00:09<00:05,  1.41it/s]

Epoch: 11, Loss: 0.44764948439324037


 65%|██████▌   | 13/20 [00:09<00:04,  1.44it/s]

Epoch: 12, Loss: 0.36625432420050963


 70%|███████   | 14/20 [00:10<00:04,  1.43it/s]

Epoch: 13, Loss: 0.28570955649189567


 75%|███████▌  | 15/20 [00:11<00:03,  1.31it/s]

Epoch: 14, Loss: 0.23705855183217717


 80%|████████  | 16/20 [00:11<00:02,  1.39it/s]

Epoch: 15, Loss: 0.24214757721999597


 85%|████████▌ | 17/20 [00:12<00:02,  1.46it/s]

Epoch: 16, Loss: 0.23252338102494163


 90%|█████████ | 18/20 [00:13<00:01,  1.56it/s]

Epoch: 17, Loss: 0.1807409703046426


 95%|█████████▌| 19/20 [00:13<00:00,  1.61it/s]

Epoch: 18, Loss: 0.1644628897480581


100%|██████████| 20/20 [00:14<00:00,  1.40it/s]


Epoch: 19, Loss: 0.16496436897365527


  5%|▌         | 1/20 [00:00<00:09,  1.94it/s]

Epoch: 0, Loss: 7.218933524756596


 10%|█         | 2/20 [00:01<00:10,  1.74it/s]

Epoch: 1, Loss: 3.8499192177564248


 15%|█▌        | 3/20 [00:01<00:10,  1.69it/s]

Epoch: 2, Loss: 2.4028056407796927


 20%|██        | 4/20 [00:02<00:08,  1.85it/s]

Epoch: 3, Loss: 1.6717845982518689


 25%|██▌       | 5/20 [00:02<00:07,  1.92it/s]

Epoch: 4, Loss: 1.2869682421629456


 30%|███       | 6/20 [00:03<00:07,  1.99it/s]

Epoch: 5, Loss: 1.04040169441837


 35%|███▌      | 7/20 [00:03<00:06,  1.99it/s]

Epoch: 6, Loss: 0.8723142503321856


 40%|████      | 8/20 [00:04<00:05,  2.01it/s]

Epoch: 7, Loss: 0.8053792372517202


 45%|████▌     | 9/20 [00:04<00:05,  1.83it/s]

Epoch: 8, Loss: 0.6964112807964457


 50%|█████     | 10/20 [00:05<00:05,  1.89it/s]

Epoch: 9, Loss: 0.6315587690506859


 55%|█████▌    | 11/20 [00:05<00:04,  1.94it/s]

Epoch: 10, Loss: 0.5824301708703754


 60%|██████    | 12/20 [00:06<00:04,  1.94it/s]

Epoch: 11, Loss: 0.5382239462315351


 65%|██████▌   | 13/20 [00:06<00:03,  1.93it/s]

Epoch: 12, Loss: 0.4857608641701183


 70%|███████   | 14/20 [00:07<00:03,  1.90it/s]

Epoch: 13, Loss: 0.49670763673453494


 75%|███████▌  | 15/20 [00:07<00:02,  1.87it/s]

Epoch: 14, Loss: 0.4117313253468481


 80%|████████  | 16/20 [00:08<00:02,  1.84it/s]

Epoch: 15, Loss: 0.4269136944036374


 85%|████████▌ | 17/20 [00:09<00:01,  1.86it/s]

Epoch: 16, Loss: 0.4094566038285179


 90%|█████████ | 18/20 [00:09<00:01,  1.86it/s]

Epoch: 17, Loss: 0.34391419640902815


 95%|█████████▌| 19/20 [00:10<00:00,  1.72it/s]

Epoch: 18, Loss: 0.3542645772298177


100%|██████████| 20/20 [00:10<00:00,  1.84it/s]

Epoch: 19, Loss: 0.3301591982786683





In [101]:
# Check predictions before training
# tarining data {
#     intent: [x_train, y_train, x_test, y_test]
# }
with torch.no_grad():
    for intent in intents:
        print(f'Intent: {intent}')
        precheck_sent = prepare_sequence(training_data[intent][0][0].split(), separate_word_to_index[intent])
        precheck_tag = prepare_sequence(training_data[intent][1][0], separate_tag_to_index[intent])
        print("Actual tags:", precheck_tag)
        print("Predicted tags:", models[intent][0](precheck_sent))


Intent: conditional operation
Actual tags: tensor([3, 5, 2, 3, 6, 7, 3, 8, 1, 4, 3, 3, 3, 6, 3, 3, 3])
Predicted tags: [[3, 5, 2, 3, 6, 7, 3, 8, 1, 4, 3, 3, 3, 6, 3, 3, 3]]
Intent: mathematical operation
Actual tags: tensor([2, 4, 3, 4])
Predicted tags: [[2, 4, 3, 4]]
Intent: casting
Actual tags: tensor([4, 4, 4, 3, 1, 4, 4, 2])
Predicted tags: [[4, 4, 4, 3, 1, 4, 4, 2]]
Intent: comment
Actual tags: tensor([2, 1, 3, 3])
Predicted tags: [[2, 1, 3, 3]]
Intent: function declaration
Actual tags: tensor([3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 1, 5, 3, 3, 3])
Predicted tags: [[3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 1, 2, 3, 3, 3]]
Intent: ide operation
Actual tags: tensor([1])
Predicted tags: [[1]]
Intent: membership operation
Actual tags: tensor([3, 4, 3, 7, 3, 3, 5])
Predicted tags: [[3, 4, 3, 7, 3, 3, 5]]
Intent: class declaration
Actual tags: tensor([2, 3, 2, 2, 2])
Predicted tags: [[2, 3, 2, 2, 2]]
Intent: assignment operation
Actual tags: tensor([3, 1, 2, 4])
Predicted tags: [[3, 1, 2, 4]]
Intent: asser

In [103]:
def predict(model, test_data):
    model.eval()
    predictions = [] # a dictionary to store the predictions for each intent
    with torch.no_grad():
        for sentence in test_data:
            precheck_sent = prepare_sequence(sentence.split(), separate_word_to_index[intent])
            predictions.append(model(precheck_sent)[0])
    return predictions

In [104]:
# evaluation 
def evaluate(y_true, y_pred):
    mlb = MultiLabelBinarizer()
    y_true_binarized = mlb.fit_transform(y_true)
    y_pred_binarized = mlb.transform(y_pred)
    
    f1 = f1_score(y_true_binarized, y_pred_binarized, average='weighted')
    precision = precision_score(y_true_binarized, y_pred_binarized, average='weighted')
    recall = recall_score(y_true_binarized, y_pred_binarized, average='weighted')
    accuracy = accuracy_score(y_true_binarized, y_pred_binarized)
    
    return f1, precision, recall, accuracy

In [111]:
for intent in intents:
    print(f'Intent: {intent}')
    x_test, y_test = training_data[intent][2], training_data[intent][3]
    predictions = predict(models[intent][0], x_test)

    mapped_predictions = [[separate_index_to_tag[intent][tag] for tag in tags] for tags in predictions]

    f1, precision, recall, accuracy = evaluate(y_test, mapped_predictions)
    print("F1 Score:", f1)
    print("Precision:", precision)
    print("Recall:", recall)
    print("Accuracy:", accuracy)
    print("----------------------------------------------------------------------")

Intent: conditional operation
F1 Score: 0.9656853924365445
Precision: 0.9826839826839827
Recall: 0.9523809523809523
Accuracy: 0.6875
----------------------------------------------------------------------
Intent: mathematical operation
F1 Score: 0.9841082929170324
Precision: 0.9876169510181618
Recall: 0.9810126582278481
Accuracy: 0.918918918918919
----------------------------------------------------------------------
Intent: casting
F1 Score: 0.9808429118773946
Precision: 1.0
Recall: 0.9655172413793104
Accuracy: 0.875
----------------------------------------------------------------------
Intent: comment
F1 Score: 1.0
Precision: 1.0
Recall: 1.0
Accuracy: 1.0
----------------------------------------------------------------------
Intent: function declaration
F1 Score: 0.9891304347826088
Precision: 1.0
Recall: 0.9791666666666666
Accuracy: 0.9166666666666666
----------------------------------------------------------------------
Intent: ide operation
F1 Score: 0.9913194444444444
Precision: 1.



In [113]:
# Check predictions before training
test_example_1 = ("create a variable called ay khara and set it to yasmine ghanem", ["O", "O", "O", "O", "B-VAR", "I-VAR", "O", "O", "O", "B-VAL", "I-VAL"], "variable declaration")
test_example_2 = ("cast x to an int", ["O", "B-VAR", "O", "O", "B-CAST_TYPE"], "casting")
test_example_3 = ("import the pandas library", ["O", "O", "B-LIB_NAME", "O"], "libraries")
test_example_4 = ('write a new comment multiply the emissions and transitions', ['O', 'O', 'O', 'B-COMMENT', 'I-COMMENT', 'I-COMMENT', 'I-COMMENT', 'I-COMMENT' ], 'comment')
test_example_5 = ('define a new class person', ['O', 'O', 'O', 'O', 'B-CLASS_NAME'], 'class declaration')
test_example_6 = ('create a new function named add that takes parameters x and y', ['O', 'O', 'O', 'O', 'O', 'B-FUNC', 'O', 'O', 'O', 'B-PARAM', 'O', 'B-PARAM'], 'function declaration')
test_example_7 = ('create a new function called add with parameters num1 and num2', ['O', 'O', 'O', 'O', 'O', 'B-FUNC', 'O', 'O', 'O', 'B-PARAM', 'O', 'B-PARAM'], 'function declaration')

test_data = [test_example_1, test_example_2, test_example_3, test_example_4, test_example_5, test_example_6, test_example_7]

for index, example in enumerate(test_data):
    print("Sentence:", example[0])
    print("Intent:", example[2])

    predictions = predict(models[example[2]][0], [example[0]])

    mapped_predictions = [[separate_index_to_tag[example[2]][tag] for tag in tags] for tags in predictions]

    print("Actual tags:", example[1])
    print("Predicted tags:", mapped_predictions[0])

    evaluated = evaluate([example[1]], mapped_predictions)

    print("----------------------------------------------------------------------------------------------------")

# predictions = predict(test_model, test_data)
# labels = []
# predicted_tags = []

# for index, prediction in enumerate(predictions):
#     print("Test Example:", index)
#     mapped_tags = [all_tags[tag] for tag in prediction]
#     predicted_tags.append(mapped_tags)
#     labels.append(test_data[index][1])
#     print("Actual tags:", test_data[index][1])
#     print("Predicted tags:", mapped_tags)

Sentence: create a variable called ay khara and set it to yasmine ghanem
Intent: variable declaration
Actual tags: ['O', 'O', 'O', 'O', 'B-VAR', 'I-VAR', 'O', 'O', 'O', 'B-VAL', 'I-VAL']
Predicted tags: ['O', 'B-VAR', 'I-VAR', 'B-VAR', 'I-VAR', 'O', 'B-VAR', 'I-VAR', 'O', 'B-VAR', 'I-VAR', 'I-VAL']
----------------------------------------------------------------------------------------------------
Sentence: cast x to an int
Intent: casting
Actual tags: ['O', 'B-VAR', 'O', 'O', 'B-CAST_TYPE']
Predicted tags: ['O', 'B-VAR', 'I-VAR', 'O', 'B-CAST_TYPE']
----------------------------------------------------------------------------------------------------
Sentence: import the pandas library
Intent: libraries
Actual tags: ['O', 'O', 'B-LIB_NAME', 'O']
Predicted tags: ['O', 'B-LIB_NAME', 'O', 'B-LIB_NAME']
----------------------------------------------------------------------------------------------------
Sentence: write a new comment multiply the emissions and transitions
Intent: comment


  _warn_prf(average, modifier, msg_start, len(result))


IndexError: index out of range in self