In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

%config Completer.use_jedi = False

%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import tensorflow as tf
import numpy as np
from sklearn import metrics
from sklearn.metrics import f1_score

import torch
import torch.nn as nn
import torch.nn.functional as func

import matplotlib.pyplot as plt
# t.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = False

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# data

In [119]:
!pwd

/home/xyang2/project/research/audioclf/infnet-spen


In [118]:
with open('data/bibtex/train.pickle', "rb") as f:
    temp = pickle.load(f)
    data_x = np.array([instance['feats'] for instance in temp])
    data_y = np.array([instance['types'] for instance in temp])


with open('data/bibtex/test.pickle', "rb") as f:
    temp = pickle.load(f)
    test_x = np.array([instance['feats'] for instance in temp])
    test_y = np.array([instance['types'] for instance in temp])


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_x = torch.FloatTensor(data_x).to(device)
data_y = torch.FloatTensor(data_y).to(device)
test_x = torch.FloatTensor(test_x).to(device)
test_y = torch.FloatTensor(test_y).to(device)


In [4]:
def f1_map(y, pred, threshold=None):
    if threshold is None:
        threshold = [0.05, 0.10, 0.15, 0.2, 0.25, 0.30, 0.35, 0.4, 0.45, 0.5, 0.55, 0.60, 0.65, 0.70, 0.75]
    else:
        threshold = [0.5]
    best_f1 = 0
    for t in threshold:
        local_pred = pred > t
        local_f1 = f1_score(y.data.cpu().numpy(), local_pred.data.cpu().numpy(), average='samples')
        if local_f1 > best_f1:
            best_f1 = local_f1
    precision = np.mean(metrics.average_precision_score(
        y.data.cpu().numpy(), pred.data.cpu().numpy(), average=None
    ))

    return best_f1, precision

# Arch s

In [5]:

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1836, 150)
        self.layer2 = nn.Linear(150, 150)      
    
    def forward(self, x):
        out = func.relu(self.layer1(x))
        out = func.relu(self.layer2(out))
        return out
    

In [8]:

feat_net = MLP().to(device)
energy_net = EnergyNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(list(feat_net.parameters()) + list(energy_net.parameters()), lr=1e-3, weight_decay=0)

for epoch in range(20):
    inds = np.random.permutation(list(range(len(data_x))))
    for i in range(0, 4880, 32):
        l = inds[i:i+32]
        data = data_x[l]
        label = data_y[l]
        feat = feat_net(data)
        _, logits = energy_net(feat, label)
        loss = criterion(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('epoch', epoch)
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)
    
    best_f1, mAP = f1_map(test_y, pred_test, 0.5)
    print(best_f1, mAP)
    print()

epoch 0
0.010612515383887152 0.03239640349215638

epoch 1
0.14521789895547352 0.0950389308766338

epoch 2
0.19718947498271552 0.14722309977371306

epoch 3
0.21832717977847202 0.19360340578231233

epoch 4
0.243371832497081 0.23144460515764315

epoch 5
0.26413997312605664 0.26450110012015443

epoch 6
0.28348104292438286 0.28794679487069746

epoch 7
0.30611502210764224 0.3012620973138103

epoch 8
0.3315541749590091 0.3099285562858436

epoch 9
0.3401557106699934 0.3119738275330926

epoch 10
0.3453505929613987 0.31223580932536793

epoch 11
0.3490278477271688 0.31247202677472735

epoch 12
0.35557245319409175 0.31088071957103347

epoch 13
0.36646074427124964 0.31017692004801145

epoch 14
0.3567290072391968 0.31047401359604065

epoch 15
0.36200591356304135 0.3072352959550528

epoch 16
0.3708267047730268 0.30712782849428355

epoch 17
0.35581687696001807 0.30150322231006416

epoch 18
0.3709799706004828 0.30132950077515375

epoch 19
0.3712019472620572 0.3008649561447038

0.39986122455594064 0.300

In [9]:
with torch.no_grad():
    feat = feat_net(test_x)
    print(feat)
    _, pred_test = energy_net(feat, test_y)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)

