In [1]:
from common.imports import *
from common.util import *
from common.const import *
from data.dataloader import *
from model.transformer import *
from model.run_model import *

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
torch.cuda.empty_cache()

cuda:0


In [3]:
# RUN SETTINGS
load_model = False
save_model = True
model_name = "best_model.pt"
max_src_list = [100]
test_size = 20
num_epochs = 2
batch_size = 4
num_workers = 0
days_pred_list = [50]

# OPTIM PARAMS
lr = 0.1
max_norm = 100000

In [4]:
# OBJECTS

## DATALOADER
train_loader, test_loader = create_splits(
    batch_size=batch_size,
    test_size=test_size,
    num_workers=num_workers
)

## MODEL
model = stonks_transformer_model(
    d_model=D_MODEL,
    nhead=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout=DROPOUT
).to(device)
if load_model:
    checkpoint = torch.load(model_name)
    model.load_state_dict(checkpoint['state_dict'])

## BATCH PROCESSOR
batch_processor = BatchProcessor(-1)



In [None]:
# TRAIN

total_iters = len(max_src_list) * len(days_pred_list) * num_epochs * len(train_loader)
all_losses = np.empty((0,))
with tqdm(total=total_iters) as pbar:
    for max_src in max_src_list:
        batch_processor.max_src_window = max_src
        for days_pred in days_pred_list:
            for _ in range(num_epochs):
                all_losses = np.concatenate((
                    all_losses,
                    train_stonks_transformer(
                        model, lr, max_norm, days_pred, train_loader, 
                        batch_processor, device, pbar
                    )
                ))

plot_loss(all_losses)
print("Ending loss:", np.mean(all_losses[-200:]))

In [6]:
# TEST

with tqdm(total=len(test_loader)) as pbar:
    losses = test_stonks_transformer(
        model, test_size, test_loader, batch_processor, device, pbar
    )
avg_test_loss = np.mean(losses)
print("Avg Loss:", avg_test_loss) # prev best: 0.8122367855482101

  0%|          | 0/1154 [00:00<?, ?it/s]

Avg Loss: 2.4162297365800876


In [7]:
# SAVE
try:
    checkpoint = torch.load(model_name)
    prev_best_avg_test_loss = checkpoint['avg_test_loss']
except:
    prev_best_avg_test_loss = 1000000

if save_model and avg_test_loss < prev_best_avg_test_loss:
    print("saving model")
    print("prev best:", prev_best_avg_test_loss)
    print("new best:", avg_test_loss)
    torch.save(
        {
            'state_dict': model.state_dict(),
            'avg_test_loss': avg_test_loss
        }, 
        f="best_model.pt"
    )