In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from sklearn import metrics
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as func

# data

In [3]:
with open('data/bibtex/train.pickle', "rb") as f:
    temp = pickle.load(f)
    data_x = np.array([instance['feats'] for instance in temp])
    data_y = np.array([instance['types'] for instance in temp])


with open('data/bibtex/test.pickle', "rb") as f:
    temp = pickle.load(f)
    test_x = np.array([instance['feats'] for instance in temp])
    test_y = np.array([instance['types'] for instance in temp])


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_x = torch.FloatTensor(data_x).to(device)
data_y = torch.FloatTensor(data_y).to(device)
test_x = torch.FloatTensor(test_x).to(device)
test_y = torch.FloatTensor(test_y).to(device)


  return torch._C._cuda_getDeviceCount() > 0


In [4]:
def f1_map(y, pred, threshold=None):
    if threshold is None:
        threshold = [0.05, 0.10, 0.15, 0.2, 0.25, 0.30, 0.35, 0.4, 0.45, 0.5, 0.55, 0.60, 0.65, 0.70, 0.75]
    else:
        threshold = [0.5]
    best_f1 = 0
    for t in threshold:
        local_pred = pred > t
        local_f1 = f1_score(y.data.cpu().numpy(), local_pred.data.cpu().numpy(), average='samples')
        if local_f1 > best_f1:
            best_f1 = local_f1
    precision = np.mean(metrics.average_precision_score(
        y.data.cpu().numpy(), pred.data.cpu().numpy(), average=None
    ))

    return best_f1, precision

# Arch s

In [5]:

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1836, 150)
        self.layer2 = nn.Linear(150, 150)      
    
    def forward(self, x):
        out = func.relu(self.layer1(x))
        out = func.relu(self.layer2(out))
        return out
    

In [91]:
class EnergyNet(nn.Module):
    def __init__(self, weights_last_layer_mlp=150, feature_dim=150, label_dim=159,
                 num_pairwise=16, non_linearity=nn.Softplus()):
        super().__init__()

        self.non_linearity = non_linearity

        self.linear_wt = nn.Linear(150, label_dim, bias=False) 

        # Label energy terms, C1/c2  in equation 5 of SPEN paper
        self.C1 = nn.Linear(label_dim, num_pairwise)

        self.c2 = nn.Linear(num_pairwise, 1, bias=False)

    def forward(self, x, y):
        # Local energy
        negative_logits = self.linear_wt(x)
        feat_probs = torch.sigmoid(-1 * negative_logits)
        
        # element-wise product
        e_local = torch.mul(negative_logits, y)
        e_local = torch.sum(e_local, dim=1)

        # Label energy
        e_label = self.non_linearity(self.C1(y))
        e_label = self.c2(e_label)
        assert e_label.view(-1).shape[0] == e_label.shape[0]
        assert e_label.view(-1).shape[0] == e_local.shape[0]
        e_global = torch.add(e_label.view(-1), e_local.view(-1))

        return e_global, feat_probs

In [92]:

feat_net = MLP().to(device)
energy_net = EnergyNet()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(list(feat_net.parameters()) + list(energy_net.parameters()),
                                  lr=1e-3, weight_decay=0)

for epoch in range(20):
    inds = np.random.permutation(list(range(len(data_x))))
    for i in range(0, 4880, 80):
        l = inds[i:i+80]
        data = data_x[l]
        label = data_y[l]
        feat = feat_net(data)
        _, logits = energy_net(feat, label)
        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch', epoch)
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)
    
    best_f1, mAP = f1_map(test_y, pred_test)
    print(best_f1, mAP)
    print()

#%%

with torch.no_grad():
    feat = feat_net(test_x)
    _, pred_test = energy_net(feat, test_y)

best_f1, mAP = f1_map(test_y, pred_test)
print(best_f1, mAP)


epoch 0
0.07702077652393424 0.02075228933809243

epoch 1
0.10425859853892766 0.028361680183831842

epoch 2
0.15637174975887747 0.053388570189344654

epoch 3
0.20183496788640407 0.0930577746249585

epoch 4
0.21475814643095767 0.11609853677749406

epoch 5
0.2500084685828593 0.13940866948476346

epoch 6
0.27325017872195206 0.16497941486537399