tensor([[ 5.4456,  7.1403, 12.4506,  ...,  4.6460,  4.8572,  4.4169],
        [ 8.0281,  0.9246, 10.6162,  ...,  8.0007,  6.3127,  0.8347],
        [ 1.9766,  5.8080,  3.7858,  ...,  3.8194,  1.6761,  4.8637],
        ...,
        [ 0.6951,  2.3739,  9.3303,  ...,  4.7144,  2.0970,  0.6466],
        [ 4.8289,  0.9531, 13.5373,  ...,  9.1066,  3.6389,  0.0000],
        [ 6.3641,  4.8654, 14.8546,  ..., 11.9688,  6.9226,  3.9216]],
       device='cuda:0')
0.39986122455594064 0.3008649561447038


In [10]:
class InfNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1836, 150)
        self.layer2 = nn.Linear(150, 150)
        self.layer3 = nn.Linear(150, 159, bias=False) 
    
    def forward(self, x):
        out = func.relu(self.layer1(x))
        out = func.relu(self.layer2(out))
        out = self.layer3(out)
        return torch.sigmoid(out), out
    

In [7]:
class EnergyNet(nn.Module):
    def __init__(self, weights_last_layer_mlp=150, feature_dim=150, label_dim=159,
                 num_pairwise=16, non_linearity=nn.Softplus()):
        super().__init__()

        self.non_linearity = non_linearity

        self.linear_wt = nn.Linear(150, label_dim, bias=False) 

        # Label energy terms, C1/c2  in equation 5 of SPEN paper
        self.C1 = nn.Linear(label_dim, num_pairwise)

        self.c2 = nn.Linear(num_pairwise, 1, bias=False)

    def forward(self, x, y):
        # Local energy
        negative_logits = self.linear_wt(x)
        feat_probs = torch.sigmoid(-1 * negative_logits)
        
        # element-wise product
        e_local = torch.mul(negative_logits, y)
        e_local = torch.sum(e_local, dim=1)

        # Label energy
        e_label = self.non_linearity(self.C1(y))
        e_label = self.c2(e_label)
        assert e_label.view(-1).shape[0] == e_label.shape[0]
        assert e_label.view(-1).shape[0] == e_local.shape[0]
        e_global = torch.add(e_label.view(-1), e_local.view(-1))

        return e_global, feat_probs

# train cost inf

In [17]:

def tf2torch(checkpoint, feat_net, inf_net, energy_net):
    
    tf_path = os.path.abspath(checkpoint)
    init_vars = tf.train.list_variables(tf_path)

    tf_vars = []
    for name, shape in init_vars:
        # print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        tf_vars.append((name, array.squeeze()))
    
    feat_i = 12
    energy_i = 8
    inf_i = 18
    feat_net.layer1.bias.data = torch.from_numpy(tf_vars[feat_i][1].T)
    feat_net.layer1.weight.data = torch.from_numpy(tf_vars[feat_i + 1][1].T)
    feat_net.layer2.bias.data = torch.from_numpy(tf_vars[feat_i + 2][1].T)
    feat_net.layer2.weight.data = torch.from_numpy(tf_vars[feat_i + 3][1].T)
    
    energy_net.C1.bias.data = torch.from_numpy(tf_vars[energy_i][1].T)
    energy_net.C1.weight.data = torch.from_numpy(tf_vars[energy_i + 1][1].T)
    energy_net.c2.weight.data = torch.from_numpy(tf_vars[energy_i + 2][1].T)
    energy_net.linear_wt.weight.data = torch.from_numpy(tf_vars[energy_i + 3][1].T)


    inf_net.layer1.bias.data = torch.from_numpy(tf_vars[inf_i][1].T)
    inf_net.layer1.weight.data = torch.from_numpy(tf_vars[inf_i + 1][1].T)
    inf_net.layer2.bias.data = torch.from_numpy(tf_vars[inf_i + 2][1].T)
    inf_net.layer2.weight.data = torch.from_numpy(tf_vars[inf_i + 3][1].T)
    inf_net.layer3.weight.data = torch.from_numpy(tf_vars[inf_i + 4][1].T)
    
    return feat_net.to(device), inf_net.to(device), energy_net.to(device)

