In [None]:
import sys

sys.path.append("../..")

In [None]:
from tqdm import tqdm
import torch
import umap
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from emgrep.models.dummy_baseline_model import DummyBaselineModel, DummyBaselineEncoder, DummyBaselineAR
from emgrep.datasets.EMGRepDataloader import EMGRepDataloader
from emgrep.criterion import CPCCriterion

In [None]:
dataloader = EMGRepDataloader(
    data_path="../../data/01_raw",
    train_data=[(1, day, time) for day in range(1, 4) for time in range(1, 3)],
    val_data=[(1, 4, time) for time in range(1, 3)],
    # test_data=[(1, 5, time) for time in range(1, 3)],
    positive_mode="none",
    seq_len=3000,
    seq_stride=3000,
    block_len=300,
    block_stride=300,
    batch_size=32,
    num_workers=0,
)

train_dataloader, val_dataloader, test_dataloader = dataloader.get_dataloaders()

In [None]:
model = DummyBaselineModel(
    encoder=DummyBaselineEncoder(),
    ar=DummyBaselineAR(),
)

In [None]:
len(train_dataloader), len(val_dataloader)

In [None]:
def train_loop(dataloader, model, criterion, optimizer):
    size = len(dataloader)
    pbar = tqdm(enumerate(dataloader), total=size)
    for batch, (emg, stimulus, info) in pbar:
        z, c = model(emg.double())
        loss = criterion(z, c)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 10 == 0:
            loss, current = loss.item(), batch * len(emg)
            pbar.set_description(f"loss: {loss:>7f}]")

def val_loop(dataloader, model, criterion):
    size = len(dataloader)
    test_loss = 0
    with torch.no_grad():
        for emg, stimulus, info in dataloader:
            z, c = model(emg.double())
            test_loss += criterion(z, c).item()
    test_loss /= size
    print(f"Validation Error: \n Avg loss: {test_loss:>8f} \n")

def visualize_embeddings(dataloader, model, epoch=0):
    """Computes embeddings for the entire dataset and plots them in 2D using umap."""
    size = len(dataloader)
    emg_embeddings = []
    label = []
    with torch.no_grad():
        for emg, stimulus, info in dataloader:
            z, c = model(emg.double())
            emg_embeddings.append(z.reshape(-1, 128))
            label.append(stimulus[:,0,:,-1,0].reshape(-1,1))
    emg_embeddings = torch.cat(emg_embeddings, dim=0)
    label = torch.cat(label, dim=0)

    reducer = umap.UMAP()
    embedding = reducer.fit_transform(emg_embeddings)
    plt.scatter(embedding[:, 0], embedding[:, 1], c=label, cmap="Spectral", s=0.1)
    plt.title("UMAP projection of the EMG embeddings")
    plt.savefig(f"umap_{epoch}.png", dpi=300)
    plt.close()


In [None]:
criterion = CPCCriterion(3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, criterion, optimizer)
    # val_loop(val_dataloader, model, criterion)
    visualize_embeddings(val_dataloader, model, t)
    
print("Done!")