Skip to content

Commit

Permalink
Update train_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
prakashpandey9 committed Jun 8, 2018
1 parent 813c117 commit f315e67
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions train_test.py
Expand Up @@ -43,16 +43,16 @@
batch_size= context.size()[0] batch_size= context.size()[0]
context= Variable(context.long()) ## context.size() = (batch_size, num_sentences, embedding_length) embedding_length = hidden_size context= Variable(context.long()) ## context.size() = (batch_size, num_sentences, embedding_length) embedding_length = hidden_size
questions= Variable(questions.long()) ## questions.size() = (batch_size, num_tokens) questions= Variable(questions.long()) ## questions.size() = (batch_size, num_tokens)
answers= Variable(answers) answers= Variable(answers.long())


total_loss, acc = model.loss(context,questions,answers) ## Loss is calculated and gradients are backpropagated through the layers. total_loss, acc = model.loss(context,questions,answers) ## Loss is calculated and gradients are backpropagated through the layers.
total_loss.backward() total_loss.backward()
total_acc += acc*batch_size total_acc += acc*batch_size
count += batch_size count += batch_size


if batch_id %20 == 0: if batch_id %20 == 0:
print('training error') print('Training Error')
print ('task '+str(task_id)+',epoch '+str(epoch)+',loss ' +str(loss.data[0])+',total accuracy : '+str(total_acc/count)) print (f'[Task {task_id}, Epoch {epoch}] [Training] total_loss : {total_loss.data[0]: {10}.{8}}, acc : {total_acc / count: {5}.{4}}, batch_id : {batch_id}')
optim.step() optim.step()


'''Validation part''' '''Validation part'''
Expand All @@ -69,10 +69,9 @@
batch_size = context.size()[0] batch_size = context.size()[0]
context = Variable(context.long()) context = Variable(context.long())
questions = Variable(questions.long()) questions = Variable(questions.long())
answers = Variable(answers) answers = Variable(answers.long())


_, acc = model.loss(context,questions,answers) _, acc = model.loss(context,questions,answers)
total_loss.backward()
total_acc += acc*batch_size total_acc += acc*batch_size
count += batch_size count += batch_size


Expand All @@ -87,15 +86,14 @@
if early_stop_count > 20: # If the accuracy doesn't improve even after 20 epochs thenuse early stopping. if early_stop_count > 20: # If the accuracy doesn't improve even after 20 epochs thenuse early stopping.
early_Stop_flag = True early_Stop_flag = True


print ('itr '+str(itr)+',task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc)) print (f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}')


with open('log.txt', 'a') as fp: with open('log.txt', 'a') as fp:
fp.write('itr '+str(itr)+', task_id '+str(task_id)+',epoch '+str(epoch)+',total_acc '+str(total_acc)+'+\n ') fp.write(f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Validate] Accuracy : {total_acc: {5}.{4}}' + '\n')
if total_acc == 1.0: if total_acc == 1.0:
break break
else: else:
print('iteration'+str(itr)+'task' +str(task_id)+' Early Stopping at Epoch' +str(epoch)+'validation accuracy :' +str(best_acc)) print(f'[Run {itr}, Task {task_id}] Early Stopping at Epoch {epoch}, Valid Accuracy : {best_acc: {5}.{4}}')





dataset.set_mode('test') dataset.set_mode('test')
Expand All @@ -109,19 +107,19 @@
batch_size = context.size()[0] batch_size = context.size()[0]
context = Variable(context.long()) context = Variable(context.long())
questions = Variable(questions.long()) questions = Variable(questions.long())
answers = Variable(answers) answers = Variable(answers.long())


model.load_state_dict(best_state) # Loading the best model model.load_state_dict(best_state) # Loading the best model
_, acc = model.loss(context, questions, answers) _, acc = model.loss(context, questions, answers)


test_acc += acc* batch_size test_acc += acc* batch_size
count += batch_size count += batch_size
print ('itr '+ str(itr)+'task =' +str(task_id)+ 'Epoch ' +str(epoch)+' test accuracy : '+str(test_acc / count)) print (f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {test_acc / cnt: {5}.{4}}')






os.makedirs('models',exist_ok=True) os.makedirs('models',exist_ok=True)
with open('models/task'+str(task_id)+'_epoch'+str(epoch)+'_run'+str(run)+'_acc'+str(test_acc/count)+'.pth', 'wb') as fp: with open(f'models/task{task_id}_epoch{epoch}_run{itr}_acc{test_acc/cnt}.pth', 'wb') as fp:
torch.save(model.state_dict(), fp) torch.save(model.state_dict(), fp)
with open('log.txt', 'a') as fp: with open('log.txt', 'a') as fp:
fp.write('[itr '+str(itr)+', Task '+str(task_id)+', Epoch '+str(epoch)+'] [Test] Accuracy : '+str(total_acc)+' + \n') fp.write(f'[Run {itr}, Task {task_id}, Epoch {epoch}] [Test] Accuracy : {total_acc: {5}.{4}}' + '\n')

0 comments on commit f315e67

Please sign in to comment.