# WELCOME

In [1]:
# import here
import os
import sys
sys.path.append("../")
from src.pytorch_cl_vae.model import ResNet18Classifier 
from sklearn.metrics import accuracy_score
import torch
from torch.nn import CrossEntropyLoss
from matplotlib import pyplot as plt
%matplotlib inline
# params here
params = {
    attributes: ['Black', '2', '3', '4'],
    classes_dim: [2,2,2,2,2],
    learning_rate: 5e-4,
    num_epochs: 10,
    batch_size: 128
}

## 1 - Load Dataset

In [None]:
data_loader = None

## 2 - Create Classifiers

In [None]:
classifiers = []
optimizers = []
for name, classes_dim in zip(params['attributes'], params['classes_dim']):
    classifier = ResNet18Classifier(label_dim=params['classes_dim'])
    optimizer = torch.optim.Adam(classifier.parameters(), lr=params['learning_rate'])
    
    classifies.append(classifier)
    optimizers.append(optimizer)

## 3 - Train 

In [None]:
num_train_samples = 0

train_step = 0
train_losses = []
train_accuracies = []
for epoch in range(params['num_epochs']):
    for x_batch, ws_batch in data_loader:
        train_step += 1
        losses = []
        accuracies = []
        for label_index, (classifier, optim) in enumerate(zip(classifiers, optimizers)):
            
            optim.zero_grad()
            classifier.train()
            predictions = classifier(x_batch)
            w_true = ws_batch[label_index]
            loss = CrossEntropyLoss()(predictions, w_true)
            loss.backward()
            optim.step()
            losses.append(loss)
            labels = w_true.max(1)[1].squeeze()
            labels_predict = predictions.max(1)[1].squeeze()
            acc = accuracy_score(labels, labels_predict)
            accuracies.append(acc)
            
        train_losses.append(losses)
        train_accuracies.append(accuracies)
        loss_dict = dict(zip(params['attributes'], losses))
        acc_dict = dict(zip(params['attributes'], accuracies))

        print("\r|progress: {:.2f}% | train step: {} | losses: {} | accuracies: {} |"
              " | w_dkl loss: {:.4f} | class_accuracy: {:.4f} |".format(
            100.* train_step_i / (num_train_samples // params['batch_size'] * params['num_epochs']), train_step_i,
                  loss_dict, acc_dict), end='')
        if train_step_i % 100 == 0:
            print()

## 4 - Plot results

In [None]:
train_losses = np.array(train_losses)
train_accuracies = np.array(train_accuracies)

for i, attr in enumerate(params['attributes']):
    plt.plot(train_losses[:, i])
    plt.title(attr)
    plt.ylabel('loss')
    plt.show()
    
    plt.plot(train_accuracies[:, i])
    plt.title(attr)
    plt.ylabel('accuracy')
    plt.show()

## 5 - Save models