epoch 7
0.28653947035232186 0.18902257228666106

epoch 8
0.30766674531205895 0.21249409255524154

epoch 9
0.32588053567851133 0.2284908196945189

epoch 10
0.34165172914429254 0.2437831384621451

epoch 11
0.34755725569151297 0.25463065368460297

epoch 12
0.3622423219253687 0.2649500044879656

epoch 13
0.3663316263977641 0.2753344939804642

epoch 14
0.3681037263254425 0.28134065656099055

epoch 15
0.3816848401753928 0.28649331763652985

epoch 16
0.38664342640487787 0.2921747442302924

epoch 17
0.39041074616526494 0.29210938399327213

epoch 18
0.39482345197493696 0.29871975362279063

epoch 19
0.39402649532207923 0.29800158196302456

0.39402649532207923

In [94]:
with torch.no_grad():
    feat = feat_net(test_x)
    print(feat)
    _, pred_test = energy_net(feat, test_y)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)

tensor([[ 0.0000,  4.0921, 18.9302,  ...,  0.0000,  4.1326,  2.4208],
        [ 0.0000, 10.5515, 12.6726,  ...,  0.0000,  6.6926, 11.5328],
        [ 0.0000,  1.5048,  7.3782,  ...,  0.0000,  2.1468,  1.1797],
        ...,
        [ 0.0000,  3.4105,  6.5482,  ...,  0.0000,  0.0000,  2.2689],
        [ 0.0000,  8.5150, 10.4562,  ...,  0.0000,  6.7461,  8.7696],
        [ 0.0000,  6.3319, 16.1583,  ...,  0.0000,  7.4409,  8.9332]])
0.39402649532207923 0.29800158196302456


In [96]:
class InfNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1836, 150)
        self.layer2 = nn.Linear(150, 150)
        self.layer3 = nn.Linear(150, 159, bias=False) 
    
    def forward(self, x):
        out = func.relu(self.layer1(x))
        out = func.relu(self.layer2(out))
        out = self.layer3(out)
        return torch.sigmoid(out)
    
inf_net = InfNet().to(device)
inf_net.layer1.weight.data = feat_net.layer1.weight.data.clone()
inf_net.layer1.bias.data = feat_net.layer1.bias.data.clone()
inf_net.layer2.weight.data = feat_net.layer2.weight.data.clone()
inf_net.layer2.bias.data = feat_net.layer2.bias.data.clone()
inf_net.layer3.weight.data = -energy_net.linear_wt.weight.data.clone()

In [97]:
with torch.no_grad():
    pred_test = inf_net(test_x)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)

0.39402649532207923 0.29800158196302456


# train cost inf

