In [None]:
import sys

sys.path.append("../..")

In [None]:
%load_ext autoreload

In [None]:
import logging
from argparse import Namespace

import torch
import matplotlib.pyplot as plt

from emgrep.datasets.EMGRepDataloader import get_dataloader
from emgrep.datasets.RepresentationsDataset import RepresentationDataset
from emgrep.models.cpc_model import CPCModel, CPCEncoder, CPCAR

from emgrep.train_classifier import train_classifier

In [None]:
DEVICE = "cpu"

SUBJECTS = 10
DAYS = 5
TIMES = 2

POS_MODE = "none"

ENC_DIM = 256
AR_DIM = 256
AR_LAYERS = 2

args = {
    "data": "../../data/01_raw",
    "device": DEVICE,
    "output_dir": "logs",
    "n_subjects": SUBJECTS,
    "n_days": DAYS,
    "n_times": TIMES,
    "positive_mode": POS_MODE,
    "val_idx": 1,
    "test_idx": 2,
    "seq_len": 3000,
    "seq_stride": 3000,
    "block_len": 300,
    "block_stride": 300,
    "batch_size_cpc": 256,
    "num_workers": 0,
    "encoder_dim": ENC_DIM,
    "ar_dim": AR_DIM,
    "ar_layers": AR_LAYERS,
    "cpc_k": 5,
    "epochs_cpc": 1,
    "lr_cpc": 0.0002,
    "weight_decay_cpc": 0.0,
    "epochs_classifier": 50,
    "lr_classifier": 0.1,
    "batch_size_classifier": 256,
    "split_mode": "day",
    "debug": True,
    "log_dir": "../../logs/none_1_2/2023-05-03_13-15-11",
    "log_to_file": False,
    "wandb": False,
    "normalize": False,
    "preprocessing": "none",
}

logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s %(levelname)s %(message)s",
)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

args = Namespace(**args)

In [None]:
checkpoint = "../../logs/logs/none_1_2/2023-05-02_23-05-45/checkpoints/best_model.pt"

enc = CPCEncoder(16, args.encoder_dim)
ar = CPCAR(args.encoder_dim, args.ar_dim, args.ar_layers)

model = CPCModel(enc, ar)

model.load_state_dict(torch.load(checkpoint, map_location=args.device))

In [None]:
rep_dataloader = get_dataloader(args, extract_rep_mode=True)
representations = {
    "train": RepresentationDataset(model=model, dataloader=rep_dataloader["train"], args=args),
    "val": RepresentationDataset(model=model, dataloader=rep_dataloader["val"], args=args),
    "test": RepresentationDataset(model=model, dataloader=rep_dataloader["test"], args=args),
}
del rep_dataloader

In [None]:
metrics = train_classifier(representations, pred_block=-1, args=args)

logging.info(f"Train accuracy: {metrics['train']['accuracy']:.3f}")
logging.info(f"Val accuracy:   {metrics['val']['accuracy']:.3f}")
logging.info(f"Test accuracy:  {metrics['test']['accuracy']:.3f}")

In [None]:
metrics = {
    "train": [],
    "val": [],
    "test": []
}

for pred_block in range(1, 11):
    m = train_classifier(representations, pred_block=-pred_block, args=args)

    metrics["train"].append(m["train"]["accuracy"])
    metrics["val"].append(m["val"]["accuracy"])
    metrics["test"].append(m["test"]["accuracy"])

In [None]:
plt.plot(metrics["train"], label="train")
plt.plot(metrics["val"], label="val")
plt.plot(metrics["test"], label="test")

plt.legend()
plt.show()