In [1]:
from itertools import chain, islice
import utils
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
from sklearn.model_selection import train_test_split

In [2]:
class FakeNewsDataset(Dataset):
    def __init__(self, root):
        super(FakeNewsDataset, self).__init__()
        self.data = np.load(root, allow_pickle=True)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        if idx < 0:
            idx = len(self) + idx
        return torch.FloatTensor(self.data[idx][2]), torch.FloatTensor(self.data[idx][3]), self.data[idx][1]

name = "data/processed/data_bin1h_cut2160_100w_20xu_50yu.npy"
fake_news_data = FakeNewsDataset(name)

In [3]:
x, y, label = fake_news_data[2]
print(x.shape, y.shape, label)


torch.Size([128, 122]) torch.Size([161, 50]) 0


In [4]:
class Capture(nn.Module):
    def __init__(self, feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim):
        super(Capture, self).__init__()
        self.input_embedding = nn.Sequential(
            nn.Linear(feature_dim, feature_embedding_dim),
            nn.Tanh(),
            nn.Dropout(0.2)
        )
        self.lstm = nn.LSTM(feature_embedding_dim, lstm_hidden_dim, batch_first=True)
        self.output_embedding = nn.Sequential(
            nn.Linear(lstm_hidden_dim, article_embedding_dim),
            nn.Tanh(),
            nn.Dropout(0.2)
        )

    def forward(self, x):
        x = self.input_embedding(x)
        _, (x, _) = self.lstm(x)
        x = self.output_embedding(x[-1])
        return x

class Score(nn.Module):
    def __init__(self, user_dim, user_embedding_dim):
        super(Score, self).__init__()
        self.user_embedding = nn.Sequential(
            nn.Linear(user_dim, user_embedding_dim), # add regularization
            nn.Tanh()
        )
        self.user_score = nn.Sequential(
            nn.Linear(user_embedding_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, y):
        y_hat = self.user_embedding(y)
        s = self.user_score(y_hat)
        p = s.mean(dim=1)
        return p, s, y_hat

class Integrate(nn.Module):
    def __init__(self, capture_dim):
        super(Integrate, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(capture_dim + 1, 1),
            #nn.Sigmoid() #  disabled, part of loss function
        )

    def forward(self, x, y):
        v = torch.cat((x, y), 1)
        v = self.net(v)
        return v


class CSI(nn.Module):
    def __init__(self, feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim,
                 user_dim, user_embedding_dim):
        super(CSI, self).__init__()
        self.capture = Capture(feature_dim, feature_embedding_dim, lstm_hidden_dim, article_embedding_dim)
        self.score = Score(user_dim, user_embedding_dim)
        self.integrate = Integrate(article_embedding_dim)

    def forward(self, x, y):
        x = self.capture(x)
        y, _, _ = self.score(y)
        l = self.integrate(x, y)
        return l

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
csi = CSI(122, 100, 100, 100, 50, 100)
#csi.load_state_dict(torch.load(f"models/csi.pt", map_location=torch.device('cpu')))
csi.to(device)

CSI(
  (capture): Capture(
    (input_embedding): Sequential(
      (0): Linear(in_features=122, out_features=100, bias=True)
      (1): Tanh()
      (2): Dropout(p=0.2, inplace=False)
    )
    (lstm): LSTM(100, 100, batch_first=True)
    (output_embedding): Sequential(
      (0): Linear(in_features=100, out_features=100, bias=True)
      (1): Tanh()
      (2): Dropout(p=0.2, inplace=False)
    )
  )
  (score): Score(
    (user_embedding): Sequential(
      (0): Linear(in_features=50, out_features=100, bias=True)
      (1): Tanh()
    )
    (user_score): Sequential(
      (0): Linear(in_features=100, out_features=1, bias=True)
      (1): Sigmoid()
    )
  )
  (integrate): Integrate(
    (net): Sequential(
      (0): Linear(in_features=101, out_features=1, bias=True)
    )
  )
)

In [8]:
pytorch_total_params = sum(p.numel() for p in csi.parameters() if p.requires_grad)
pytorch_total_params

108503

In [6]:
def grouper(iterable, n):
    iterable = iter(iterable)
    try:
        while True:
            yield list(chain((next(iterable),), islice(iterable, n - 1)))
    except StopIteration:
        return
index_list = list(range(len(fake_news_data)))
x_train, x_test = train_test_split(index_list, test_size=0.2, random_state=420)

In [7]:
csi.train()
batch_size = 48
n_epochs = 30
n_batches = int(len(fake_news_data) / batch_size + 1)
csi.train()
optimizer = torch.optim.Adam(csi.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
for i_epoch in range(n_epochs):
    random.shuffle(x_train)
    epoch_loss = 0
    for i_batch, indices in enumerate(grouper(x_train, batch_size)):
        labels = []
        outputs = []
        for idx in indices:
            x, y, label = fake_news_data[idx]
            x = x.unsqueeze(0).to(device)
            y = y.unsqueeze(0).to(device)
            label = torch.FloatTensor([[int(label)]]).to(device)
            output = csi(x, y)
            labels.append(label)
            outputs.append(output)
        labels = torch.vstack(labels)
        outputs = torch.vstack(outputs)
        # print(outputs, labels)
        loss = criterion(outputs, labels) + 0.001 / 2 * torch.norm(next(csi.score.user_embedding[0].parameters()))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if i_batch % 25 == 0:
            print(f"Batch {i_batch + 1}/{n_batches}: {loss.item()}")
    if i_epoch % 1 == 0:
        print(f"Epoch {i_epoch + 1}/{n_epochs}: {epoch_loss / n_batches}")

Batch 1/21: 0.6965411901473999
Epoch 1/30: 0.5370955921354748
Batch 1/21: 0.6281033158302307
Epoch 2/30: 0.4885947718506768
Batch 1/21: 0.6639443635940552
Epoch 3/30: 0.4698239124956585
Batch 1/21: 0.6195600032806396
Epoch 4/30: 0.4541335105895996
Batch 1/21: 0.7003993391990662
Epoch 5/30: 0.44634845285188585
Batch 1/21: 0.5354519486427307
Epoch 6/30: 0.4433618372394925
Batch 1/21: 0.629942774772644
Epoch 7/30: 0.4199687199933188
Batch 1/21: 0.4480827748775482
Epoch 8/30: 0.39924075206120807
Batch 1/21: 0.44649407267570496
Epoch 9/30: 0.3992875459648314
Batch 1/21: 0.4752514064311981
Epoch 10/30: 0.38192891648837496
Batch 1/21: 0.4371241629123688
Epoch 11/30: 0.3613453053292774
Batch 1/21: 0.48646754026412964
Epoch 12/30: 0.38374326484543936
Batch 1/21: 0.43573570251464844
Epoch 13/30: 0.3695550887357621
Batch 1/21: 0.37570542097091675
Epoch 14/30: 0.3436501891840072
Batch 1/21: 0.5979343056678772
Epoch 15/30: 0.3571427024546124
Batch 1/21: 0.5215819478034973
Epoch 16/30: 0.33353018760

In [8]:
csi.eval()
result = []
for i in x_test:
    x, y, label = fake_news_data[i]
    x = x.unsqueeze(0).to(device)
    y = y.unsqueeze(0).to(device)
    label = int(label)
    with torch.no_grad():
        output = csi(x, y)
    result.append((label, torch.sigmoid(output).item()))
count = 0
for label, p_label in result:
    p_label = round(p_label)
    if p_label == label:
        count += 1
count / len(x_test)

0.7236180904522613

In [14]:
torch.save(csi.state_dict(), f"models/csi.pt")