In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [2]:
%config Completer.use_jedi = False

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import pickle
import numpy as np
from sklearn import metrics
from sklearn.metrics import f1_score

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as func

In [6]:
with open('data/bibtex/train.pickle', "rb") as f:
    temp = pickle.load(f)
    data_x = np.array([instance['feats'] for instance in temp])
    data_y = np.array([instance['types'] for instance in temp])

In [7]:
with open('data/bibtex/test.pickle', "rb") as f:
    temp = pickle.load(f)
    test_x = np.array([instance['feats'] for instance in temp])
    test_y = np.array([instance['types'] for instance in temp])


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
data_x = torch.FloatTensor(data_x).to(device)
data_y = torch.FloatTensor(data_y).to(device)
test_x = torch.FloatTensor(test_x).to(device)
test_y = torch.FloatTensor(test_y).to(device)

# model

In [10]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1836, 150)
        self.layer2 = nn.Linear(150, 150)
        self.out_l = nn.Linear(150, 159)
        
    
    def forward(self, x, only_feature_extraction=False):
        out = func.relu(self.layer1(x))
        out = func.relu(self.layer2(out))
        if not only_feature_extraction:
            out = self.out_l(out)
            out = torch.sigmoid(out)
        return out

# f1 and mAP

In [11]:
Threshold = [0.05, 0.10, 0.15, 0.2, 0.25, 0.30, 0.35, 0.4, 0.45, 0.5, 0.55, 0.60, 0.65, 0.70, 0.75]
def f1_map(test_y, pred_test):
    best_f1 = 0
    for t in Threshold:
        pred = pred_test > t
        f1 = f1_score(test_y.data.cpu().numpy(), pred.data.cpu().numpy(), average='samples')
        if f1 > best_f1:
            best_f1 = f1
    mAP = np.mean(metrics.average_precision_score(
        test_y.data.cpu().numpy(), pred_test.data.cpu().numpy(), average=None
    ))
    
    return best_f1, mAP

# pre-training feature net

In [14]:
model = MLP().to(device)
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)

In [15]:

for epoch in range(20):
    inds = np.random.permutation(list(range(len(data_x))))
    for i in range(0, 4880, 80):
        l = inds[i:i+80]
        data = data_x[l]
        label = data_y[l]
        logits = model(data)
        
        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch', epoch)
    with torch.no_grad():
        logits = model(test_x)
        pred_test = torch.sigmoid(logits)
    
    best_f1, mAP = f1_map(test_y, pred_test)
    print(best_f1, mAP)
    print()

epoch 0
0.0414399394039205 0.020842633841436933

epoch 1
0.05286597873078986 0.02799239158708507

epoch 2
0.07068036227081953 0.052707762726149784

epoch 3
0.1673552778125343 0.09121999194169501

epoch 4
0.21151587312621106 0.118025021028019

epoch 5
0.23361619447249293 0.14419005436022742

epoch 6
0.26242346012723744 0.1692342302364163

epoch 7
0.29072475836716877 0.1967457966908786

epoch 8
0.3097807762073427 0.21543534492927094

epoch 9
0.3349576502623345 0.23439222164592144

epoch 10
0.34892444884855606 0.25013439869614496

epoch 11
0.35706535406406764 0.2622146933649457

epoch 12
0.3628598005731947 0.27239799682330157

epoch 13
0.37195683954212655 0.27919367332978173

epoch 14
0.37369825977039467 0.28451671429981823

epoch 15
0.37713901792080695 0.2898621354885521

epoch 16
0.3822975290311186 0.2924142790193828

epoch 17
0.3834344100502664 0.293635805789803

epoch 18
0.39442181802600956 0.2970049204138833

epoch 19
0.38807548450631135 0.2973065234538496



In [30]:
with torch.no_grad():
    logits = model(test_x)
    pred_test = torch.sigmoid(logits)

best_f1, mAP = f1_map(test_y, pred_test)
print(best_f1, mAP)

0.38807548450631135 0.2973065234538496


In [17]:
class EnergyNetwork(nn.Module):
    def __init__(self, weights_last_layer_mlp=150, feature_dim=150, label_dim=159,
                 num_pairwise=16, non_linearity=nn.Softplus()):
        super().__init__()

        self.non_linearity = non_linearity

        self.B = torch.nn.Parameter(torch.transpose(-weights_last_layer_mlp, 0, 1))

        # Label energy terms, C1/c2  in equation 5 of SPEN paper
        self.C1 = torch.nn.Parameter(torch.empty(label_dim, num_pairwise))
        torch.nn.init.normal_(self.C1, mean=0, std=np.sqrt(2.0 / label_dim))

        self.c2 = torch.nn.Parameter(torch.empty(num_pairwise, 1))
        torch.nn.init.normal_(self.c2, mean=0, std=np.sqrt(2.0 / num_pairwise))

    def forward(self, x, y):
        # Local energy
        e_local = torch.mm(x, self.B)
        # element-wise product
        e_local = torch.mul(y, e_local)
        e_local = torch.sum(e_local, dim=1)
        e_local = e_local.view(e_local.size()[0], 1)

        # Label energy
        e_label = self.non_linearity(torch.mm(y, self.C1))
        e_label = torch.mm(e_label, self.c2)
        e_global = torch.add(e_label, e_local)

        return e_global

In [18]:
import torch.optim as optim

