Skip to content
Permalink
Browse files

Update train_test.py

  • Loading branch information...
prakashpandey9 committed Jun 8, 2018
1 parent 813c117 commit f315e67c7179f77de508174ea45539991c575895
Showing with 11 additions and 13 deletions.
  1. +11 −13 train_test.py
@@ -43,16 +43,16 @@
batch_size= context.size()[0]
context= Variable(context.long()) ## context.size() = (batch_size, num_sentences, embedding_length) embedding_length = hidden_size
questions= Variable(questions.long()) ## questions.size() = (batch_size, num_tokens)
answers= Variable(answers)
answers= Variable(answers.long())

total_loss, acc = model.loss(context,questions,answers) ## Loss is calculated and gradients are backpropagated through the layers.
total_loss.backward()
total_acc += acc*batch_size
count += batch_size

if batch_id %20 == 0:
print('training error')
print ('task '+str(task_id)+',epoch '+str(epoch)+',loss ' +str(loss.data[0])+',total accuracy : '+str(total_acc/count))
print('Training Error')
print (f'[Task {task_id}, Epoch {epoch}] [Training] total_loss : {total_loss.data[0]: {10}.{8}}, acc : {total_acc / count: {5}.{4}}, batch_id : {batch_id}')
optim.step()

'''Validation part'''
@@ -69,10 +69,9 @@
batch_size = context.size()[0]
context = Variable(context.long())
questions = Variable(questions.long())
answers = Variable(answers)
answers = Variable(answers.long())

_, acc = model.loss(context,questions,answers)
total_loss.backward()
total_acc += acc*batch_size
count += batch_size

@@ -87,15 +86,14 @@
if early_stop_count > 20: # If the accuracy doesn't improve even after 20 epochs thenuse early stopping.
early_Stop_flag = True

print ('itr '+str(itr)+',task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc))
print (f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}')

with open('log.txt', 'a') as fp:
fp.write('itr '+str(itr)+', task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc)+'+\n ')
fp.write(f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}' + '\n')
if total_acc == 1.0:
break
else:
print('iteration'+str(itr)+'task' +str(task_id)+' Early Stopping at Epoch' +str(epoch)+'validation accuracy :' +str(best_acc))

print(f'[Run {itr}, Task {task_id}] Early Stopping at Epoch {epoch}, Valid Accuracy : {best_acc: {5}.{4}}')


dataset.set_mode('test')
@@ -109,19 +107,19 @@
batch_size = context.size()[0]
context = Variable(context.long())
questions = Variable(questions.long())
answers = Variable(answers)
answers = Variable(answers.long())

model.load_state_dict(best_state) # Loading the best model
_, acc = model.loss(context, questions, answers)

test_acc += acc* batch_size
count += batch_size
print ('itr '+ str(itr)+'task =' +str(task_id)+ 'Epoch ' +str(epoch)+' test accuracy : '+str(test_acc / count))
print (f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {test_acc / cnt: {5}.{4}}')



os.makedirs('models',exist_ok=True)
with open('models/task'+str(task_id)+'_epoch'+str(epoch)+'_run'+str(run)+'_acc'+str(test_acc/count)+'.pth', 'wb') as fp:
with open(f'models/task{task_id}_epoch{epoch}_run{itr}_acc{test_acc/cnt}.pth', 'wb') as fp:
torch.save(model.state_dict(), fp)
with open('log.txt', 'a') as fp:
fp.write('[itr '+str(itr)+', Task '+str(task_id)+', Epoch '+str(epoch)+'] [Test] Accuracy : '+str(total_acc)+' + \n')
fp.write(f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {total_acc: {5}.{4}}' + '\n')

0 comments on commit f315e67

Please sign in to comment.
You can’t perform that action at this time.