In [1]:
import torch
from torch import nn
from dataloader import dataset, collate_fn
import pickle
from torch.utils import data
from torch.autograd import Variable

In [2]:
data_ = []
file_path = './generated_dataset.txt'
with open(file_path, 'r') as f:
    for i in f:
        j = i.strip('\n').split('\t')
        data_.append(j[-1].lower())

tweet_pairs, distance_vectors = [], []
with open('./tweet_pairs.pkl', 'rb') as f:
    tweet_pairs = pickle.load(f)
with open('./distance_vectors.pkl', 'rb') as f:
    distance_vectors = pickle.load(f)
    
tweet_pair_data = [[data_[i[0]], data_[i[1]]] for i in tweet_pairs]
distance_vector_data = [[distance_vectors[i[0]],distance_vectors[i[1]]] for i in tweet_pairs]

dataset_ = dataset(tweet_pair_data, distance_vector_data)
loader = data.DataLoader(dataset_, batch_size=32, collate_fn=collate_fn, shuffle=True)

In [9]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        with open("./vocab.pkl", "rb") as f:
            self.vocab = pickle.load(f)
        pre_trained_emb = torch.Tensor(self.vocab.vectors)
        self.word_embedding = nn.Embedding.from_pretrained(pre_trained_emb)
        self.distance_embedding = nn.Embedding(64, 14)
        self.lstm = nn.LSTM(114, 128, batch_first=True, bidirectional=True)
    
    def init_hidden(self, batch_size):
        h, c = (Variable(torch.zeros(1 * 2, batch_size, 128)),
                Variable(torch.zeros(1 * 2, batch_size, 128)))
        return h, c
    
    def forward(self, tweet1, tweet2, dist1, dist2):
        
        batch_size = tweet1.shape[0]
        h_0, c_0 = self.init_hidden(batch_size)
        
        tweet1_embedding = self.word_embedding(tweet1.long())
        tweet2_embedding = self.word_embedding(tweet2.long())
        dist1_embedding = self.distance_embedding(dist1.long())
        dist2_embedding = self.distance_embedding(dist2.long())
        
        tweet1 = torch.cat([tweet1_embedding, dist1_embedding], dim=2)        
        tweet2 = torch.cat([tweet2_embedding, dist2_embedding], dim=2)
        
        output, (h_n, c_n) = self.lstm(tweet1, (h_0, c_0))
        # Output Sentence Level Feature
        return tweet1_embedding.shape, tweet2_embedding.shape, dist1_embedding.shape, dist2_embedding.shape            


In [10]:
model = Model()

In [11]:
for a, b, c, d in loader:
    model(a, b, c, d)
    break

torch.Size([32, 27, 100]) torch.Size([32, 27, 14])
torch.Size([32, 27, 100]) torch.Size([32, 27, 14])
torch.Size([32, 27, 256])