feat_net1 = MLP()
inf_net1 = InfNet()
energy_net1 = EnergyNet()
feat_net1, inf_net1, energy_net1 = tf2torch('./copied.ckpt', feat_net1, inf_net1, energy_net1)

with torch.no_grad():
    pred_test, _ = inf_net1(test_x)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)


with torch.no_grad():
    feat = feat_net1(test_x)
    _, pred_test = energy_net1(feat, test_y)

f1, mAP = f1_map(test_y, pred_test)
print(f1, mAP)

0.4105516296229916 0.3304564718587311
0.4105516296229916 0.3304564718587311


# print different

In [12]:
def print_model(feat_net, inf_net, energy_net):
    for t in feat_net.named_parameters():
        print(t[0], ':', '%.6f' % t[1].sum().item())
    for t in inf_net.named_parameters():
        print(t[0], ':', '%.6f' % t[1].sum().item())
    for t in energy_net.named_parameters():
        print(t[0], ':', '%.6f' % t[1].sum().item())

        
def print_model_grad(feat_net, inf_net, energy_net):
    for t in feat_net.named_parameters():
        print('feat_net/' + t[0], ':', '%9.6f' % t[1].grad.sum().item(), '%.6f' % t[1].sum().item())
    for t in inf_net.named_parameters():
        print('infer_net/' + t[0], ':', '%9.6f' % t[1].grad.sum().item(), '%.6f' % t[1].sum().item())
    for t in energy_net.named_parameters():
        print('energy_net/' + t[0], ':', '%9.6f' % t[1].grad.sum().item(), '%.6f' % t[1].sum().item())

def print_energy_grad(energy_net):
    for t in energy_net.named_parameters():
        print('energy_net/' + t[0], ':', '%9.6f' % t[1].grad.sum().item(), '%.6f' % t[1].sum().item())
        
def print_grads(grads):
    for g in grads:
        print('grads: ', g.sum().item())
        
def print_summary(summary, j):
    print('iter %d:  cost_phi %.7f' % (j, summary['infer cost']))
    # print('cost_theta %.7f' % summary['energy cost'])
    # print('base_obj %.7f' % summary['base_objective'])

# SPEN

In [41]:
i = 0
l = np.arange(i, i+32)
data = data_x[l]
label = data_y[l]

i = 32
l = np.arange(i, i+32)
data2 = data_x[l]
label2 = data_y[l]

In [15]:
class EnergyNet(nn.Module):
    def __init__(self, weights_last_layer_mlp=150, feature_dim=150, label_dim=159,
                 num_pairwise=16, non_linearity=nn.Softplus()):
        super().__init__()

        self.non_linearity = non_linearity

        self.linear_wt = nn.Linear(150, label_dim, bias=False) 

        # Label energy terms, C1/c2  in equation 5 of SPEN paper
        self.C1 = nn.Linear(label_dim, num_pairwise)

        self.c2 = nn.Linear(num_pairwise, 1, bias=False)

    def forward(self, x, y):
        # Local energy
        negative_logits = self.linear_wt(x)
        feat_probs = torch.sigmoid(-1 * negative_logits)
        
        # element-wise product
        e_local = torch.mul(negative_logits, y)
        e_local = torch.sum(e_local, dim=1)

        # Label energy
        e_label = self.non_linearity(self.C1(y))
        e_label = self.c2(e_label)
        assert e_label.view(-1).shape[0] == e_label.shape[0]
        assert e_label.view(-1).shape[0] == e_local.shape[0]
        e_global = torch.add(e_label.view(-1), e_local.view(-1))

        return e_global, feat_probs