In [111]:
class SPEN():
    def __init__(self, feature_net, energy_net, inf_net, n_steps_inf=1, input_dim=1836, label_dim=159):
        self.feature_extractor = feature_net
        self.feature_extractor.eval()
        self.energy_net = energy_net
        self.inf_net = inf_net
        
        self.phi0 = InfNet().to(device)
        self.phi0.load_state_dict(inf_net.state_dict())
        
    def _compute_energy(self, inputs, targets):
        f_x = self.feature_extractor(inputs)
        
        # Energy ground truth
        gt_energy, _ = self.energy_net(f_x, targets)
        
        # Cost-augmented inference network
        pred_probs = self.inf_net(inputs)
        
        pred_energy, _ = self.energy_net(f_x, pred_probs)
        
        return pred_probs, pred_energy, gt_energy
    
    def compute_loss(self, inputs, targets):
        
        pred_probs, pred_energy, gt_energy = self._compute_energy(inputs, targets)
        # Max-margin Loss
        delta = torch.sum((pred_probs - targets)**2, dim=1)
        pre_loss_real = delta - pred_energy + gt_energy
        
        energy_loss = torch.relu(pre_loss_real)
        pre_loss_real = torch.mean(pre_loss_real)
        energy_loss = torch.mean(energy_loss)

        entropy_loss = nn.BCELoss()(pred_probs, pred_probs.detach())
        
        reg_losses_phi = 0.5 * sum(p.pow(2.0).sum() for p in self.inf_net.parameters())
        pretrain_bias = sum((x - y).pow(2.0).sum() for x, y in zip(self.inf_net.state_dict().values(), self.phi0.state_dict().values()))
        reg_losses_theta = 0.5 * sum(p.pow(2.0).sum() for p in self.energy_net.parameters())
        
        inf_net_loss = energy_loss \
                       - 0.001 * reg_losses_phi\
                       - 1 * pretrain_bias \
                       - 1 * entropy_loss
        inf_net_loss = -inf_net_loss
        
        e_net_loss = energy_loss + 0.001 * reg_losses_theta
        
        summaries = {
            'base_objective': energy_loss.item(),
            'base_obj_real': pre_loss_real.item(),
            'energy_inf_net': pred_energy.mean().item(),
            'energy_ground_truth': gt_energy.mean().item(),
            'reg_losses_theta': reg_losses_theta.item(),
            'reg_losses_phi': reg_losses_phi.item(),
            'reg_losses_entropy': entropy_loss.item(),
            'pretrain_bias': pretrain_bias.item()
        }
        
        return pred_probs, e_net_loss, inf_net_loss, summaries

    def pred(self, x):
        with torch.no_grad():
            y_pred = self.inf_net(x)
        return y_pred
    
    def inference(self, x, training=False, n_steps=1):
        
        sd = self.inf_net.state_dict()
        inf_net2 = MLP()
        inf_net2.load_state_dict(sd)
        self.inf_net.eval()
        optimizer = optim.SGD(self.inf_net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
        with torch.no_grad():
            y_pred = self.inf_net(x)
        
        self.inf_net.train()
        
        return y_pred

In [99]:


optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print(best_f1, mAP)


for epoch in range(100):
    inds = np.random.permutation(list(range(len(data_x))))
    for i in range(0, 4880, 32):
        l = inds[i:i+32]
        data = data_x[l]
        label = data_y[l]
        
        optim_inf.zero_grad()
        preds, e_loss, inf_loss = spen.compute_loss(data, label)
        # preds, inf_loss = spen.compute_loss(data, label)
        
        inf_loss.backward()
        optim_inf.step()
        
        optim_energy.zero_grad()
        preds, e_loss, inf_loss = spen.compute_loss(data, label)
        
        e_loss.backward()
        optim_energy.step()
    print(epoch)

    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print(best_f1, mAP)
    
    # pred_test = spen.inference(test_x)
    # best_f1, mAP = f1_map(test_y, pred_test)
    # print(best_f1, mAP)
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)

    f1, mAP = f1_map(test_y, pred_test)
    print('feature net', f1, mAP)
    print()

0.39402649532207923 0.29800158196302456
0
0.20420840039130295 0.0943624346668638
feature net 0.39402649532207923 0.29800158196302456

1
0.1622137019154912 0.07552265274980116
feature net 0.39402649532207923 0.29800158196302456

2
0.182375587743381 0.06693666780955289
feature net 0.39402649532207923 0.29800158196302456



KeyboardInterrupt: 

In [103]:
import tensorflow as tf
def tf2torch(checkpoint, feat_net, inf_net, energy_net):
    
    tf_path = os.path.abspath(checkpoint)
    init_vars = tf.train.list_variables(tf_path)

    tf_vars = []
    for name, shape in init_vars:
        # print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        tf_vars.append((name, array.squeeze()))
    
    feat_i = 12
    energy_i = 8
    inf_i = 18
    feat_net.layer1.bias.data = torch.from_numpy(tf_vars[feat_i][1].T)
    feat_net.layer1.weight.data = torch.from_numpy(tf_vars[feat_i + 1][1].T)
    feat_net.layer2.bias.data = torch.from_numpy(tf_vars[feat_i + 2][1].T)
    feat_net.layer2.weight.data = torch.from_numpy(tf_vars[feat_i + 3][1].T)
    
    energy_net.C1.bias.data = torch.from_numpy(tf_vars[energy_i][1].T)
    energy_net.C1.weight.data = torch.from_numpy(tf_vars[energy_i + 1][1].T)
    energy_net.c2.weight.data = torch.from_numpy(tf_vars[energy_i + 2][1].T)
    energy_net.linear_wt.weight.data = torch.from_numpy(tf_vars[energy_i + 3][1].T)


    inf_net.layer1.bias.data = torch.from_numpy(tf_vars[inf_i][1].T)
    inf_net.layer1.weight.data = torch.from_numpy(tf_vars[inf_i + 1][1].T)
    inf_net.layer2.bias.data = torch.from_numpy(tf_vars[inf_i + 2][1].T)
    inf_net.layer2.weight.data = torch.from_numpy(tf_vars[inf_i + 3][1].T)
    inf_net.layer3.weight.data = torch.from_numpy(tf_vars[inf_i + 4][1].T)
    
    return feat_net, inf_net, energy_net

