In [1]:
import torch
from torch import nn
from dataloader import dataset, collate_fn
import pickle
from torch.utils import data
from torch.autograd import Variable

In [16]:
data_ = []
file_path = './FinalDataset.csv'
with open(file_path, 'r') as f:
    for i in f:
        j = i.strip('\n').split('\t')
        data_.append(j[-1].lower())

tweet_pairs, distance_vectors = [], []
with open('./tweet_pairs.pkl', 'rb') as f:
    tweet_pairs = pickle.load(f)
with open('./distance_vectors.pkl', 'rb') as f:
    distance_vectors = pickle.load(f)
with open("./trigger_word_pos.pkl", 'rb') as f:
    trigger_word_pos = pickle.load(f)
    
tweet_pair_data = [[data_[i[0]], data_[i[1]]] for i in tweet_pairs]
distance_vector_data = [[distance_vectors[i[0]],distance_vectors[i[1]]] for i in tweet_pairs]
trigger_word_pos_data = [[trigger_word_pos[i[0]], trigger_word_pos[i[1]]] for i in tweet_pairs]
labels_data = [i[2] for i in tweet_pairs]

dataset_ = dataset(tweet_pair_data, distance_vector_data, trigger_word_pos_data, labels_data)
loader = data.DataLoader(dataset_, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [187]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        with open("./vocab.pkl", "rb") as f:
            self.vocab = pickle.load(f)
        pre_trained_emb = torch.Tensor(self.vocab.vectors)
        self.word_embedding = nn.Embedding.from_pretrained(pre_trained_emb)
        self.distance_embedding = nn.Embedding(64, 14)
        self.lstm = nn.LSTM(114, 128, batch_first=True, bidirectional=True)
        self.selective = nn.Linear(256, 1)
    
    def init_hidden(self, batch_size):
        h, c = (Variable(torch.zeros(1 * 2, batch_size, 128)),
                Variable(torch.zeros(1 * 2, batch_size, 128)))
        return h, c
    
    def forward(self, tweet1, tweet2, dist1, dist2, pos1, pos2):
        
        batch_size = tweet1.shape[0]
        seq_len = tweet1.shape[1]
        h_0, c_0 = self.init_hidden(batch_size)
        
        tweet1_embedding = self.word_embedding(tweet1.long())
        tweet2_embedding = self.word_embedding(tweet2.long())
        dist1_embedding = self.distance_embedding(dist1.long())
        dist2_embedding = self.distance_embedding(dist2.long())
        
        tweet1 = torch.cat([tweet1_embedding, dist1_embedding], dim=2)        
        tweet2 = torch.cat([tweet2_embedding, dist2_embedding], dim=2)
        
        # Tweet1
        output1, (h_n, c_n) = self.lstm(tweet1, (h_0, c_0))
        output_1 = output1.view(batch_size, seq_len, 2, 128)
        indices = torch.Tensor(list(range(batch_size))).long()
        
        ment_part1 = output_1[indices, pos1[:, 1].long(), 0, :]
        ment_part2 = output_1[indices, 0, 1, :]
        
        mention_feature1 = torch.cat([ment_part1, ment_part2], dim=1)
        Rc1 = output1 * mention_feature1.view(batch_size, 1, -1)
        alpha1 = torch.tanh(self.selective(Rc1))
        select1 = alpha1 * output1
        
        # Tweet2
        batch_size = tweet2.shape[0]
        seq_len = tweet2.shape[1]
        h_0, c_0 = self.init_hidden(batch_size)
        output2, (h_n, c_n) = self.lstm(tweet2, (h_0, c_0))
        output_2 = output2.view(batch_size, seq_len, 2, 128)
        indices = torch.Tensor(list(range(batch_size))).long()
        
        ment_part1 = output_2[indices, pos2[:, 1].long(), 0, :]
        ment_part2 = output_2[indices, 0, 1, :]
        
        mention_feature2 = torch.cat([ment_part1, ment_part2], dim=1)
        Rc2 = output2 * mention_feature2.view(batch_size, 1, -1)
        alpha2 = torch.tanh(self.selective(Rc2))
        select2 = alpha2 * output2
        
        print(select1.shape, select2.shape)
        return tweet1_embedding.shape, tweet2_embedding.shape, dist1_embedding.shape, dist2_embedding.shape            


In [188]:
model = Model()

In [189]:
for tweet1, tweet2, dist1, dist2, pos1, pos2, label in loader:
    model(tweet1, tweet2, dist1, dist2, pos1, pos2)

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
torch.Size([64, 29, 256]) torch.Size([64, 35, 256])
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
torch.Size([64, 33, 256]) torch.Size([64, 24, 256])
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
torch.Size([64, 30, 256]) torch.Size([64, 27, 256])
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
torch.Size([64, 35, 256]) torch.Size([64, 26, 256])
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0

KeyboardInterrupt: 