In [111]:
class SPEN():
    def __init__(self, feature_net, energy_net, inf_net, n_steps_inf=1, input_dim=1836, label_dim=159):
        self.feature_extractor = feature_net
        self.feature_extractor.eval()
        self.energy_net = energy_net
        self.inf_net = inf_net
        
        self.phi0 = InfNet().to(device)
        self.phi0.load_state_dict(inf_net.state_dict())
            
    def compute_loss(self, inputs, targets):
        f_x = self.feature_extractor(inputs).detach()
        
        # Energy ground truth
        gt_energy, _ = self.energy_net(f_x, targets)
        
        # Cost-augmented inference network
        pred_probs, logits = self.inf_net(inputs)
        
        pred_energy, _ = self.energy_net(f_x, pred_probs)
        
        # Max-margin Loss
        diff = torch.sum((pred_probs - targets)**2, dim=1)
        gt_en = gt_energy
        inf_en = pred_energy
        pre_loss_real = diff  - inf_en + gt_en 
        # pre_loss_real = diff - inf_en + gt_en
        # pre_loss_real = - inf_en + gt_en
        
        energy_loss = torch.relu(pre_loss_real)
        pre_loss_real = torch.mean(pre_loss_real)
        energy_loss = torch.mean(energy_loss)

        entropy_loss = nn.BCELoss()(pred_probs, pred_probs.detach())
#         entropy_loss = func.binary_cross_entropy_with_logits(logits, pred_probs.detach())
        
        reg_losses_phi = 0.5 * sum(p.pow(2.0).sum() for p in self.inf_net.parameters())
        
        pretrain_bias = sum((x - y).pow(2.0).sum() for x, y in zip(list(self.inf_net.parameters()), self.phi0.state_dict().values()))
        
        reg_losses_theta = 0.5 * sum(p.pow(2.0).sum() for p in self.energy_net.parameters())
        
        inf_net_loss = energy_loss  \
                       - 0.001 * reg_losses_phi \
                       - 1 * pretrain_bias  #  \
                       # - 1 * entropy_loss 
        
        inf_net_loss = -inf_net_loss
        
        e_net_loss = energy_loss + 0.001 * reg_losses_theta
        
        summaries = {
            'infer cost': inf_net_loss,
            'energy cost': e_net_loss,
            'base_objective': energy_loss,
            'base_obj_real': pre_loss_real,
            'energy_inf_net': pred_energy.mean(),
            'energy_ground_truth': gt_energy.mean(),
            'reg_losses_theta': reg_losses_theta,
            'reg_losses_phi': reg_losses_phi,
            'reg_losses_entropy': entropy_loss,
            'pretrain_bias': pretrain_bias
        }
        
        return pred_probs, e_net_loss, inf_net_loss, summaries

    def pred(self, x):
        with torch.no_grad():
            y_pred, _ = self.inf_net(x)
        return y_pred
    
    def inference_loss(self, x):

        f_x = self.feature_extractor(x)
        # inference network
        pred_probs, logits = self.inf_net(x)
        pred_energy, _ = self.energy_net(f_x, pred_probs)
        
        entropy_loss = func.binary_cross_entropy_with_logits(logits, pred_probs.detach())
        reg_losses_phi = 0.5 * sum(p.pow(2.0).sum() for p in self.inf_net.parameters())
        
        inf_net_loss = torch.mean(pred_energy)  \
                       + 0.001 * reg_losses_phi \
                       + 1 * entropy_loss
        
        return inf_net_loss

# step debug

In [68]:
feat_net.zero_grad()
optim_e.zero_grad()
optim_inf.zero_grad()
preds, e_loss, inf_loss, summary = spen.compute_loss(data2, label2)
print_summary(summary, 0)
inf_loss.backward()
print('compute inf net gradients')
print_model_grad(feat_net, inf_net, energy_net)

iter 0:  cost_phi 1.1570342
compute inf net gradients
feat_net/layer1.weight : 17.899414 766.057251
feat_net/layer1.bias :  0.196697 17.295366
feat_net/layer2.weight : 19.131044 502.172089
feat_net/layer2.bias :  0.136933 6.113752
infer_net/layer1.weight :  7.929173 766.511963
infer_net/layer1.bias :  0.208551 17.356606
infer_net/layer2.weight :  9.020711 502.805725
infer_net/layer2.bias :  0.088460 6.128996
infer_net/layer3.weight : -10.122227 -720.555420
energy_net/linear_wt.weight : 19.223404 719.875610
energy_net/C1.weight :  0.004540 2.112450
energy_net/C1.bias : -0.000532 -0.000042
energy_net/c2.weight :  0.002886 -0.030406


In [54]:
summary

