# Test capacity control

Test if we can infer about certain type of competences using only the desired part of latent space

In [1]:
%load_ext lab_black

### Imports

In [1]:
import os
import yaml
import pickle
import random

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import pandas as pd

from model.encoder import CandidateEncoderConfig
from model.decoder import CandidateDecoderConfig
from model.candidate_vae import CandidateVAE
from trainer.trainer import TrainerConfig
from config.general_config import GeneralConfig
from dataset.dataset import SellersDataset

In [2]:
EXPERIMENT = "candidate_vae_02_06_22_04_13_30"
CHECKPOINT = "7506_checkpoint.tar"
# If false, we can used cached content e.g. if we are testing the code
CREATE_DATASET = True

### Constants

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

with open(os.path.join("checkpoints", EXPERIMENT, "config.yaml"), "r") as file:
    try:
        config = yaml.safe_load(file)["vae"]
    except yaml.YAMLError as exc:
        print(exc)

general_config = GeneralConfig(**config["general"])


encoder_config = CandidateEncoderConfig(**{**config["encoder"], **config["general"]})

decoder_config = CandidateDecoderConfig(**{**config["decoder"], **config["general"]})

trainer_config = TrainerConfig(**{**config["trainer"], **config["general"]})

log_dir = os.path.join(general_config.checkpoints_dir, "runs")

os.makedirs(log_dir, exist_ok=True)

writer_tensorboard = SummaryWriter(log_dir)

In [4]:
# %reload_ext tensorboard
# %tensorboard --logdir $log_dir --port=6008

In [5]:
dataset = SellersDataset(
    dataset_path=general_config.datset_path,
    test_index=general_config.test_index,
    embedder_name=general_config.embedder_name,
    raw_data_path=general_config.raw_data_path,
    device=DEVICE,
    bow_remove_stopwords=general_config.bow_remove_stopwords,
    bow_remove_sentiment=general_config.bow_remove_sentiment,
    nn_embedding_size=encoder_config.lstm_hidden_dim,
    trim_tr=general_config.trim_tr,
)
# dataset.prepare_dataset()
dataset.load_dataset()

Loading dataset data/dataset_2/...
[2022-06-04 11:17:23,429] {dataset.py:260} INFO - Loading dataset data/dataset_2/...




Loaded dataset data/dataset_2/!
[2022-06-04 11:17:24,209] {dataset.py:288} INFO - Loaded dataset data/dataset_2/!


# Latent space configuration

In [6]:
disentangled_targets = {
    "skills": {
        "latent_dim": trainer_config.skills_dim,
        "output_dim": dataset.bow_vocab.n_words,
        "indexes": (0, trainer_config.skills_dim),
    },
    "education": {
        "latent_dim": trainer_config.education_dim,
        "output_dim": dataset.bow_vocab.n_words,
        "indexes": (
            trainer_config.skills_dim,
            trainer_config.skills_dim + trainer_config.education_dim,
        ),
    },
    "languages": {
        "latent_dim": trainer_config.languages_dim,
        "output_dim": len(dataset.langs_map) * dataset.num_lang_levels,
        "indexes": (
            trainer_config.skills_dim + trainer_config.education_dim,
            trainer_config.skills_dim
            + trainer_config.education_dim
            + trainer_config.languages_dim,
        ),
    },
}

# Load VAE

In [7]:
checkpoint = torch.load(os.path.join("checkpoints", EXPERIMENT, CHECKPOINT))

candidate_vae = CandidateVAE(
    general_config, encoder_config, decoder_config, dataset.vocab, dataset.embedder
).to(DEVICE)

candidate_vae.encoder.load_state_dict(checkpoint["encoder"])
candidate_vae.decoder.load_state_dict(checkpoint["decoder"])
candidate_vae.embedding.load_state_dict(checkpoint["embedding"]) if checkpoint[
    "embedding"
] else None

In [8]:
candidate_vae