feat_net1 = MLP()
inf_net1 = InfNet()
energy_net1 = EnergyNet()
feat_net1, inf_net1, energy_net1 = tf2torch('./copied.ckpt', feat_net1, inf_net1, energy_net1)
with torch.no_grad():
    pred_test = inf_net1(test_x)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)


with torch.no_grad():
    feat = feat_net1(test_x)
    _, pred_test = energy_net1(feat, test_y)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)

0.4105516296229916 0.3304564624708914
0.4105516296229916 0.3304564624708914


In [116]:
feat_net = MLP()
inf_net = InfNet()
energy_net = EnergyNet()
feat_net, inf_net, energy_net = tf2torch('./copied.ckpt', feat_net, inf_net, energy_net)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print('inf net start', best_f1, mAP)


for epoch in range(100):

    for i in range(0, 4880, 32):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        label = data_y[l]
        
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)
        # preds, inf_loss = spen.compute_loss(data, label)
        
        inf_loss.backward()
        optim_inf.step()
        
        optim_energy.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)
        e_loss.backward()
        optim_energy.step()
    print(epoch)

    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print('current inf net', best_f1, mAP)
    
    
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)

    f1, mAP = f1_map(test_y, pred_test)
    print('feature net', f1, mAP)
    print()
    if epoch > 10: break

inf net start 0.4105516296229916 0.33045648352092005
0
current inf net 0.12936284170876616 0.0997916248653658
feature net 0.4105516296229916 0.33045648352092005

1
current inf net 0.1363707244820565 0.08314041959203788
feature net 0.4105516296229916 0.33045648352092005

2
current inf net 0.1290234395602189 0.07465214780016267
feature net 0.4105516296229916 0.33045648352092005

3
current inf net 0.14288885692568856 0.07560781111297611
feature net 0.4105516296229916 0.33045648352092005

4
current inf net 0.13243287968934092 0.06903962678071672
feature net 0.4105516296229916 0.33045648352092005

5
current inf net 0.05080132389862246 0.06384816783175366
feature net 0.4105516296229916 0.33045648352092005

6
current inf net 0.03879185400907539 0.05476117936805065
feature net 0.4105516296229916 0.33045648352092005

7
current inf net 0.03454222813738034 0.0461665098021007
feature net 0.4105516296229916 0.33045648352092005

8
current inf net 0.03268306262005996 0.0406527228992705
feature net 0.

In [122]:
feat_net = MLP()
inf_net = InfNet()
energy_net = EnergyNet()
feat_net, inf_net, energy_net = tf2torch('./copied.ckpt', feat_net, inf_net, energy_net)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print(best_f1, mAP)


for epoch in range(100):

    for i in range(0, 4880, 32):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        label = data_y[l]
        
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summaries = spen.compute_loss(data, label)
        
        summary = {}
        summary['cost_phi'] = inf_loss.item()
        summary['cost_theta'] = e_loss.item()
        summary.update(summaries)
        print('before')
        print(summary)
        # preds, inf_loss = spen.compute_loss(data, label)
        
        inf_loss.backward()
        optim_inf.step()
        
        optim_energy.zero_grad()
        preds, e_loss, inf_loss, summaries = spen.compute_loss(data, label)
        summary = {}
        summary['cost_phi'] = inf_loss.item()
        summary['cost_theta'] = e_loss.item()
        summary.update(summaries)
        print('update phi')
        print(summary)
        
        e_loss.backward()
        optim_energy.step()
        
        preds, e_loss, inf_loss, summaries = spen.compute_loss(data, label)
        summary = {}
        summary['cost_phi'] = inf_loss.item()
        summary['cost_theta'] = e_loss.item()
        summary.update(summaries)
        print('update theta')
        print(summary)
    print(epoch)

    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print(best_f1, mAP)
    
    # pred_test = spen.inference(test_x)
    # best_f1, mAP = f1_map(test_y, pred_test)
    # print(best_f1, mAP)
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)

    f1, mAP = f1_map(test_y, pred_test)
    print('feature net', f1, mAP)
    print()
    break

