In [None]:
# Uses Glove embeddings for words.
# Hidden layers of LSTMs for word level feed into another LSTM whose last hidden layer is used to classify 

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import re
import util
import random
import embeddings

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
EMBEDDING_DIM = 50
WORD_HIDDEN_DIM = 30
SENT_HIDDEN_DIM = 30

In [4]:
vocab = embeddings.load_glove()

# language_model = embeddings.load_local_word2vec()
# text = util.get_alphanumeral(collections[0]["texts"][0])
# embeddings.get_doc_embedding(model, text)

In [5]:
df = util.get_processed_data("./data/collections_math.csv", True)
collections = util.get_collections(df)

In [6]:
df.head()

Unnamed: 0,collection_id,sequence_id,resource_id,title,description,text
0,0008d66a-753f-4639-8634-81bb3abb3269,3,2ee6f80f-0851-4cfa-b4bf-2655e9c46ab7,Solve the Linear equation: _______,Solve the Linear equation: [2],Solve Linear equation Solve Linear equation
1,0008d66a-753f-4639-8634-81bb3abb3269,4,81c6995c-dd95-418e-a8c4-c22d8ccd32e9,Solve the linear equation: _______,Solve the linear equation: [-18],Solve linear equation Solve linear equation
2,0008d66a-753f-4639-8634-81bb3abb3269,1,231eb4ad-d0e8-4e94-a552-f8bd2358a47a,Solve the linear equation: _______,Solve the linear equation: [1/2] ...,Solve linear equation Solve linear equation Pl...
3,0008d66a-753f-4639-8634-81bb3abb3269,2,0b248202-12a9-405b-acd1-4ab8250e4198,"If , then _______","If , then [1/5]. &nbsp;Please writ...",nbsp Please write answer fraction
4,001bf2c6-8ede-478a-9b8b-d7750488cb1b,1,1b922fae-619f-4f52-9551-663b3206e4e5,Lesson 11,I'll say an addition or subtraction sentence. ...,Lesson say addition subtraction sentence say a...


In [6]:
training_data, testing_data = util.get_train_test([col["texts"] for col in collections])

In [7]:
print(len(training_data), len(testing_data))

6332 2716


In [11]:
# For Gensim word2vec embeddings

# def get_embeddings(text, word_embedding_dim):
#     global language_model
#     global device
#     embeds = []
#     for word in text:
#         if word in language_model.wv.vocab:
#             embeds.append(torch.tensor(language_model.wv[word], dtype = torch.float, device = device))
#         else:
#             embeds.append(torch.zeros(word_embedding_dim, device = device))
#     return torch.cat(embeds).view(len(text),word_embedding_dim)

In [11]:
# For Glove embeddings

def get_embeddings(text, word_embedding_dim):
    global vocab
    global device
    embeds = []
    for word in text:
        if word in vocab:
            embeds.append(torch.tensor(vocab[word], dtype = torch.float, device = device))
        else:
            embeds.append(torch.zeros(word_embedding_dim, device = device))
    return torch.cat(embeds).view(len(text),word_embedding_dim)

In [12]:
#Defining model

class LSTMClassifier(nn.Module):
    
    def __init__(self, embedding_dim, word_hidden_dim, sent_hidden_dim, vocab_size):
        super(LSTMClassifier, self).__init__()
        self.word_hidden_dim = word_hidden_dim
        self.sent_hidden_dim = sent_hidden_dim
#         self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.word_lstm = nn.LSTM(embedding_dim, word_hidden_dim)
        self.sent_lstm = nn.LSTM(word_hidden_dim, sent_hidden_dim)
        
        self.hidden2out = nn.Linear(sent_hidden_dim, 2)
        self.sent_hidden = self.init_sent_hidden()
        
    def init_word_hidden(self):
        return (torch.zeros(1,1,self.word_hidden_dim, device = device), torch.zeros(1,1,self.word_hidden_dim, device = device))
    
    def init_sent_hidden(self):
        return (torch.zeros(1,1,self.sent_hidden_dim, device = device), torch.zeros(1,1,self.sent_hidden_dim, device = device))

    def forward(self, collection):
        outputs = []
        for i,word_embeds in enumerate(collection):