CandidateVAE(
  (encoder): CandidateEncoder(
    (lstm): LSTM(100, 64, bidirectional=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (relu): ReLU()
    (fcs): ModuleList(
      (0): Linear(in_features=128, out_features=64, bias=True)
    )
    (fc_mu): Linear(in_features=64, out_features=64, bias=True)
    (fc_var): Linear(in_features=64, out_features=64, bias=True)
  )
  (decoder): CandidateDecoder(
    (lstm): LSTM(100, 64)
    (dropout): Dropout(p=0.1, inplace=False)
    (relu): ReLU()
    (fcs): ModuleList(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Linear(in_features=64, out_features=100, bias=True)
    )
    (attn): Attn()
    (attn_mu): Identity()
    (attn_var): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (concat): Linear(in_features=128, out_features=64, bias=True)
    (out): Linear(in_features=64, out_features=28371, bias

# Prepare train test data

In [9]:
def prepare_train_rows(dataset: SellersDataset) -> dict:
    rows = []

    for idx in tqdm(range(len(dataset))):
        latents = {}
        # Both seeds have to me set up!!!
        rng = np.random.default_rng(42)
        random.seed(42)
        row = dataset.__getitem__(idx)
        targets = {}

        (
            input_tensor,
            _,
            targets["skills"],
            targets["education"],
            targets["languages"],
        ) = row

        with torch.no_grad():
            input_lengths = torch.tensor(len(input_tensor)).unsqueeze(dim=0)

            mu, var, outputs, (hn, cn) = candidate_vae.encoder(
                input_tensor.unsqueeze(dim=1).to(DEVICE), input_lengths.to("cpu")
            )
            z = candidate_vae.decoder.reparameterize(mu, var)

        for key in disentangled_targets:
            index_start, index_end = disentangled_targets[key]["indexes"]

            # Use mu or z?
            latents[key] = [
                z[:, index_start:index_end],
                torch.cat(
                    (z[:, :index_start], z[:, index_end:]),
                    dim=1,
                ),
                targets[key],
            ]

        rows.append(latents)
    return rows


def prepare_test_data(dataset: SellersDataset):
    # We have to set both seeds!!!
    rng = np.random.default_rng(42)
    random.seed(42)

    texts = dataset.test_dataset.progress_apply(
        lambda x: dataset._create_textual_decription(x, rng), axis=1
    )
    embedded = [dataset.embedder(text)[0].cpu() for text in tqdm(texts)]

    # if general_config.embedder_name != EmbedderType.LANG:
    embedded = [text.unsqueeze(dim=1) for text in tqdm(embedded)]

    input_lengths = [torch.tensor(len(row)).unsqueeze(dim=0) for row in embedded]

    dataset.test_dataset["embedded"] = embedded
    dataset.test_dataset["input_lengths"] = input_lengths


def prepare_test_row(row: pd.Series) -> dict:
    latents = {}
    with torch.no_grad():
        mu, var, outputs, (hn, cn) = candidate_vae.encoder(
            row["embedded"].to(DEVICE), row["input_lengths"].to("cpu")
        )
        z = candidate_vae.decoder.reparameterize(mu, var)

    for key in disentangled_targets:
        index_start, index_end = disentangled_targets[key]["indexes"]

        # Use mu or z?        
        latents[key] = [
            z[:, index_start:index_end],
            torch.cat(
                (z[:, :index_start], z[:, index_end:]),
                dim=1,
            ),
            row[f"{key}_vec"],
        ]

    return latents

In [10]:
if CREATE_DATASET:
    train_latents = prepare_train_rows(dataset)
    prepare_test_data(dataset)
    test_latents = dataset.test_dataset.progress_apply(prepare_test_row, axis=1)
    os.makedirs("tests/data/capacity", exist_ok=True)
    with open(os.path.join("tests/data/capacity/train_latents.pickle"), "wb") as f:
        pickle.dump(train_latents, f)

    with open(os.path.join("tests/data/capacity/test_latents.pickle"), "wb") as f:
        pickle.dump(test_latents, f)
else:
    with open(os.path.join("tests/data/capacity/train_latents.pickle"), "rb") as f:
        train_latents = pickle.load(f)

    with open(os.path.join("tests/data/capacity/test_latents.pickle"), "rb") as f:
        test_latents = pickle.load(f)

  0%|          | 59/40014 [00:00<02:16, 293.56it/s]

torch.Size([85, 1, 100])
torch.Size([87, 1, 100])
torch.Size([79, 1, 100])
torch.Size([89, 1, 100])
torch.Size([70, 1, 100])
torch.Size([63, 1, 100])
torch.Size([127, 1, 100])
torch.Size([117, 1, 100])
torch.Size([72, 1, 100])
torch.Size([80, 1, 100])
torch.Size([131, 1, 100])
torch.Size([109, 1, 100])
torch.Size([78, 1, 100])
torch.Size([112, 1, 100])
torch.Size([46, 1, 100])
torch.Size([133, 1, 100])
torch.Size([126, 1, 100])
torch.Size([77, 1, 100])
torch.Size([143, 1, 100])
torch.Size([65, 1, 100])
torch.Size([121, 1, 100])
torch.Size([141, 1, 100])
torch.Size([81, 1, 100])
torch.Size([157, 1, 100])
torch.Size([146, 1, 100])
torch.Size([53, 1, 100])
torch.Size([122, 1, 100])
torch.Size([43, 1, 100])
torch.Size([122, 1, 100])
torch.Size([88, 1, 100])
torch.Size([100, 1, 100])
torch.Size([56, 1, 100])
torch.Size([125, 1, 100])
torch.Size([142, 1, 100])
torch.Size([114, 1, 100])
torch.Size([109, 1, 100])
torch.Size([62, 1, 100])
torch.Size([82, 1, 100])
torch.Size([123, 1, 100])
torch

  0%|          | 124/40014 [00:00<02:07, 311.84it/s]

torch.Size([103, 1, 100])
torch.Size([84, 1, 100])
torch.Size([116, 1, 100])
torch.Size([94, 1, 100])
torch.Size([100, 1, 100])
torch.Size([76, 1, 100])
torch.Size([61, 1, 100])
torch.Size([51, 1, 100])
torch.Size([110, 1, 100])
torch.Size([94, 1, 100])
torch.Size([123, 1, 100])
torch.Size([56, 1, 100])
torch.Size([123, 1, 100])
torch.Size([60, 1, 100])
torch.Size([100, 1, 100])
torch.Size([71, 1, 100])
torch.Size([93, 1, 100])
torch.Size([56, 1, 100])
torch.Size([64, 1, 100])
torch.Size([63, 1, 100])
torch.Size([138, 1, 100])
torch.Size([141, 1, 100])
torch.Size([99, 1, 100])
torch.Size([69, 1, 100])
torch.Size([104, 1, 100])
torch.Size([89, 1, 100])
torch.Size([118, 1, 100])
torch.Size([84, 1, 100])
torch.Size([103, 1, 100])
torch.Size([128, 1, 100])
torch.Size([92, 1, 100])
torch.Size([73, 1, 100])
torch.Size([94, 1, 100])
torch.Size([120, 1, 100])
torch.Size([146, 1, 100])
torch.Size([138, 1, 100])
torch.Size([100, 1, 100])
torch.Size([132, 1, 100])
torch.Size([69, 1, 100])
torch.S

  0%|          | 156/40014 [00:00<02:07, 313.58it/s]

torch.Size([63, 1, 100])
torch.Size([83, 1, 100])
torch.Size([127, 1, 100])
torch.Size([129, 1, 100])
torch.Size([70, 1, 100])
torch.Size([50, 1, 100])
torch.Size([168, 1, 100])
torch.Size([73, 1, 100])
torch.Size([88, 1, 100])
torch.Size([102, 1, 100])
torch.Size([115, 1, 100])
torch.Size([143, 1, 100])
torch.Size([122, 1, 100])
torch.Size([92, 1, 100])
torch.Size([86, 1, 100])
torch.Size([72, 1, 100])
torch.Size([91, 1, 100])
torch.Size([116, 1, 100])
torch.Size([123, 1, 100])
torch.Size([97, 1, 100])
torch.Size([90, 1, 100])
torch.Size([62, 1, 100])
torch.Size([45, 1, 100])
torch.Size([64, 1, 100])
torch.Size([84, 1, 100])
torch.Size([83, 1, 100])
torch.Size([97, 1, 100])
torch.Size([86, 1, 100])
torch.Size([73, 1, 100])
torch.Size([100, 1, 100])
torch.Size([84, 1, 100])
torch.Size([133, 1, 100])
torch.Size([69, 1, 100])
torch.Size([141, 1, 100])
torch.Size([134, 1, 100])
torch.Size([116, 1, 100])
torch.Size([95, 1, 100])
torch.Size([109, 1, 100])
torch.Size([81, 1, 100])
torch.Size

  1%|          | 219/40014 [00:00<02:11, 303.27it/s]

torch.Size([55, 1, 100])
torch.Size([100, 1, 100])
torch.Size([73, 1, 100])
torch.Size([127, 1, 100])
torch.Size([111, 1, 100])
torch.Size([50, 1, 100])
torch.Size([67, 1, 100])
torch.Size([145, 1, 100])
torch.Size([110, 1, 100])
torch.Size([146, 1, 100])
torch.Size([57, 1, 100])
torch.Size([108, 1, 100])
torch.Size([101, 1, 100])
torch.Size([79, 1, 100])
torch.Size([78, 1, 100])
torch.Size([60, 1, 100])
torch.Size([98, 1, 100])
torch.Size([104, 1, 100])
torch.Size([80, 1, 100])
torch.Size([86, 1, 100])
torch.Size([57, 1, 100])
torch.Size([114, 1, 100])
torch.Size([53, 1, 100])
torch.Size([118, 1, 100])
torch.Size([125, 1, 100])
torch.Size([97, 1, 100])
torch.Size([91, 1, 100])
torch.Size([88, 1, 100])
torch.Size([84, 1, 100])
torch.Size([59, 1, 100])
torch.Size([117, 1, 100])
torch.Size([92, 1, 100])
torch.Size([93, 1, 100])
torch.Size([65, 1, 100])
torch.Size([84, 1, 100])
torch.Size([129, 1, 100])
torch.Size([62, 1, 100])
torch.Size([129, 1, 100])
torch.Size([44, 1, 100])
torch.Size

  1%|          | 283/40014 [00:00<02:15, 293.85it/s]

torch.Size([57, 1, 100])
torch.Size([85, 1, 100])
torch.Size([101, 1, 100])
torch.Size([70, 1, 100])
torch.Size([143, 1, 100])
torch.Size([109, 1, 100])
torch.Size([88, 1, 100])
torch.Size([135, 1, 100])
torch.Size([55, 1, 100])
torch.Size([115, 1, 100])
torch.Size([73, 1, 100])
torch.Size([109, 1, 100])
torch.Size([107, 1, 100])
torch.Size([127, 1, 100])
torch.Size([50, 1, 100])
torch.Size([45, 1, 100])
torch.Size([123, 1, 100])
torch.Size([68, 1, 100])
torch.Size([85, 1, 100])
torch.Size([170, 1, 100])
torch.Size([127, 1, 100])
torch.Size([113, 1, 100])
torch.Size([103, 1, 100])
torch.Size([81, 1, 100])
torch.Size([51, 1, 100])
torch.Size([64, 1, 100])
torch.Size([94, 1, 100])
torch.Size([107, 1, 100])
torch.Size([145, 1, 100])
torch.Size([80, 1, 100])
torch.Size([79, 1, 100])
torch.Size([78, 1, 100])
torch.Size([100, 1, 100])
torch.Size([60, 1, 100])
torch.Size([83, 1, 100])
torch.Size([94, 1, 100])
torch.Size([98, 1, 100])
torch.Size([84, 1, 100])
torch.Size([90, 1, 100])
torch.Siz

  1%|          | 344/40014 [00:01<02:15, 292.63it/s]

torch.Size([125, 1, 100])
torch.Size([117, 1, 100])
torch.Size([91, 1, 100])
torch.Size([116, 1, 100])
torch.Size([93, 1, 100])
torch.Size([105, 1, 100])
torch.Size([124, 1, 100])
torch.Size([140, 1, 100])
torch.Size([98, 1, 100])
torch.Size([100, 1, 100])
torch.Size([121, 1, 100])
torch.Size([137, 1, 100])
torch.Size([92, 1, 100])
torch.Size([50, 1, 100])
torch.Size([91, 1, 100])
torch.Size([96, 1, 100])
torch.Size([70, 1, 100])
torch.Size([121, 1, 100])
torch.Size([116, 1, 100])
torch.Size([76, 1, 100])
torch.Size([73, 1, 100])
torch.Size([118, 1, 100])
torch.Size([138, 1, 100])
torch.Size([57, 1, 100])
torch.Size([87, 1, 100])
torch.Size([71, 1, 100])
torch.Size([60, 1, 100])
torch.Size([60, 1, 100])
torch.Size([110, 1, 100])
torch.Size([152, 1, 100])
torch.Size([73, 1, 100])
torch.Size([49, 1, 100])
torch.Size([50, 1, 100])
torch.Size([93, 1, 100])
torch.Size([73, 1, 100])
torch.Size([147, 1, 100])
torch.Size([101, 1, 100])
torch.Size([110, 1, 100])
torch.Size([68, 1, 100])
torch.S

  1%|          | 407/40014 [00:01<02:11, 300.24it/s]

torch.Size([105, 1, 100])
torch.Size([51, 1, 100])
torch.Size([48, 1, 100])
torch.Size([63, 1, 100])
torch.Size([64, 1, 100])
torch.Size([124, 1, 100])
torch.Size([117, 1, 100])
torch.Size([112, 1, 100])
torch.Size([129, 1, 100])
torch.Size([79, 1, 100])
torch.Size([217, 1, 100])
torch.Size([89, 1, 100])
torch.Size([89, 1, 100])
torch.Size([114, 1, 100])
torch.Size([107, 1, 100])
torch.Size([96, 1, 100])
torch.Size([129, 1, 100])
torch.Size([81, 1, 100])
torch.Size([64, 1, 100])
torch.Size([123, 1, 100])
torch.Size([116, 1, 100])
torch.Size([106, 1, 100])
torch.Size([81, 1, 100])
torch.Size([58, 1, 100])
torch.Size([117, 1, 100])
torch.Size([127, 1, 100])
torch.Size([79, 1, 100])
torch.Size([114, 1, 100])
torch.Size([108, 1, 100])
torch.Size([85, 1, 100])
torch.Size([72, 1, 100])
torch.Size([122, 1, 100])
torch.Size([108, 1, 100])
torch.Size([83, 1, 100])
torch.Size([80, 1, 100])
torch.Size([71, 1, 100])
torch.Size([79, 1, 100])
torch.Size([117, 1, 100])
torch.Size([50, 1, 100])
torch.

  1%|          | 468/40014 [00:01<02:12, 297.50it/s]

torch.Size([101, 1, 100])
torch.Size([118, 1, 100])
torch.Size([67, 1, 100])
torch.Size([59, 1, 100])
torch.Size([81, 1, 100])
torch.Size([121, 1, 100])
torch.Size([74, 1, 100])
torch.Size([119, 1, 100])
torch.Size([103, 1, 100])
torch.Size([120, 1, 100])
torch.Size([61, 1, 100])
torch.Size([119, 1, 100])
torch.Size([124, 1, 100])
torch.Size([118, 1, 100])
torch.Size([98, 1, 100])
torch.Size([108, 1, 100])
torch.Size([85, 1, 100])
torch.Size([76, 1, 100])
torch.Size([76, 1, 100])
torch.Size([74, 1, 100])
torch.Size([78, 1, 100])
torch.Size([131, 1, 100])
torch.Size([117, 1, 100])
torch.Size([80, 1, 100])
torch.Size([59, 1, 100])
torch.Size([118, 1, 100])
torch.Size([103, 1, 100])
torch.Size([106, 1, 100])
torch.Size([81, 1, 100])
torch.Size([91, 1, 100])
torch.Size([80, 1, 100])
torch.Size([144, 1, 100])
torch.Size([89, 1, 100])
torch.Size([116, 1, 100])
torch.Size([89, 1, 100])
torch.Size([120, 1, 100])
torch.Size([83, 1, 100])
torch.Size([84, 1, 100])
torch.Size([126, 1, 100])
torch.

  1%|▏         | 531/40014 [00:01<02:10, 301.98it/s]

torch.Size([65, 1, 100])
torch.Size([85, 1, 100])
torch.Size([137, 1, 100])
torch.Size([104, 1, 100])
torch.Size([67, 1, 100])
torch.Size([91, 1, 100])
torch.Size([99, 1, 100])
torch.Size([70, 1, 100])
torch.Size([93, 1, 100])
torch.Size([91, 1, 100])
torch.Size([99, 1, 100])
torch.Size([104, 1, 100])
torch.Size([37, 1, 100])
torch.Size([121, 1, 100])
torch.Size([85, 1, 100])
torch.Size([85, 1, 100])
torch.Size([93, 1, 100])
torch.Size([60, 1, 100])
torch.Size([107, 1, 100])
torch.Size([121, 1, 100])
torch.Size([64, 1, 100])
torch.Size([76, 1, 100])
torch.Size([99, 1, 100])
torch.Size([65, 1, 100])
torch.Size([78, 1, 100])
torch.Size([131, 1, 100])
torch.Size([51, 1, 100])
torch.Size([68, 1, 100])
torch.Size([58, 1, 100])
torch.Size([64, 1, 100])
torch.Size([118, 1, 100])
torch.Size([125, 1, 100])
torch.Size([116, 1, 100])
torch.Size([125, 1, 100])
torch.Size([91, 1, 100])
torch.Size([87, 1, 100])
torch.Size([135, 1, 100])
torch.Size([73, 1, 100])
torch.Size([57, 1, 100])
torch.Size([1

  1%|▏         | 595/40014 [00:01<02:07, 308.92it/s]

torch.Size([84, 1, 100])
torch.Size([82, 1, 100])
torch.Size([119, 1, 100])
torch.Size([104, 1, 100])
torch.Size([102, 1, 100])
torch.Size([59, 1, 100])
torch.Size([73, 1, 100])
torch.Size([136, 1, 100])
torch.Size([81, 1, 100])
torch.Size([97, 1, 100])
torch.Size([111, 1, 100])
torch.Size([112, 1, 100])
torch.Size([95, 1, 100])
torch.Size([76, 1, 100])
torch.Size([88, 1, 100])
torch.Size([119, 1, 100])
torch.Size([59, 1, 100])
torch.Size([145, 1, 100])
torch.Size([60, 1, 100])
torch.Size([99, 1, 100])
torch.Size([88, 1, 100])
torch.Size([107, 1, 100])
torch.Size([144, 1, 100])
torch.Size([90, 1, 100])
torch.Size([81, 1, 100])
torch.Size([142, 1, 100])
torch.Size([87, 1, 100])
torch.Size([45, 1, 100])
torch.Size([62, 1, 100])
torch.Size([72, 1, 100])
torch.Size([153, 1, 100])
torch.Size([100, 1, 100])
torch.Size([33, 1, 100])
torch.Size([55, 1, 100])
torch.Size([124, 1, 100])
torch.Size([86, 1, 100])
torch.Size([84, 1, 100])
torch.Size([56, 1, 100])
torch.Size([92, 1, 100])
torch.Size(

  2%|▏         | 663/40014 [00:02<02:10, 300.97it/s]

torch.Size([73, 1, 100])
torch.Size([76, 1, 100])
torch.Size([96, 1, 100])
torch.Size([76, 1, 100])
torch.Size([136, 1, 100])
torch.Size([122, 1, 100])
torch.Size([63, 1, 100])
torch.Size([140, 1, 100])
torch.Size([122, 1, 100])
torch.Size([137, 1, 100])
torch.Size([107, 1, 100])
torch.Size([158, 1, 100])
torch.Size([58, 1, 100])
torch.Size([65, 1, 100])
torch.Size([91, 1, 100])
torch.Size([134, 1, 100])
torch.Size([109, 1, 100])
torch.Size([58, 1, 100])
torch.Size([128, 1, 100])
torch.Size([96, 1, 100])
torch.Size([66, 1, 100])
torch.Size([88, 1, 100])
torch.Size([116, 1, 100])
torch.Size([113, 1, 100])
torch.Size([132, 1, 100])
torch.Size([128, 1, 100])
torch.Size([95, 1, 100])
torch.Size([81, 1, 100])
torch.Size([104, 1, 100])
torch.Size([132, 1, 100])
torch.Size([131, 1, 100])
torch.Size([101, 1, 100])
torch.Size([159, 1, 100])
torch.Size([69, 1, 100])
torch.Size([138, 1, 100])
torch.Size([79, 1, 100])
torch.Size([129, 1, 100])
torch.Size([71, 1, 100])
torch.Size([90, 1, 100])
torc




KeyboardInterrupt: 

In [43]:
class AdversarialDataset(torch.utils.data.Dataset):
    def __init__(self, latents: dict, key: str):
        self.data = [row[key] for row in latents]

    def __getitem__(self, idx: int):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


adversarial_datasets_train = {
    target: AdversarialDataset(train_latents, target) for target in disentangled_targets
}


adversarial_datasets_test = {
    target: AdversarialDataset(test_latents, target) for target in disentangled_targets
}

In [44]:
dataloaders_train = {
    target: DataLoader(
        adversarial_datasets_train[target],
        batch_size=4096,
    )
    for target in disentangled_targets
}

dataloaders_test = {
    target: DataLoader(
        adversarial_datasets_test[target],
        batch_size=1024,
    )
    for target in disentangled_targets
}

In [45]:
next(iter(dataloaders_train["skills"]))[1].shape
# next(iter(dataloaders_test["languages"]))[1].shape

torch.Size([4096, 1, 48])

# Train / test methods

In [46]:
def validate(
    model: nn.Module,
    loss_fn: torch.nn.CrossEntropyLoss,
    dataloader: DataLoader,
    _type: str = "mult",
) -> tuple[torch.Tensor, torch.Tensor]:
    loss = 0
    # _all = 0
    iters = 0
    for X_mult_batch, X_adv_batch, y_batch in dataloader:

        y_pred = model(
            X_mult_batch.squeeze(dim=1).cuda()
            if _type == "mult"
            else X_adv_batch.squeeze(dim=1).cuda()
        )
        iters += 1
        loss += loss_fn(y_pred, y_batch.cuda())
    return loss / iters


def fit(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.CrossEntropyLoss,
    train_dl: DataLoader,
    val_dl: DataLoader,
    writer: SummaryWriter,
    epochs: int = 20,
    print_metrics: bool = True,
    patience: int = 5,
    run_prefix: str = "early_stopping",
    _type: str = "mult",
) -> dict[str, list]:
    losses = {"train": [], "val": []}

    min_val_loss = 1e10
    current_patience = 0
    for epoch in tqdm(range(epochs)):
        model.train()

        for X_mult_batch, X_adv_batch, y_batch in train_dl:
            X_batch = X_mult_batch if _type == "mult" else X_adv_batch
            X_batch, y_batch = (
                X_batch.squeeze(dim=1).cuda(),
                y_batch.cuda(),
            )
            y_pred = model(
                X_batch
            )  # Uzyskanie pseudoprawdopodobieństw dla próbek z minibatcha

            loss = loss_fn(y_pred, y_batch)  # Policzenie funkcji straty
            loss.backward()  # Wsteczna propagacja z wyniku funkcji straty - policzenie gradientów i zapisanie ich w tensorach (parametrach)
            optimizer.step()  # Aktualizacja parametrów modelu przez optymalizator na podstawie gradientów zapisanych w tensorach (parametrach) oraz lr
            optimizer.zero_grad()  # Wyzerowanie gradientów w modelu, alternatywnie można wywołać model.zero_grad()

        model.eval()  # Przełączenie na tryb ewaluacji modelu - istotne dla takich warstw jak Dropuot czy BatchNorm
        with torch.no_grad():  # Wstrzymujemy przeliczanie i śledzenie gradientów dla tensorów - w procesie ewaluacji modelu nie chcemy zmian w gradientach
            train_loss = validate(model, loss_fn, train_dl, _type)
            val_loss = validate(model, loss_fn, val_dl, _type)

            if val_loss < min_val_loss:
                min_val_loss = val_loss
                current_patience = 0
                os.makedirs("tests/checkpoints/capacity", exist_ok=True)
                torch.save(
                    obj={
                        "epoch": epoch,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                    },
                    f="tests/checkpoints/capacity/best" + "_" + run_prefix,
                )
            else:
                current_patience += 1

        losses["train"].append(train_loss)
        losses["val"].append(val_loss)

        writer.add_scalars(
            main_tag=f"{run_prefix}/loss",
            tag_scalar_dict={"train": train_loss, "dev": val_loss},
            global_step=epoch + 1,
        )

        if print_metrics:
            print(
                f"Epoch {epoch}: "
                f"train loss = {train_loss:.3f}, "
                f"validation loss = {val_loss:.3f}"
            )

        if current_patience >= patience:
            break

    model.eval()  # Przełączenie na tryb ewaluacji modelu - istotne dla takich warstw jak Dropuot czy BatchNorm
    # return losses

# Define networks to be trained

In [47]:
crossentropy_loss = nn.CrossEntropyLoss()

multitask_classifiers = nn.ModuleDict(
    {
        target: nn.Linear(
            disentangled_targets[target]["latent_dim"],
            disentangled_targets[target]["output_dim"],
        ).to(DEVICE)
        for target in disentangled_targets
    }
)
# Retreiving target using all except target. Classifiers should fail :)
adversarial_classifiers = nn.ModuleDict(
    {
        target: nn.Linear(
            general_config.latent_dim - disentangled_targets[target]["latent_dim"],
            disentangled_targets[target]["output_dim"],
        ).to(DEVICE)
        for target in disentangled_targets
    }
)

multitask_optimizers = {
    target: torch.optim.Adam(
        multitask_classifiers[target].parameters(),
        lr=0.05,
    )
    for target in disentangled_targets
}

adversarial_optimizers = {
    target: torch.optim.Adam(
        adversarial_classifiers[target].parameters(),
        lr=0.05,
    )
    for target in disentangled_targets
}

In [48]:
EPOCHS = 200

for target in disentangled_targets:
    _type = "mult"
    fit(
        model=multitask_classifiers[target],
        optimizer=multitask_optimizers[target],
        loss_fn=crossentropy_loss,
        train_dl=dataloaders_train[target],
        val_dl=dataloaders_test[target],
        writer=writer_tensorboard,
        epochs=EPOCHS,
        print_metrics=False,
        patience=10,
        run_prefix=f"capacity_{_type}_{target}_{EXPERIMENT}_{CHECKPOINT.replace('.tar', '')}",
        _type=_type,
    )
    _type = "adv"
    fit(
        model=adversarial_classifiers[target],
        optimizer=adversarial_optimizers[target],
        loss_fn=crossentropy_loss,
        train_dl=dataloaders_train[target],
        val_dl=dataloaders_test[target],
        writer=writer_tensorboard,
        epochs=EPOCHS,
        print_metrics=False,
        patience=10,
        run_prefix=f"capacity_{_type}_{target}_{EXPERIMENT}_{CHECKPOINT.replace('.tar', '')}",
        _type=_type,
    )

 80%|████████  | 80/100 [01:55<00:28,  1.45s/it]
 50%|█████     | 50/100 [01:18<01:18,  1.56s/it]
 23%|██▎       | 23/100 [00:39<02:13,  1.73s/it]
 23%|██▎       | 23/100 [00:38<02:10,  1.69s/it]
 82%|████████▏ | 82/100 [00:16<00:03,  4.88it/s]
 46%|████▌     | 46/100 [00:09<00:11,  4.73it/s]