{'infer cost': tensor(1.1570, device='cuda:0', grad_fn=<NegBackward>),
 'energy cost': tensor(0.2115, device='cuda:0', grad_fn=<AddBackward0>),
 'base_objective': tensor(0.0247, device='cuda:0', grad_fn=<MeanBackward0>),
 'base_obj_real': tensor(-1.7291, device='cuda:0', grad_fn=<MeanBackward0>),
 'energy_inf_net': tensor(-3.8876, device='cuda:0', grad_fn=<MeanBackward0>),
 'energy_ground_truth': tensor(-5.9633, device='cuda:0', grad_fn=<MeanBackward0>),
 'reg_losses_theta': tensor(186.8051, device='cuda:0', grad_fn=<MulBackward0>),
 'reg_losses_phi': tensor(1181.3073, device='cuda:0', grad_fn=<MulBackward0>),
 'reg_losses_entropy': tensor(0.0234, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>),
 'pretrain_bias': tensor(0.0004, device='cuda:0')}

# all

In [106]:
feat_net = MLP()
inf_net = InfNet()
energy_net = EnergyNet()
feat_net, inf_net, energy_net = tf2torch('./copied.ckpt', feat_net, inf_net, energy_net)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)
optim_e  = torch.optim.Adam(list(inf_net.parameters()) + list(energy_net.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print('inf net start', best_f1, mAP)

phi_energies = []
theta_energies = []
f1s = [best_f1]
for epoch in range(10):

    for j, i in enumerate(range(0, 4880, 32)):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        # print(data.sum())
        label = data_y[l]
        
        optim_e.zero_grad()
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)

        inf_loss.backward()

        optim_inf.step()

        
        optim_e.zero_grad()
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)
        
        e_loss.backward()

        optim_energy.step()

    print(epoch)
    
    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print('current inf net', best_f1, mAP)
    
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)

    f1, mAP = f1_map(test_y, pred_test, 0.5)
    print('feature net', f1, mAP)
    print()


inf net start 0.4105516296229916 0.33045646957084107
0
current inf net 0.3951645406964536 0.3183652439194566
feature net 0.37872409762358 0.33045646957084107

1
current inf net 0.39884343141395623 0.32234328773344706
feature net 0.37872409762358 0.33045646957084107

2
current inf net 0.4027456771023543 0.3243291427038416
feature net 0.37872409762358 0.33045646957084107

3
current inf net 0.4066585445422124 0.32528390548265845
feature net 0.37872409762358 0.33045646957084107

4
current inf net 0.4061004560066561 0.32584806409754813
feature net 0.37872409762358 0.33045646957084107

5
current inf net 0.4076676155627696 0.3269871627331186
feature net 0.37872409762358 0.33045646957084107

6
current inf net 0.4081970385611802 0.32743184774168776
feature net 0.37872409762358 0.33045646957084107

7
current inf net 0.4088775479476274 0.3282031100263219
feature net 0.37872409762358 0.33045646957084107

8
current inf net 0.40689603730262164 0.32757802909668715
feature net 0.37872409762358 0.33045

