In [29]:
from itertools import chain, islice
import utils
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random

In [2]:
class FakeNewsDataset(Dataset):
    def __init__(self):
        super(FakeNewsDataset, self).__init__()
        self.data = np.load("data/processed/data_bin12h_cut180_100w_20xu_50yu.npy", allow_pickle=True)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if idx < 0:
            idx = len(self) + idx
        return torch.FloatTensor(self.data[idx][2]), torch.FloatTensor(self.data[idx][3]), self.data[idx][1]

fake_news_data = FakeNewsDataset()

In [3]:
x, y, label = fake_news_data[2]
print(x.shape, y.shape, label)


torch.Size([65, 122]) torch.Size([161, 50]) 0


In [4]:
class Capture(nn.Module):
    def __init__(self, feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim):
        super(Capture, self).__init__()
        self.input_embedding = nn.Sequential(
            nn.Linear(feature_dim, feature_embedding_dim),
            nn.Tanh(),
        )
        self.lstm = nn.LSTM(feature_embedding_dim, lstm_hidden_dim, batch_first=True)
        self.output_embedding = nn.Sequential(
            nn.Linear(lstm_hidden_dim, article_embedding_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.input_embedding(x)
        _, (x, _) = self.lstm(x)
        x = self.output_embedding(x[-1])
        return x

class Score(nn.Module):
    def __init__(self, user_dim, user_embedding_dim):
        super(Score, self).__init__()
        self.user_embedding = nn.Sequential(
            nn.Linear(user_dim, user_embedding_dim), # add regularization
            nn.Tanh()
        )
        self.user_score = nn.Sequential(
            nn.Linear(user_embedding_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, y):
        y_hat = self.user_embedding(y)
        s = self.user_score(y_hat)
        p = s.mean(dim=1)
        return p, s, y_hat

class Integrate(nn.Module):
    def __init__(self, capture_dim):
        super(Integrate, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(capture_dim + 1, 1),
            #nn.Sigmoid() #  disabled, part of loss function
        )

    def forward(self, x, y):
        v = torch.cat((x, y), 1)
        v = self.net(v)
        return v


class CSI(nn.Module):
    def __init__(self, feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim,
                 user_dim, user_embedding_dim):
        super(CSI, self).__init__()
        self.capture = Capture(feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim)
        self.score = Score(user_dim, user_embedding_dim)
        self.integrate = Integrate(article_embedding_dim)

    def forward(self, x, y):
        x = self.capture(x)
        y, _, _ = self.score(y)
        l = self.integrate(x, y)
        return l

In [34]:
device = "cuda" if torch.cuda.is_available() else "cpu"
csi = CSI(122, 100, 100, 50, 50, 50)
csi.load_state_dict(torch.load(f"models/csi.pt", map_location=torch.device('cpu')))
csi.to(device)

CSI(
  (capture): Capture(
    (input_embedding): Sequential(
      (0): Linear(in_features=122, out_features=100, bias=True)
      (1): Tanh()
    )
    (lstm): LSTM(100, 100, batch_first=True)
    (output_embedding): Sequential(
      (0): Linear(in_features=100, out_features=50, bias=True)
      (1): Tanh()
    )
  )
  (score): Score(
    (user_embedding): Sequential(
      (0): Linear(in_features=50, out_features=50, bias=True)
      (1): Tanh()
    )
    (user_score): Sequential(
      (0): Linear(in_features=50, out_features=1, bias=True)
      (1): Sigmoid()
    )
  )
  (integrate): Integrate(
    (net): Sequential(
      (0): Linear(in_features=51, out_features=1, bias=True)
    )
  )
)

In [31]:
def grouper(iterable, n):
    iterable = iter(iterable)
    try:
        while True:
            yield list(chain((next(iterable),), islice(iterable, n - 1)))
    except StopIteration:
        return
index_list = list(range(len(fake_news_data)))

In [52]:
batch_size = 48
n_epochs = 10
n_batches = int(len(fake_news_data) / batch_size + 1)
csi.train()
optimizer = torch.optim.Adam(csi.parameters(), lr=0.003)
criterion = nn.BCEWithLogitsLoss()
for i_epoch in range(n_epochs):
    random.shuffle(index_list)
    epoch_loss = 0
    for i_batch, indices in enumerate(grouper(index_list, batch_size)):
        labels = []
        outputs = []
        for idx in indices:
            x, y, label = fake_news_data[idx]
            x = x.unsqueeze(0).to(device)
            y = y.unsqueeze(0).to(device)
            label = torch.FloatTensor([[int(label)]]).to(device)
            output = csi(x, y)
            labels.append(label)
            outputs.append(output)
        labels = torch.vstack(labels)
        outputs = torch.vstack(outputs)
        # print(outputs, labels)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if i_batch % 25 == 0:
            print(f"Batch {i_batch + 1}/{n_batches}: {loss.item()}")
    if i_epoch % 1 == 0:
        print(f"Epoch {i_epoch + 1}/{n_epochs}: {epoch_loss / n_batches}")

Batch 1/21: 0.2018938511610031
Epoch 1/10: 0.18743291958456948
Batch 1/21: 0.08315072953701019
Epoch 2/10: 0.1101628149016982
Batch 1/21: 0.03916238248348236
Epoch 3/10: 0.08473625353404454
Batch 1/21: 0.04589641094207764
Epoch 4/10: 0.07845223349119936
Batch 1/21: 0.015788085758686066
Epoch 5/10: 0.07068554400688126
Batch 1/21: 0.029216861352324486
Epoch 6/10: 0.05664382972532794
Batch 1/21: 0.05986647307872772
Epoch 7/10: 0.0640338005586749
Batch 1/21: 0.07190493494272232
Epoch 8/10: 0.056665426741043724
Batch 1/21: 0.005099854432046413
Epoch 9/10: 0.04383390307581673
Batch 1/21: 0.03523223102092743
Epoch 10/10: 0.02362060153834699


In [46]:
result = []
for i in range(len(fake_news_data)):
    x, y, label = fake_news_data[i]
    x = x.unsqueeze(0).to(device)
    y = y.unsqueeze(0).to(device)
    label = int(label)
    with torch.no_grad():
        output = csi(x, y)
    result.append((label, torch.sigmoid(output).item()))

In [48]:
count = 0
for label, p_label in result:
    p_label = round(p_label)
    if p_label == label:
        count += 1

In [50]:
count / len(fake_news_data)

0.9375

In [51]:
torch.save(csi.state_dict(), f"models/csi.pt")