class SPEN():
    def __init__(self, feature_net, energy_net, inf_net, n_steps_inf=1, input_dim=1836, label_dim=159, num_pairwise=16,
                 learning_rate=1e-5, weight_decay=1e-4, non_linearity=nn.Softplus()):
        self.feature_extractor = feature_net
        self.feature_extractor.eval()
        self.energy_net = energy_net
        self.inf_net = inf_net
        self.loss_fn = torch.nn.MSELoss(reduction='sum')
        
    def _compute_energy(self, inputs, targets):
        f_x = self.feature_extractor(inputs, only_feature_extraction=True)
        
        # Energy ground truth
        gt_energy = self.energy_net(f_x, targets)
        
        # Cost-augmented inference network
        pred_labels = self.inf_net(inputs)
        pred_energy = self.energy_net(f_x, pred_labels)
        return pred_labels, pred_energy, gt_energy
    
    def compute_loss(self, inputs, targets):
        
        pred_labels, pred_energy, gt_energy = self._compute_energy(inputs, targets)
        # Max-margin Loss
        pre_loss = self.loss_fn(pred_labels, targets) - pred_energy + gt_energy
        eneryg_loss = torch.max(pre_loss, torch.zeros_like(pre_loss))
        eneryg_loss = torch.mean(eneryg_loss)        
        inf_net_loss = eneryg_loss
        e_net_loss = eneryg_loss + 0.01 * sum(p.pow(2.0).sum() for p in self.energy_net.parameters())
        return pred_labels, e_net_loss, inf_net_loss

    def pred(self, x):
        with torch.no_grad():
            y_pred = self.inf_net(x)
        return y_pred
    
    def inference(self, x, training=False, n_steps=1):
        
        sd = self.inf_net.state_dict()
        inf_net2 = MLP()
        inf_net2.load_state_dict(sd)
        self.inf_net.eval()
        optimizer = optim.SGD(self.inf_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        with torch.no_grad():
            y_pred = self.inf_net(x)
        
        self.inf_net.train()
        
        return y_pred

In [31]:
import torch.optim as optim

eps = 1e-8
class SPEN():
    def __init__(self, feature_net, energy_net, inf_net, n_steps_inf=1, input_dim=1836, label_dim=159, num_pairwise=16,
                 learning_rate=1e-5, weight_decay=1e-4, non_linearity=nn.Softplus()):
        self.feature_extractor = feature_net
        self.feature_extractor.eval()
        self.energy_net = energy_net
        self.inf_net = inf_net
        self.loss_fn = torch.nn.MSELoss(reduction='sum')
        
    def _compute_energy(self, inputs, targets):
        f_x = self.feature_extractor(inputs, only_feature_extraction=True)
        
        # Energy ground truth
        gt_energy = self.energy_net(f_x, targets)
        
        # Cost-augmented inference network
        pred_labels = self.inf_net(inputs)
        pred_energy = self.energy_net(f_x, pred_labels)
        return pred_labels, pred_energy, gt_energy
    
    def compute_loss(self, inputs, targets):
        
        pred_labels, pred_energy, gt_energy = self._compute_energy(inputs, targets)
        # Max-margin Loss
        pre_loss = self.loss_fn(pred_labels, targets) - pred_energy + gt_energy
        eneryg_loss = torch.max(pre_loss, torch.zeros_like(pre_loss))
        eneryg_loss = torch.mean(eneryg_loss)
        
        pred_y = pred_labels

        entropy_loss = - torch.mean(pred_y * torch.log(pred_y) + (1 - pred_y) * torch.log(1 - pred_y))
        
        inf_net_loss = -eneryg_loss + 0.01 * sum(p.pow(2.0).sum() for p in self.inf_net.parameters()) \
                       + 10 * sum((x - y).pow(2.0).sum() for x, y in zip(self.inf_net.state_dict().values(), self.feature_extractor.state_dict().values())) \
                       + entropy_loss
        e_net_loss = eneryg_loss + 0.01 * sum(p.pow(2.0).sum() for p in self.energy_net.parameters())
        return pred_labels, e_net_loss, inf_net_loss

    def pred(self, x):
        with torch.no_grad():
            y_pred = self.inf_net(x)
        return y_pred
    
    def inference(self, x, training=False, n_steps=1):
        
        sd = self.inf_net.state_dict()
        inf_net2 = MLP()
        inf_net2.load_state_dict(sd)
        self.inf_net.eval()
        optimizer = optim.SGD(self.inf_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        with torch.no_grad():
            y_pred = self.inf_net(x)
        
        self.inf_net.train()
        
        return y_pred

In [32]:
inf_net = MLP().to(device)
energy_net = EnergyNetwork(model.out_l.weight).to(device)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=3e-4, weight_decay=0)
optim_energy = torch.optim.Adam(energy_net.parameters(), lr=1e-5, weight_decay=0)

spen = SPEN(model, energy_net, inf_net, )

In [33]:

for epoch in range(100):
    inds = np.random.permutation(list(range(len(data_x))))
    for i in range(0, 4880, 80):
        l = inds[i:i+80]
        data = data_x[l]
        label = data_y[l]
        
        optim_inf.zero_grad()
        preds, e_loss, inf_loss = spen.compute_loss(data, label)
        # preds, inf_loss = spen.compute_loss(data, label)
        
        inf_loss.backward()
        optim_inf.step()
        
        optim_energy.zero_grad()
        preds, e_loss, inf_loss = spen.compute_loss(data, label)
        
        e_loss.backward()
        optim_energy.step()
    print(epoch)

    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print(best_f1, mAP)
    
    # pred_test = spen.inference(test_x)
    # best_f1, mAP = f1_map(test_y, pred_test)
    # print(best_f1, mAP)
    print()

0


ValueError: Input contains NaN, infinity or a value too large for dtype('float32').