In [1]:
%load_ext autoreload
%autoreload 2

In [9]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from C2AE import C2AE, save_model, load_model, Fe, Fx, Fd, eval_metrics

from sklearn.model_selection import train_test_split
from sklearn.metrics import hamming_loss, accuracy_score, f1_score, precision_score, recall_score
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from skmultilearn.dataset import load_dataset

In [10]:
device = torch.device('cuda')
train_x, train_y, feat_names, label_names = load_dataset('mediamill', 'train')
test_x, test_y, _, _ = load_dataset('mediamill', 'test')

train_dataset = TensorDataset(torch.tensor(train_x.todense(), device=device, dtype=torch.float),torch.tensor(train_y.todense(), device=device,dtype=torch.float))
test_dataset = TensorDataset(torch.tensor(test_x.todense(), device=device, dtype=torch.float), torch.tensor(test_y.todense(), device=device, dtype=torch.float))

mediamill:train - exists, not redownloading
mediamill:test - exists, not redownloading


In [11]:
train_dataset[:][0].shape, train_dataset[:][1].shape, test_dataset[:][0].shape, test_dataset[:][1].shape

(torch.Size([30993, 120]),
 torch.Size([30993, 101]),
 torch.Size([12914, 120]),
 torch.Size([12914, 101]))

### Metrics:

In [12]:
def micro_r(y_t, y_p):
    return recall_score(y_t, y_p, average='micro')
def macro_r(y_t, y_p):
    return recall_score(y_t, y_p, average='macro')
def micro_p(y_t, y_p):
    return precision_score(y_t, y_p, average='micro')
def macro_p(y_t, y_p):
    return precision_score(y_t, y_p, average='macro')
def micro_f1(y_t, y_p):
    return f1_score(y_t, y_p, average='micro')
def macro_f1(y_t, y_p):
    return f1_score(y_t, y_p, average='macro')
def ham_los(*args, **kwargs):
    return hamming_loss(*args, **kwargs)

In [13]:
# Training configs.
num_epochs = 1000
batch_size = 32
lr = 0.001
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# # Scene config
feat_dim = 120
num_labels = 101
latent_dim = 50
fx_h_dim=110
fe_h_dim=105
fd_h_dim=105

# Scene models.
Fx_tmc = Fx(feat_dim, fx_h_dim, fx_h_dim, latent_dim)
Fe_tmc = Fe(num_labels, fe_h_dim, latent_dim)
Fd_tmc = Fd(latent_dim, fd_h_dim, num_labels, fin_act=torch.sigmoid)
               
# Initializing net.
net = C2AE(Fx_tmc, Fe_tmc, Fd_tmc, beta=0.01, alpha=40, emb_lambda=0.01, latent_dim=latent_dim, device=device)
net = net.to(device)


# Doing weight_decay here is eqiv to adding the L2 norm.
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
writer = SummaryWriter(comment='mediamill-4c2ae')

In [14]:
print("Starting training!")
best_loss = np.infty
for epoch in range(num_epochs+1): 
    # Training.
    net.train()
    loss_tracker = 0.0
    latent_loss_tracker = 0.0
    cor_loss_tracker = 0.0
    for x, y in train_dataloader:
        optimizer.zero_grad()      

        # Pass x, y to network. Retrieve both encodings, and decoding of ys encoding.
        fx_x, fe_y, fd_z = net(x, y)
        # Calc loss.
        l_loss, c_loss = net.losses(fx_x, fe_y, fd_z, y)
        # Normalize losses by batch.
        l_loss /= x.shape[0]
        c_loss /= x.shape[0]
        loss = net.beta*l_loss + net.alpha*c_loss
        loss.backward()
        optimizer.step()
        
        loss_tracker+=loss.item()
        latent_loss_tracker+=l_loss.item()
        cor_loss_tracker+=c_loss.item()
    writer.add_scalar('train/loss', loss_tracker, epoch)
    writer.add_scalar('train/latent_loss', latent_loss_tracker, epoch)
    writer.add_scalar('train/corr_loss', cor_loss_tracker, epoch)
    
    # Evaluation
    net.eval()
    loss_tracker = 0.0
    latent_loss_tracker = 0.0
    cor_loss_tracker = 0.0
    acc_track = 0.0
    for x, y in test_dataloader:
        # evaluation only requires x. As its just Fd(Fx(x))
        fx_x, fe_y = net.Fx(x), net.Fe(y)
        fd_z = net.Fd(fx_x)

        l_loss, c_loss = net.losses(fx_x, fe_y, fd_z, y)
        # Normalize losses by batch.
        l_loss /= x.shape[0]
        c_loss /= x.shape[0]
        loss = net.beta*l_loss + net.alpha*c_loss
        
        latent_loss_tracker += l_loss.item()
        cor_loss_tracker += c_loss.item()
        loss_tracker += loss.item()
        lab_preds = torch.round(net.Fd(net.Fx(x))).cpu().detach().numpy()
        
    print(f"Epoch: {epoch}, Loss: {loss_tracker},  L-Loss: {latent_loss_tracker}, C-Loss: {cor_loss_tracker}")
    torch.save(net.state_dict(), f'./models/mediamill/c2ae/v4_2{epoch}.pt')
    writer.add_scalar('val/loss', loss_tracker, epoch)
    writer.add_scalar('val/latent_loss', latent_loss_tracker, epoch)
    writer.add_scalar('val/corr_loss', cor_loss_tracker, epoch)
    
    # Log metrics on whole dataset.
    mets = eval_metrics(net, [ham_los, accuracy_score, micro_f1, micro_p, micro_r, macro_f1, macro_p, macro_r], 
                        [test_dataset, train_dataset], torch.device('cuda'))
    for k, v in mets['dataset_1'].items():
        writer.add_scalar(f'train/{k}', v, epoch)
    for k, v in mets['dataset_0'].items():
        writer.add_scalar(f'val/{k}', v, epoch)

