In [52]:
import torch
from torch import nn
from dataloader import dataset, collate_fn
import pickle
from torch.utils import data
from torch.autograd import Variable
from torchnlp.nn import Attention

In [53]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

In [54]:
data_ = []
file_path = './FinalDataset.csv'
with open(file_path, 'r') as f:
    for i in f:
        j = i.strip('\n').split('\t')
        data_.append(j[-1].lower())

tweet_pairs, distance_vectors = [], []
with open('./tweet_pairs.pkl', 'rb') as f:
    tweet_pairs = pickle.load(f)
with open('./distance_vectors.pkl', 'rb') as f:
    distance_vectors = pickle.load(f)
with open("./trigger_word_pos.pkl", 'rb') as f:
    trigger_word_pos = pickle.load(f)
    
tweet_pair_data = [[data_[i[0]], data_[i[1]]] for i in tweet_pairs]
distance_vector_data = [[distance_vectors[i[0]],distance_vectors[i[1]]] for i in tweet_pairs]
trigger_word_pos_data = [[trigger_word_pos[i[0]], trigger_word_pos[i[1]]] for i in tweet_pairs]
labels_data = [i[2] for i in tweet_pairs]

dataset_ = dataset(tweet_pair_data, distance_vector_data, trigger_word_pos_data, labels_data)
loader = data.DataLoader(dataset_, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [55]:
criterion = nn.BCEWithLogitsLoss()

In [63]:
class Model(nn.Module):
    def __init__(self,max_dist):
        super().__init__()
        with open("./vocab.pkl", "rb") as f:
            self.vocab = pickle.load(f)
        pre_trained_emb = torch.Tensor(self.vocab.vectors)
        self.word_embedding = nn.Embedding.from_pretrained(pre_trained_emb)
        self.distance_embedding = nn.Embedding(64, 14)
        self.lstm = nn.LSTM(114, 64, batch_first=True, bidirectional=True)
        self.selective = nn.Linear(128, 1)
        self.attention_linear = nn.Linear(128, 128)
        self.attention = Attention(128)
        self.final = nn.Linear(512, 1)
    
    def init_hidden(self, batch_size):
        h, c = (Variable(torch.zeros(1 * 2, batch_size, 64)).to(device),
                Variable(torch.zeros(1 * 2, batch_size, 64)).to(device))
        return h, c
    
    def forward(self, tweet1, tweet2, dist1, dist2, pos1, pos2):
        
        batch_size = tweet1.shape[0]
        seq_len = tweet1.shape[1]
        h_0, c_0 = self.init_hidden(batch_size)
        
        tweet1_embedding = self.word_embedding(tweet1.long())
        tweet2_embedding = self.word_embedding(tweet2.long())
        dist1_embedding = self.distance_embedding(dist1.long())
        dist2_embedding = self.distance_embedding(dist2.long())
        
        tweet1 = torch.cat([tweet1_embedding, dist1_embedding], dim=2)        
        tweet2 = torch.cat([tweet2_embedding, dist2_embedding], dim=2)
        
        # Tweet1
        output1, (h_n, c_n) = self.lstm(tweet1, (h_0, c_0))
        output_1 = output1.view(batch_size, seq_len, 2, 64)
        indices = torch.Tensor(list(range(batch_size))).long()
        
        ment_part1 = output_1[indices, pos1[:, 1].long(), 0, :]
        ment_part2 = output_1[indices, 0, 1, :]
        
        mention_feature1 = torch.cat([ment_part1, ment_part2], dim=1)
        Rc1 = output1 * mention_feature1.view(batch_size, 1, -1)
        alpha1 = torch.tanh(self.selective(Rc1))
        select1 = alpha1 * output1
        t = self.attention(mention_feature1.view(batch_size, 1, -1), select1)
        Vem1 = torch.cat([mention_feature1, t[0].view(batch_size, -1)], dim=1)
#         ex = torch.tanh(self.attention_linear(select1))
#         print(ex.shape)
        
        # Tweet2
        batch_size = tweet2.shape[0]
        seq_len = tweet2.shape[1]
        h_0, c_0 = self.init_hidden(batch_size)
        output2, (h_n, c_n) = self.lstm(tweet2, (h_0, c_0))
        output_2 = output2.view(batch_size, seq_len, 2, 64)
        indices = torch.Tensor(list(range(batch_size))).long()
        
        ment_part1 = output_2[indices, pos2[:, 1].long(), 0, :]
        ment_part2 = output_2[indices, 0, 1, :]
        
        mention_feature2 = torch.cat([ment_part1, ment_part2], dim=1)
        Rc2 = output2 * mention_feature2.view(batch_size, 1, -1)
        alpha2 = torch.tanh(self.selective(Rc2))
        select2 = alpha2 * output2
        t = self.attention(mention_feature2.view(batch_size, 1, -1), select2)
        Vem2 = torch.cat([mention_feature2, t[0].view(batch_size, -1)], dim=1)
        
        final = torch.cat([Vem1, Vem2], dim=1)
        return self.final(final)        


In [64]:
model = Model().to(device)

In [65]:
from torch.optim import Adam

In [66]:
optimizer = Adam(model.parameters())

In [67]:
for tweet1, tweet2, dist1, dist2, pos1, pos2, label in loader:
    tweet1 = tweet1.to(device)
    tweet2 = tweet2.to(device)
    dist1 = dist1.to(device)
    dist2 = dist2.to(device)
    pos1 = pos1.to(device)
    pos2 = pos2.to(device)
    label = label.to(device)
    prediction = model(tweet1, tweet2, dist1, dist2, pos1, pos2)
    loss = criterion(prediction.squeeze(), label.squeeze())
    loss.backward()
    optimizer.step()
    print(loss)

tensor(0.7240, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.6564, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5962, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.5340, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.4703, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.3969, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.3360, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.2645, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.1933, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.1266, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0796, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0381, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0210, device='cuda:0', grad_fn=

KeyboardInterrupt: 