# Test disentanglement

Test if we can infer about certain type of competences using only the desired part of latent space

In [1]:
%load_ext lab_black

### Imports

In [2]:
import os
import yaml
import random
from copy import deepcopy

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score

from model.encoder import CandidateEncoderConfig
from model.decoder import CandidateDecoderConfig
from model.candidate_vae import CandidateVAE
from trainer.trainer import TrainerConfig
from config.general_config import GeneralConfig
from dataset.dataset import SellersDataset

In [3]:
EXPERIMENT = "candidate_vae_04_06_22_01_46_39"
CHECKPOINT = "7506_checkpoint.tar"
# If false, we can used cached content e.g. if we are testing the code
CREATE_DATASET = True
LR = 0.001
PATIENCE = 10
VALIDATE_EVERY = 50
LABEL_BATCH_SIZE = 16

### Constants

In [4]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

with open(os.path.join("checkpoints", EXPERIMENT, "config.yaml"), "r") as file:
    try:
        config = yaml.safe_load(file)["vae"]
    except yaml.YAMLError as exc:
        print(exc)

general_config = GeneralConfig(**config["general"])


encoder_config = CandidateEncoderConfig(**{**config["encoder"], **config["general"]})

decoder_config = CandidateDecoderConfig(**{**config["decoder"], **config["general"]})

trainer_config = TrainerConfig(**{**config["trainer"], **config["general"]})

log_dir = os.path.join(general_config.checkpoints_dir, "runs")

os.makedirs(log_dir, exist_ok=True)

writer_tensorboard = SummaryWriter(log_dir)

In [5]:
# %reload_ext tensorboard
# %tensorboard --logdir $log_dir --port=6008

In [6]:
dataset = SellersDataset(
    dataset_path=general_config.datset_path,
    test_index=general_config.test_index,
    embedder_name=general_config.embedder_name,
    raw_data_path=general_config.raw_data_path,
    device=DEVICE,
    bow_remove_stopwords=general_config.bow_remove_stopwords,
    bow_remove_sentiment=general_config.bow_remove_sentiment,
    nn_embedding_size=encoder_config.lstm_hidden_dim,
    trim_tr=general_config.trim_tr,
)
# dataset.prepare_dataset()
dataset.load_dataset()

Loading dataset data/dataset_3/...
[2022-06-04 17:33:40,260] {dataset.py:260} INFO - Loading dataset data/dataset_3/...
Loaded dataset data/dataset_3/!
[2022-06-04 17:33:40,777] {dataset.py:288} INFO - Loaded dataset data/dataset_3/!




# Load VAE

In [7]:
checkpoint = torch.load(os.path.join("checkpoints", EXPERIMENT, CHECKPOINT))

candidate_vae = CandidateVAE(
    general_config, encoder_config, decoder_config, dataset.vocab, dataset.embedder
).to(DEVICE)

candidate_vae.encoder.load_state_dict(checkpoint["encoder"])
candidate_vae.decoder.load_state_dict(checkpoint["decoder"])
candidate_vae.embedding.load_state_dict(checkpoint["embedding"]) if checkpoint[
    "embedding"
] else None

In [8]:
candidate_vae