Starting training!
Epoch: 0, Loss: 9605.383268356323,  L-Loss: 3741.042100429535, C-Loss: 239.19932079315186


  _warn_prf(average, modifier, msg_start, len(result))


Epoch: 1, Loss: 9789.852140426636,  L-Loss: 2743.5333018302917, C-Loss: 244.06042009592056
Epoch: 2, Loss: 10016.533346176147,  L-Loss: 2283.5240881443024, C-Loss: 249.84245246648788
Epoch: 3, Loss: 9950.966569900513,  L-Loss: 1715.6772162914276, C-Loss: 248.34524548053741
Epoch: 4, Loss: 9896.89337348938,  L-Loss: 1516.278756260872, C-Loss: 247.04326450824738
Epoch: 5, Loss: 9978.80080986023,  L-Loss: 1355.9878435134888, C-Loss: 249.131023645401
Epoch: 6, Loss: 9980.269666671753,  L-Loss: 1210.139478445053, C-Loss: 249.20420688390732
Epoch: 7, Loss: 9951.54649925232,  L-Loss: 1157.1998167037964, C-Loss: 248.49936228990555
Epoch: 8, Loss: 10125.19736289978,  L-Loss: 936.8685793876648, C-Loss: 252.8957164287567
Epoch: 9, Loss: 9956.332695007324,  L-Loss: 807.5172189474106, C-Loss: 248.70643770694733
Epoch: 10, Loss: 9878.29787826538,  L-Loss: 771.1928068399429, C-Loss: 246.76464849710464
Epoch: 11, Loss: 10346.298871994019,  L-Loss: 721.2692561149597, C-Loss: 258.47715455293655
Epoch: 1

### Picking best model

In [21]:
eval_net = load_model(C2AE, './models/mediamill/c2ae/v4_2950.pt', Fx=Fx_tmc, Fe=Fe_tmc, Fd=Fd_tmc, device=device).to(device)

In [22]:
mets = eval_metrics(eval_net, [ham_los, accuracy_score, micro_f1, micro_p, micro_r, macro_f1, macro_p, macro_r], [test_dataset, train_dataset], device)
mets

{'dataset_0': {'ham_los': 0.03430385628000619,
  'accuracy_score': 0.07573176397707913,
  'micro_f1': 0.5826874469533748,
  'micro_p': 0.620816440098577,
  'micro_r': 0.5489710198414791,
  'macro_f1': 0.1278815755808191,
  'macro_p': 0.21369858766535993,
  'macro_r': 0.11195772712532844},
 'dataset_1': {'ham_los': 0.030500020285640993,
  'accuracy_score': 0.09515051785887135,
  'micro_f1': 0.6269701729298044,
  'micro_p': 0.6646013667425968,
  'micro_r': 0.5933721351605198,
  'macro_f1': 0.1810227401352932,
  'macro_p': 0.3406434281722783,
  'macro_r': 0.15087116899792413}}

In [19]:
eval_net = load_model(C2AE, './models/mediamill/c2ae/v4_2430.pt', Fx=Fx_tmc, Fe=Fe_tmc, Fd=Fd_tmc, device=device).to(device)

In [20]:
mets = eval_metrics(eval_net, [ham_los, accuracy_score, micro_f1, micro_p, micro_r, macro_f1, macro_p, macro_r], [test_dataset, train_dataset], device)
mets

{'dataset_0': {'ham_los': 0.031299978379439305,
  'accuracy_score': 0.08432708688245315,
  'micro_f1': 0.5695608624598029,
  'micro_p': 0.7118384988403964,
  'micro_r': 0.47468410045517656,
  'macro_f1': 0.10003849766216992,
  'macro_p': 0.19545522979837945,
  'macro_r': 0.07995643575974129},
 'dataset_1': {'ham_los': 0.029350926574605,
  'accuracy_score': 0.08508372858387378,
  'micro_f1': 0.5964554738113539,
  'micro_p': 0.7343766899565208,
  'micro_r': 0.5021483985001886,
  'macro_f1': 0.14216146757861564,
  'macro_p': 0.3395320937990518,
  'macro_r': 0.10828125431332984}}