#             print("word_ixs: ",word_ixs)
            self.word_hidden = self.init_word_hidden()
#             word_embeds = self.embeddings(word_ixs)
            word_lstm_out, self.word_hidden = self.word_lstm(word_embeds.view(len(word_embeds),1,-1), self.word_hidden)
            outputs.append(self.word_hidden[0])
            
        word_hiddens = torch.cat(outputs)
        sent_lstm_out, self.sent_hidden = self.sent_lstm(word_hiddens, self.sent_hidden)
#         print("Sent hidden, ",self.sent_hidden[0], self.sent_hidden[0].shape)
        out = self.hidden2out(sent_lstm_out[-1])
        score = F.log_softmax(out, dim = 1)
        return score

In [14]:
#Training
model = LSTMClassifier(EMBEDDING_DIM, WORD_HIDDEN_DIM, SENT_HIDDEN_DIM, len(word_to_ix)).to(device)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)

for epoch in range(2):
    total_loss = 0
    count = 0
    printed_percentages = []
    corrects = 0
    for collection,label in training_data:
        
        complete_percentage = int(count*100/len(training_data))
        if complete_percentage%5 == 0 and (complete_percentage not in printed_percentages):
            print("epoch ",epoch+1,complete_percentage,"percent complete")
            print("Current accuracy: ",corrects/(count+1))
            printed_percentages.append(complete_percentage)
        count+=1
        
        label = torch.tensor([label], dtype = torch.long, device = device)
#         collection_ixs = [prepare_sequence(text.split(),word_to_ix) for text in collection]
        collection_embeds =[get_embeddings(text.split(), EMBEDDING_DIM) for text in collection]
        model.zero_grad()

        model.sent_hidden = model.init_sent_hidden()
#         print(collection_ixs)
        
        score = model(collection_embeds)
        
        _, predicted = torch.max(score,1)
        correct = 1 if (predicted == label) else 0
        corrects += correct
        
#         print(score.shape,label.shape)
        loss = loss_function(score, label)
        
        loss.backward()
        
        total_loss += loss.item()
        
        optimizer.step()
    print("\nepoch "+str(epoch+1)+" loss: "+str(total_loss)+"\n")

epoch  1 0 percent complete
Current accuracy:  0.0
epoch  1 5 percent complete
Current accuracy:  0.5188679245283019
epoch  1 10 percent complete
Current accuracy:  0.5669291338582677
epoch  1 15 percent complete
Current accuracy:  0.5688748685594112
epoch  1 20 percent complete
Current accuracy:  0.6096214511041009
epoch  1 25 percent complete
Current accuracy:  0.6477272727272727
epoch  1 30 percent complete
Current accuracy:  0.6917411888479748
epoch  1 35 percent complete
Current accuracy:  0.7168620378719567
epoch  1 40 percent complete
Current accuracy:  0.7438831886345698
epoch  1 45 percent complete
Current accuracy:  0.7362329007365837
epoch  1 50 percent complete
Current accuracy:  0.7527628670666245
epoch  1 55 percent complete
Current accuracy:  0.7680826636050516
epoch  1 60 percent complete
Current accuracy:  0.7774269928966061
epoch  1 65 percent complete
Current accuracy:  0.7908671362642701
epoch  1 70 percent complete
Current accuracy:  0.8022101939557961
epoch  1 75 

In [15]:
#Testing
with torch.no_grad():
    total_coll = len(testing_data)
    correct_preds = 0
    count = 0
    for collection,label  in testing_data:
        label = torch.tensor(label, dtype = torch.long, device = device)
        collection_embeds =[get_embeddings(text.split(), EMBEDDING_DIM) for text in collection]
#         collection_ixs = [prepare_sequence(text.split(),word_to_ix) for text in collection]
        score = model(collection_embeds)
        _, predicted = torch.max(score,1)
        correct = 1 if (predicted == label) else 0
        correct_preds += correct
                
    print("Count: ",count)
    print("Total collections : "+str(total_coll))
    print("Correct predictions: "+str(correct_preds))
    print ("Accuracy : "+str(correct_preds/total_coll))

Count:  0
Total collections : 2716
Correct predictions: 1358
Accuracy : 0.5


In [None]:
# torch.save(model,"./models/word_lstm_collections_csv.pt")