0.4105516296229916 0.33045648352092005
before
{'cost_phi': 0.3497370183467865, 'cost_theta': 1.0473439693450928, 'base_objective': 0.8605383634567261, 'base_obj_real': -1.0311715602874756, 'energy_inf_net': -3.121706962585449, 'energy_ground_truth': -4.839283466339111, 'reg_losses_theta': 186.80557250976562, 'reg_losses_phi': 1181.28125, 'reg_losses_entropy': 0.028994053602218628, 'pretrain_bias': 0.0}
update phi
{'cost_phi': 0.45749542117118835, 'cost_theta': 1.2389947175979614, 'base_objective': 1.0521891117095947, 'base_obj_real': -0.07407276332378387, 'energy_inf_net': -4.043091297149658, 'energy_ground_truth': -4.839283466339111, 'reg_losses_theta': 186.80557250976562, 'reg_losses_phi': 1165.79248046875, 'reg_losses_entropy': 0.022366493940353394, 'pretrain_bias': 0.3215254545211792}
update theta
{'cost_phi': 0.45882007479667664, 'cost_theta': 1.2375274896621704, 'base_objective': 1.0508644580841064, 'base_obj_real': -0.0757099986076355, 'energy_inf_net': -4.040521621704102, 'ener

before
{'cost_phi': 21.778825759887695, 'cost_theta': 2.047877550125122, 'base_objective': 1.8621761798858643, 'base_obj_real': 0.7311524152755737, 'energy_inf_net': -2.6110777854919434, 'energy_ground_truth': -3.7848024368286133, 'reg_losses_theta': 185.70132446289062, 'reg_losses_phi': 1061.175537109375, 'reg_losses_entropy': 0.002505768556147814, 'pretrain_bias': 22.577320098876953}
update phi
{'cost_phi': 24.83225440979004, 'cost_theta': 2.0832901000976562, 'base_objective': 1.897588849067688, 'base_obj_real': 0.7517707347869873, 'energy_inf_net': -2.6109960079193115, 'energy_ground_truth': -3.7848024368286133, 'reg_losses_theta': 185.70132446289062, 'reg_losses_phi': 1053.9039306640625, 'reg_losses_entropy': 0.0022111183498054743, 'pretrain_bias': 25.673728942871094}
update theta
{'cost_phi': 24.83437156677246, 'cost_theta': 2.081122636795044, 'base_objective': 1.8954709768295288, 'base_obj_real': 0.7492126226425171, 'energy_inf_net': -2.613797903060913, 'energy_ground_truth': -3.

update theta
{'cost_phi': 68.47626495361328, 'cost_theta': 0.6681180596351624, 'base_objective': 0.48284250497817993, 'base_obj_real': -2.4851064682006836, 'energy_inf_net': -3.070526361465454, 'energy_ground_truth': -7.343853950500488, 'reg_losses_theta': 185.27554321289062, 'reg_losses_phi': 979.17431640625, 'reg_losses_entropy': 0.0013620661338791251, 'pretrain_bias': 67.97856903076172}
before
{'cost_phi': 67.62606048583984, 'cost_theta': 1.5186601877212524, 'base_objective': 1.33338463306427, 'base_obj_real': -0.6380000710487366, 'energy_inf_net': -2.091622829437256, 'energy_ground_truth': -4.757229328155518, 'reg_losses_theta': 185.27554321289062, 'reg_losses_phi': 979.17431640625, 'reg_losses_entropy': 0.0017048587324097753, 'pretrain_bias': 67.97856903076172}
update phi
{'cost_phi': 71.56405639648438, 'cost_theta': 1.5238126516342163, 'base_objective': 1.3385370969772339, 'base_obj_real': -0.6321216821670532, 'energy_inf_net': -2.088642120361328, 'energy_ground_truth': -4.757229

before
{'cost_phi': 121.0038070678711, 'cost_theta': 1.7357548475265503, 'base_objective': 1.5506882667541504, 'base_obj_real': -1.8603342771530151, 'energy_inf_net': -1.6212464570999146, 'energy_ground_truth': -5.491611480712891, 'reg_losses_theta': 185.06654357910156, 'reg_losses_phi': 910.0670166015625, 'reg_losses_entropy': 0.0009218505001626909, 'pretrain_bias': 121.64350891113281}
update phi
{'cost_phi': 125.2634048461914, 'cost_theta': 1.7393110990524292, 'base_objective': 1.5542445182800293, 'base_obj_real': -1.868951678276062, 'energy_inf_net': -1.6064882278442383, 'energy_ground_truth': -5.491611480712891, 'reg_losses_theta': 185.06654357910156, 'reg_losses_phi': 905.25, 'reg_losses_entropy': 0.0009163186186924577, 'pretrain_bias': 125.91148376464844}
update theta
{'cost_phi': 125.26541900634766, 'cost_theta': 1.737287998199463, 'base_objective': 1.5522303581237793, 'base_obj_real': -1.8720641136169434, 'energy_inf_net': -1.6064177751541138, 'energy_ground_truth': -5.49465370

update phi
{'cost_phi': 178.26927185058594, 'cost_theta': 1.331270456314087, 'base_objective': 1.1462836265563965, 'base_obj_real': -0.2757537066936493, 'energy_inf_net': -1.9894158840179443, 'energy_ground_truth': -4.354167461395264, 'reg_losses_theta': 184.98687744140625, 'reg_losses_phi': 850.85107421875, 'reg_losses_entropy': 0.0009717802749946713, 'pretrain_bias': 178.563720703125}
update theta
{'cost_phi': 178.27174377441406, 'cost_theta': 1.3287913799285889, 'base_objective': 1.1438076496124268, 'base_obj_real': -0.27893465757369995, 'energy_inf_net': -1.9896999597549438, 'energy_ground_truth': -4.357632637023926, 'reg_losses_theta': 184.98370361328125, 'reg_losses_phi': 850.85107421875, 'reg_losses_entropy': 0.0009717802749946713, 'pretrain_bias': 178.563720703125}
before
{'cost_phi': 178.40936279296875, 'cost_theta': 1.191227674484253, 'base_objective': 1.0062439441680908, 'base_obj_real': -0.8800417184829712, 'energy_inf_net': -1.2364283800125122, 'energy_ground_truth': -4.08

update phi
{'cost_phi': 235.3199920654297, 'cost_theta': 1.370072364807129, 'base_objective': 1.1850680112838745, 'base_obj_real': -1.5407465696334839, 'energy_inf_net': -1.8407552242279053, 'energy_ground_truth': -5.459286212921143, 'reg_losses_theta': 185.00437927246094, 'reg_losses_phi': 800.8876342773438, 'reg_losses_entropy': 0.0008199019939638674, 'pretrain_bias': 235.70335388183594}
update theta
{'cost_phi': 235.32118225097656, 'cost_theta': 1.3689002990722656, 'base_objective': 1.1838905811309814, 'base_obj_real': -1.5437723398208618, 'energy_inf_net': -1.8409357070922852, 'energy_ground_truth': -5.462491989135742, 'reg_losses_theta': 185.00975036621094, 'reg_losses_phi': 800.8876342773438, 'reg_losses_entropy': 0.0008199019939638674, 'pretrain_bias': 235.70335388183594}
before
{'cost_phi': 235.60498046875, 'cost_theta': 1.0851922035217285, 'base_objective': 0.9001824855804443, 'base_obj_real': -1.773045539855957, 'energy_inf_net': -2.5372066497802734, 'energy_ground_truth': -6

update phi
{'cost_phi': 287.334228515625, 'cost_theta': 0.8695760369300842, 'base_objective': 0.684491753578186, 'base_obj_real': -2.502622127532959, 'energy_inf_net': -2.1110377311706543, 'energy_ground_truth': -6.709946632385254, 'reg_losses_theta': 185.08425903320312, 'reg_losses_phi': 760.84130859375, 'reg_losses_entropy': 0.0011371155269443989, 'pretrain_bias': 287.2567443847656}
update theta
{'cost_phi': 287.3354187011719, 'cost_theta': 0.8683751225471497, 'base_objective': 0.6832839846611023, 'base_obj_real': -2.5055103302001953, 'energy_inf_net': -2.109485387802124, 'energy_ground_truth': -6.711281776428223, 'reg_losses_theta': 185.0911102294922, 'reg_losses_phi': 760.84130859375, 'reg_losses_entropy': 0.0011371155269443989, 'pretrain_bias': 287.2567443847656}
before
{'cost_phi': 286.71746826171875, 'cost_theta': 1.4861305952072144, 'base_objective': 1.301039457321167, 'base_obj_real': -1.1951936483383179, 'energy_inf_net': -1.0539398193359375, 'energy_ground_truth': -4.5338149

update theta
{'cost_phi': 336.7590637207031, 'cost_theta': 0.8728657364845276, 'base_objective': 0.6876125335693359, 'base_obj_real': -2.3772711753845215, 'energy_inf_net': -1.952147364616394, 'energy_ground_truth': -6.2753424644470215, 'reg_losses_theta': 185.253173828125, 'reg_losses_phi': 726.4801635742188, 'reg_losses_entropy': 0.0018505251500755548, 'pretrain_bias': 336.71832275390625}
before
{'cost_phi': 336.4315490722656, 'cost_theta': 1.199550747871399, 'base_objective': 1.014297604560852, 'base_obj_real': -3.136312961578369, 'energy_inf_net': -1.737684965133667, 'energy_ground_truth': -7.0505852699279785, 'reg_losses_theta': 185.253173828125, 'reg_losses_phi': 726.4801635742188, 'reg_losses_entropy': 0.0010445894440636039, 'pretrain_bias': 336.71832275390625}
update phi
{'cost_phi': 340.4871520996094, 'cost_theta': 1.206635594367981, 'base_objective': 1.021382451057434, 'base_obj_real': -3.134955406188965, 'energy_inf_net': -1.734800100326538, 'energy_ground_truth': -7.0505852

update theta
{'cost_phi': 388.7220153808594, 'cost_theta': 0.907557487487793, 'base_objective': 0.7221083641052246, 'base_obj_real': -2.7348711490631104, 'energy_inf_net': -2.2534780502319336, 'energy_ground_truth': -7.229969501495361, 'reg_losses_theta': 185.44912719726562, 'reg_losses_phi': 692.57763671875, 'reg_losses_entropy': 0.0018302170792594552, 'pretrain_bias': 388.7497253417969}
before
{'cost_phi': 388.69696044921875, 'cost_theta': 0.9317079782485962, 'base_objective': 0.7462588548660278, 'base_obj_real': -1.7808057069778442, 'energy_inf_net': -1.0916314125061035, 'energy_ground_truth': -4.7167277336120605, 'reg_losses_theta': 185.44912719726562, 'reg_losses_phi': 692.57763671875, 'reg_losses_entropy': 0.0009228029521182179, 'pretrain_bias': 388.7497253417969}
update phi
{'cost_phi': 392.5980224609375, 'cost_theta': 0.9317370653152466, 'base_objective': 0.7462879419326782, 'base_obj_real': -1.768691062927246, 'energy_inf_net': -1.106534719467163, 'energy_ground_truth': -4.716

before
{'cost_phi': 438.9075012207031, 'cost_theta': 0.5975261926651001, 'base_objective': 0.4117731750011444, 'base_obj_real': -3.227128267288208, 'energy_inf_net': -1.79839289188385, 'energy_ground_truth': -7.484766960144043, 'reg_losses_theta': 185.75299072265625, 'reg_losses_phi': 662.0377197265625, 'reg_losses_entropy': 0.0010442000348120928, 'pretrain_bias': 438.65618896484375}
update phi
{'cost_phi': 442.68865966796875, 'cost_theta': 0.5979669094085693, 'base_objective': 0.41221392154693604, 'base_obj_real': -3.230823040008545, 'energy_inf_net': -1.7864971160888672, 'energy_ground_truth': -7.484766960144043, 'reg_losses_theta': 185.75299072265625, 'reg_losses_phi': 659.7939453125, 'reg_losses_entropy': 0.0010187547886744142, 'pretrain_bias': 442.4400634765625}
update theta
{'cost_phi': 442.68975830078125, 'cost_theta': 0.5968936085700989, 'base_objective': 0.4111159145832062, 'base_obj_real': -3.2350776195526123, 'energy_inf_net': -1.7866640090942383, 'energy_ground_truth': -7.4

before
{'cost_phi': 486.1108703613281, 'cost_theta': 0.8460253477096558, 'base_objective': 0.6599858999252319, 'base_obj_real': -1.9833253622055054, 'energy_inf_net': -0.9771665930747986, 'energy_ground_truth': -4.854970932006836, 'reg_losses_theta': 186.03941345214844, 'reg_losses_phi': 634.5336303710938, 'reg_losses_entropy': 0.0009170790435746312, 'pretrain_bias': 486.1354064941406}
update phi
{'cost_phi': 489.65460205078125, 'cost_theta': 0.8468353748321533, 'base_objective': 0.6607959270477295, 'base_obj_real': -1.9681118726730347, 'energy_inf_net': -0.998907744884491, 'energy_ground_truth': -4.854970932006836, 'reg_losses_theta': 186.03941345214844, 'reg_losses_phi': 632.5296630859375, 'reg_losses_entropy': 0.0009413573425263166, 'pretrain_bias': 489.6819152832031}
update theta
{'cost_phi': 489.6558837890625, 'cost_theta': 0.8455739617347717, 'base_objective': 0.6595079898834229, 'base_obj_real': -1.9711475372314453, 'energy_inf_net': -0.9983985424041748, 'energy_ground_truth': -

update phi
{'cost_phi': 531.4177856445312, 'cost_theta': 0.6358431577682495, 'base_objective': 0.4494699239730835, 'base_obj_real': -3.0774316787719727, 'energy_inf_net': -2.295029401779175, 'energy_ground_truth': -7.4089837074279785, 'reg_losses_theta': 186.37319946289062, 'reg_losses_phi': 609.0543212890625, 'reg_losses_entropy': 0.0008928346796892583, 'pretrain_bias': 531.2572631835938}
update theta
{'cost_phi': 531.41845703125, 'cost_theta': 0.6351814270019531, 'base_objective': 0.44878485798835754, 'base_obj_real': -3.0804443359375, 'energy_inf_net': -2.2945611476898193, 'energy_ground_truth': -7.41152811050415, 'reg_losses_theta': 186.3965301513672, 'reg_losses_phi': 609.0543212890625, 'reg_losses_entropy': 0.0008928346796892583, 'pretrain_bias': 531.2572631835938}
before
{'cost_phi': 531.372314453125, 'cost_theta': 0.6811087131500244, 'base_objective': 0.4947122037410736, 'base_obj_real': -2.7003369331359863, 'energy_inf_net': -0.9514268040657043, 'energy_ground_truth': -5.36718

update phi
{'cost_phi': 574.2403564453125, 'cost_theta': 0.9403105974197388, 'base_objective': 0.7535959482192993, 'base_obj_real': -1.6196935176849365, 'energy_inf_net': -2.228994131088257, 'energy_ground_truth': -5.8391642570495605, 'reg_losses_theta': 186.71461486816406, 'reg_losses_phi': 584.6969604492188, 'reg_losses_entropy': 0.0004222320858389139, 'pretrain_bias': 574.4088134765625}
update theta
{'cost_phi': 574.2422485351562, 'cost_theta': 0.9384360313415527, 'base_objective': 0.7516823410987854, 'base_obj_real': -1.6241979598999023, 'energy_inf_net': -2.2285096645355225, 'energy_ground_truth': -5.843184471130371, 'reg_losses_theta': 186.75369262695312, 'reg_losses_phi': 584.6969604492188, 'reg_losses_entropy': 0.0004222320858389139, 'pretrain_bias': 574.4088134765625}
before
{'cost_phi': 573.1408081054688, 'cost_theta': 2.040365219116211, 'base_objective': 1.8536115884780884, 'base_obj_real': -1.9214661121368408, 'energy_inf_net': -1.1012948751449585, 'energy_ground_truth': -5