In [20]:
import torch 
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import trange, tqdm
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from dataset import get_loaders
from models.doc import DOC

In [31]:
DEVICE="mps"
WINDOW=32
BATCH_SIZE=128
LR=1e-4
EPOCHS=3
WEIGHT_DECAY=1e-6
LATENT_DIM=32

In [11]:
trainloader, testloader = get_loaders(root='data', window_size=WINDOW, batch_size=BATCH_SIZE)

In [32]:
model = DOC(input_dim=123, hidden_size=256, latent_dim=LATENT_DIM, num_layers=4, bidirectional=True).to(DEVICE)

In [33]:
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[0], gamma=0.1)

In [None]:
# Init center
model.eval()

n_samples = 0
eps=0.1
c = torch.zeros(LATENT_DIM).to(DEVICE)

model.eval()
with torch.no_grad():
    for x, y in tqdm(trainloader):
        x = x.to(DEVICE)
        proj = model(x)
        n_samples += proj.shape[0]
        c += torch.sum(proj, dim=0)
c /= n_samples

c[(abs(c) < eps) & (c < 0)] = -eps
c[(abs(c) < eps) & (c > 0)] = eps

In [None]:
model.train()

pbar = range(EPOCHS)
for epoch in pbar:
        
    curr_loss = 0
    for x, target in tqdm(trainloader):
        x = x.to(DEVICE)
        optimizer.zero_grad()
        proj = model(x)
        dist = torch.sum((proj - c) ** 2, dim=1)
        loss = torch.mean(dist)
        curr_loss+=loss.item()
            
        loss.backward()
        optimizer.step()

    scheduler.step()
    print(f"For epoch {epoch+1}/{EPOCHS} ; loss : {curr_loss/len(trainloader)}")
    

checkpoint = {"state_dict":model.state_dict(), 'center':c.detach().cpu()}
torch.save(checkpoint, 'checkpoints/doc.pkl')

In [None]:
test_proj = []
targets = []

model.eval()
with torch.no_grad():
    for x, target in tqdm(testloader):
        x = x.to(DEVICE)
        proj = model(x)
        test_proj.append(proj)
        targets.append(target)

test_proj = torch.cat(test_proj)
test_targets = torch.cat(targets)

test_dist = torch.sum((test_proj - c) ** 2, dim=1)
test_scores = - test_dist