In [115]:
feat_net = MLP()
inf_net = InfNet()
energy_net = EnergyNet()
feat_net, inf_net, energy_net = tf2torch('./copied.ckpt', feat_net, inf_net, energy_net)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)
optim_e  = torch.optim.Adam(list(inf_net.parameters()) + list(energy_net.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print('inf net start', best_f1, mAP)

phi_energies = []
theta_energies = []
f1s = [best_f1]
for epoch in range(10):

    for j, i in enumerate(range(0, 4880, 32)):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        # print(data.sum())
        label = data_y[l]
        
        optim_e.zero_grad()
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)

        inf_loss.backward(retain_graph=True)

        optim_inf.step()

        optim_e.zero_grad()        
        e_loss.backward()
        optim_energy.step()

    print(epoch)
    
#     pred_test = spen.pred(test_x)
#     best_f1, mAP = f1_map(test_y, pred_test)
#     print('current inf net', best_f1, mAP)
    
#     with torch.no_grad():
#         feat = feat_net(test_x)
#         _, pred_test = energy_net(feat, test_y)

#     f1, mAP = f1_map(test_y, pred_test, 0.5)
#     print('feature net', f1, mAP)
    print()


inf net start 0.4105516296229916 0.33045646957084107
0

1

2

3

4

5

6

7

8

9



In [116]:
feat_net = MLP()
inf_net = InfNet()
energy_net = EnergyNet()
feat_net, inf_net, energy_net = tf2torch('./copied.ckpt', feat_net, inf_net, energy_net)

optim_inf = torch.optim.Adam(inf_net.parameters(), lr=1e-3, weight_decay=0)
optim_energy = torch.optim.Adam(list(energy_net.C1.parameters()) + list(energy_net.c2.parameters()), 
                                lr=1e-3, weight_decay=0)
optim_e  = torch.optim.Adam(list(inf_net.parameters()) + list(energy_net.parameters()), 
                                lr=1e-3, weight_decay=0)

spen = SPEN(feat_net, energy_net, inf_net)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print('inf net start', best_f1, mAP)

phi_energies = []
theta_energies = []
f1s = [best_f1]
for epoch in range(10):

    for j, i in enumerate(range(0, 4880, 32)):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        # print(data.sum())
        label = data_y[l]
        
        optim_e.zero_grad()
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)

        inf_loss.backward()

        optim_inf.step()

        
        optim_e.zero_grad()
        optim_inf.zero_grad()
        preds, e_loss, inf_loss, summary = spen.compute_loss(data, label)
        
        e_loss.backward()

        optim_energy.step()

    print(epoch)
    
#     pred_test = spen.pred(test_x)
#     best_f1, mAP = f1_map(test_y, pred_test)
#     print('current inf net', best_f1, mAP)
    
#     with torch.no_grad():
#         feat = feat_net(test_x)
#         _, pred_test = energy_net(feat, test_y)

#     f1, mAP = f1_map(test_y, pred_test, 0.5)
#     print('feature net', f1, mAP)
    print()


inf net start 0.4105516296229916 0.33045646957084107
0

1

2

3

4

5

6

7

8

9



# phrase 3

In [107]:

optimizer = torch.optim.Adam(spen.inf_net.parameters(), lr=0.00001, weight_decay=0)
pred_test = spen.pred(test_x)
best_f1, mAP = f1_map(test_y, pred_test)
print('inf net start', best_f1, mAP)

phi_energies = []
theta_energies = []
f1s = [best_f1]
for epoch in range(10):

    for j, i in enumerate(range(0, 4880, 32)):
        if i+32 > 4880:
            i = 0
        l = np.arange(i, i+32)
        data = data_x[l]
        label = data_y[l]
        
        optimizer.zero_grad()
        inf_loss = spen.inference_loss(data)

        inf_loss.backward()

        optimizer.step()

    print(epoch)
    
    pred_test = spen.pred(test_x)
    best_f1, mAP = f1_map(test_y, pred_test)
    print('current inf net', best_f1, mAP)
    
    with torch.no_grad():
        feat = feat_net(test_x)
        _, pred_test = energy_net(feat, test_y)

    f1, mAP = f1_map(test_y, pred_test, 0.5)
    print('feature net', f1, mAP)
    print()


inf net start 0.4077304623147065 0.3272108909447864
0
current inf net 0.4161842068681086 0.32669033678144255
feature net 0.37872409762358 0.33045646957084107

1
current inf net 0.418357964342694 0.32651885712400663
feature net 0.37872409762358 0.33045646957084107

2
current inf net 0.4191447379841899 0.3270761622976327
feature net 0.37872409762358 0.33045646957084107

3
current inf net 0.4196247764621689 0.3271073043555753
feature net 0.37872409762358 0.33045646957084107

4
current inf net 0.41921874467094966 0.32723173381353665
feature net 0.37872409762358 0.33045646957084107

5
current inf net 0.4181720907474342 0.32710549376776443
feature net 0.37872409762358 0.33045646957084107

6
current inf net 0.4187331102869359 0.32743135824273834
feature net 0.37872409762358 0.33045646957084107

7
current inf net 0.4178384834864306 0.32773200635445077
feature net 0.37872409762358 0.33045646957084107

8
current inf net 0.41696608064906815 0.32776618015913955
feature net 0.37872409762358 0.33045

# only inf