CandidateVAE(
  (encoder): CandidateEncoder(
    (lstm): LSTM(100, 256, bidirectional=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (relu): ReLU()
    (fcs): ModuleList(
      (0): Linear(in_features=512, out_features=256, bias=True)
    )
    (fc_mu): Linear(in_features=256, out_features=256, bias=True)
    (fc_var): Linear(in_features=256, out_features=256, bias=True)
  )
  (decoder): CandidateDecoder(
    (lstm): LSTM(100, 256)
    (dropout): Dropout(p=0.1, inplace=False)
    (relu): ReLU()
    (fcs): ModuleList(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): Linear(in_features=256, out_features=100, bias=True)
    )
    (attn): Attn()
    (attn_mu): Identity()
    (attn_var): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): Tanh()
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
    (concat): Linear(in_features=512, out_features=256, bias=True)
    (out): Linear(in_features=256, out_feat

# Prepare train test data

In [9]:
def prepare_train_batch(dataset: SellersDataset, batch_size: int) -> dict:
    factor_labels = dataset.string_keys
    source_idx = random.randint(0, len(dataset) - 1)
    fixed_factor_idx = random.randint(0, len(factor_labels) - 1)
    fixed_factor_label = factor_labels[fixed_factor_idx]
    fixed_factor = dataset._load_row(source_idx)

    inputs = []
    for _ in range(batch_size * 2):
        x = deepcopy(fixed_factor)
        z_s = []

        for column in [fact for fact in factor_labels if fact != fixed_factor_label]:
            source = random.randint(0, len(dataset) - 1)
            x[column] = dataset.get_column_by_idx(source, column)

        x_str = dataset._create_textual_decription(x)
        x_emb = dataset.embedder(x_str)[0].cpu()
        inputs.append((x_emb, torch.tensor(len(x_emb))))

    inputs.sort(key=lambda x: x[1], reverse=True)

    input_tensors, input_lens = zip(*inputs)
    input_pad = pad_sequence(input_tensors, padding_value=dataset.vocab.pad_token)

    with torch.no_grad():
        mu, var, outputs, (hn, cn) = candidate_vae.encoder(
            input_pad.to(DEVICE), input_lens
        )
        z_s = candidate_vae.decoder.reparameterize(mu, var)
        latent_dim = z_s.shape[1]
        z_s = z_s.view(batch_size, -1)
        z_s = torch.abs(z_s[:, :latent_dim] - z_s[:, latent_dim:])

    return z_s.mean(dim=0), fixed_factor_idx


def prepare_test_batch(dataset: SellersDataset, batch_size: int):
    # We have to set both seeds!!!
    # rng = np.random.default_rng(42)
    # random.seed(42)

    factor_labels = dataset.string_keys
    source_idx = random.randint(0, dataset.test_size - 1)
    fixed_factor_idx = random.randint(0, len(factor_labels) - 1)
    fixed_factor_label = factor_labels[fixed_factor_idx]
    fixed_factor = dataset.test_dataset.iloc[source_idx]

    inputs = []

    for _ in range(batch_size * 2):
        x = deepcopy(fixed_factor)
        z_s = []

        for column in [fact for fact in factor_labels if fact != fixed_factor_label]:
            source = random.randint(0, dataset.test_size - 1)
            x[column] = dataset.get_column_by_idx(source, column)

        x_str = dataset._create_textual_decription(x)
        x_emb = dataset.embedder(x_str)[0].cpu()
        inputs.append((x_emb, torch.tensor(len(x_emb))))

    inputs.sort(key=lambda x: x[1], reverse=True)

    input_tensors, input_lens = zip(*inputs)
    input_pad = pad_sequence(input_tensors, padding_value=dataset.vocab.pad_token)

    with torch.no_grad():
        mu, var, outputs, (hn, cn) = candidate_vae.encoder(
            input_pad.to(DEVICE), input_lens
        )
        z_s = candidate_vae.decoder.reparameterize(mu, var)
        latent_dim = z_s.shape[1]
        z_s = z_s.view(batch_size, -1)
        z_s = torch.abs(z_s[:, :latent_dim] - z_s[:, latent_dim:])

    return z_s.mean(dim=0), fixed_factor_idx

In [10]:
class AdversarialDataset(torch.utils.data.Dataset):
    def __init__(self, sellers_dataset: SellersDataset, _type: str = "train"):
        self._type = _type
        self.sellers_dataset = sellers_dataset

    def __getitem__(self, idx: int):
        return (
            prepare_train_batch(self.sellers_dataset, LABEL_BATCH_SIZE)
            if self._type == "train"
            else prepare_test_batch(self.sellers_dataset, LABEL_BATCH_SIZE)
        )

    def __len__(self):
        return int(1e11)


train_dataset = AdversarialDataset(dataset, "train")
test_dataset = AdversarialDataset(dataset, "test")

In [11]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
)


test_dataloader = DataLoader(
    test_dataset,
    batch_size=64,
)

# Train / test methods

In [12]:
def calc_accuracy(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(y_pred, dim=1)
    return recall_score(
        y_true.cpu(), preds.cpu(), labels=list(range(0, 4)), average="micro"
    )


def calc_fscore(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(y_pred, dim=1)
    return f1_score(
        y_true.cpu(), preds.cpu(), average="micro", labels=list(range(0, 4))
    )


def calc_precission(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(y_pred, dim=1)
    return precision_score(
        y_true.cpu(), preds.cpu(), average="micro", labels=list(range(0, 4))
    )


0


def calc_recall(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    preds = torch.argmax(y_pred, dim=1)
    return recall_score(
        y_true.cpu(), preds.cpu(), average="micro", labels=list(range(0, 4))
    )


def validate(
    model: nn.Module,
    loss_fn: torch.nn.CrossEntropyLoss,
    dataloader: DataLoader,
) -> tuple[torch.Tensor, torch.Tensor]:
    loss = 0
    accuracy = 0
    f_score = 0
    precision = 0
    recall = 0
    # _all = 0
    iters = 0
    for X_batch, y_batch in dataloader:
        y_pred = model(X_batch)
        iters += 1

        loss += loss_fn(y_pred, y_batch.cuda())
        accuracy += calc_accuracy(y_pred, y_batch)
        f_score += calc_fscore(y_pred, y_batch)
        precision += calc_precission(y_pred, y_batch)
        recall += calc_recall(y_pred, y_batch)
        if iters == 10:
            break

    return (
        loss / iters,
        accuracy / iters,
        f_score / iters,
        precision / iters,
        recall / iters,
    )


def fit(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.CrossEntropyLoss,
    train_dl: DataLoader,
    val_dl: DataLoader,
    writer: SummaryWriter,
    validate_every: int = 100,
    print_metrics: bool = True,
    patience: int = 5,
    run_prefix: str = "early_stopping",
) -> dict[str, list]:
    losses = {"train": [], "val": []}
    accs = {"train": [], "val": []}
    f1s = {"train": [], "val": []}
    precisions = {"train": [], "val": []}
    recalls = {"train": [], "val": []}

    min_val_loss = 1e10
    current_patience = 0
    iter = 0
    for X_batch, y_batch in tqdm(train_dl):
        iter += 1
        model.train()
        X_batch, y_batch = (
            X_batch.cuda(),
            y_batch.cuda(),
        )
        y_pred = model(
            X_batch
        )  # Uzyskanie pseudoprawdopodobieństw dla próbek z minibatcha

        loss = loss_fn(y_pred, y_batch)  # Policzenie funkcji straty
        loss.backward()  # Wsteczna propagacja z wyniku funkcji straty - policzenie gradientów i zapisanie ich w tensorach (parametrach)
        optimizer.step()  # Aktualizacja parametrów modelu przez optymalizator na podstawie gradientów zapisanych w tensorach (parametrach) oraz lr
        optimizer.zero_grad()  # Wyzerowanie gradientów w modelu, alternatywnie można wywołać model.zero_grad()

        model.eval()  # Przełączenie na tryb ewaluacji modelu - istotne dla takich warstw jak Dropuot czy BatchNorm
        if iter % validate_every == 0:
            with torch.no_grad():  # Wstrzymujemy przeliczanie i śledzenie gradientów dla tensorów - w procesie ewaluacji modelu nie chcemy zmian w gradientach
                (
                    train_loss,
                    train_acc,
                    train_f1,
                    train_precision,
                    train_recall,
                ) = validate(model, loss_fn, train_dl)
                # val_loss, val_acc, val_f1, val_precision, val_recall = validate(model, loss_fn, val_dl)

                if train_loss < min_val_loss:
                    min_val_loss = train_loss
                    current_patience = 0
                    os.makedirs("tests/checkpoints/disentanglement", exist_ok=True)
                    torch.save(
                        obj={
                            "iter": iter,
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                        },
                        f="tests/checkpoints/disentanglement/best" + "_" + run_prefix,
                    )
                else:
                    current_patience += 1

            losses["train"].append(train_loss)
            accs["train"].append(train_acc)
            f1s["train"].append(train_f1)
            precisions["train"].append(train_precision)
            recalls["train"].append(train_recall)

            # losses["val"].append(val_loss)
            # accs["val"].append(val_acc)
            # f1s["val"].append(val_f1)
            # precisions["val"].append(val_precision)
            # recalls["val"].append(val_recall)

            writer.add_scalars(
                main_tag=f"{run_prefix}/loss",
                tag_scalar_dict={"train": train_loss},
                global_step=iter,
            )
            writer.add_scalars(
                main_tag=f"{run_prefix}/acc",
                tag_scalar_dict={"train": train_acc},
                global_step=iter,
            )
            writer.add_scalars(
                main_tag=f"{run_prefix}/f1",
                tag_scalar_dict={"train": train_f1},
                global_step=iter,
            )
            writer.add_scalars(
                main_tag=f"{run_prefix}/precision",
                tag_scalar_dict={"train": train_precision},
                global_step=iter,
            )
            writer.add_scalars(
                main_tag=f"{run_prefix}/recall",
                tag_scalar_dict={"train": train_recall},
                global_step=iter,
            )

            if print_metrics:
                print(
                    f"Iter {iter}: "
                    f"train loss = {train_loss:.3f}, "
                    # f"validation loss = {val_loss:.3f}"
                )

            if current_patience >= patience:
                break

    model.eval()  # Przełączenie na tryb ewaluacji modelu - istotne dla takich warstw jak Dropuot czy BatchNorm
    # return losses

# Define networks to be trained

In [13]:
crossentropy_loss = nn.CrossEntropyLoss()

classifier = nn.Linear(general_config.latent_dim, 4).to(DEVICE)
optimizer = torch.optim.Adam(classifier.parameters(), lr=LR)

In [14]:
fit(
    model=classifier,
    optimizer=optimizer,
    loss_fn=crossentropy_loss,
    train_dl=train_dataloader,
    val_dl=test_dataloader,
    writer=writer_tensorboard,
    validate_every=VALIDATE_EVERY,
    print_metrics=False,
    patience=PATIENCE,
    run_prefix=f"disentanglement_{EXPERIMENT}_{CHECKPOINT.replace('.tar', '')}",
)

  0%|          | 400/1562500000 [34:57<2276224:17:03,  5.24s/it]


KeyboardInterrupt: 