In [2]:
import sys
import os 

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sys.path.append(os.path.join(parent_dir, 'src'))
print(f"Added to sys.path: {parent_dir}")

Added to sys.path: /home/vdakov/Desktop/thesis/msc-thesis-vasko


In [7]:
import models.encoders as encoders
import train
from criterion.bar_distribution import BarDistribution, get_bucket_limits
from models import positional_encodings
from prior_generation import gp_prior
import torch


args = {
    'epochs': 10,
    'batch_size': 100,
    'steps_per_epoch': 50,
    'lr': 0.001,
    'sequence_length': 10,
    'emsize': 512,
    'nlayers': 6,
    'nhead': 4,
    'nhid': 1024,
    'dropout': 0.0,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Task specific
    'num_buckets': 100,
    'min_y': -100.0,
    'max_y': 100.0,
    'prior_hyperparameters': {'num_features': 1, 'num_outputs': 100, 'device': 'cpu'},
    
    # Encoders
    'input_normalization': False,
    'encoder_type': 'linear', # 'linear' or 'mlp'
    'pos_encoder_type': 'sinus' # 'sinus', 'learned', 'none'
}

# 5. Setup Components (Mimicking the logic in your main block)
# Prior
prior = gp_prior.GaussianProcessPriorGenerator()
criterion = BarDistribution(borders=get_bucket_limits(args['num_buckets'], full_range=(args['min_y'], args['max_y'])))

# Encoders
if args['encoder_type'] == 'linear':
    encoder_generator = encoders.LinearEncoder
else:
    encoder_generator = encoders.MLPEncoder

if args['pos_encoder_type'] == 'sinus':
    pos_encoder_generator = positional_encodings.PositionalEncoding
elif args['pos_encoder_type'] == 'learned':
    pos_encoder_generator = positional_encodings.LearnedPositionalEncoding
else:
    pos_encoder_generator = positional_encodings.NoPositionalEncoding

# Transformer Config: (emsize, nhead, nhid, nlayers, dropout)
transformer_config = (args['emsize'], args['nhead'], args['nhid'], args['nlayers'], args['dropout'])

# 6. Run Training
print(f"Starting training on {args['device']}...")
final_loss, positional_losses, model = train.train(
    prior_dataloader=prior,
    criterion=criterion, # Passing the wrapper
    encoder_generator=encoder_generator,
    transformer_configuration=transformer_config,
    y_encoder_generator=encoder_generator, # Using same encoder type for y
    pos_encoder_generator=pos_encoder_generator,
    epochs=args['epochs'],
    steps_per_epoch=args['steps_per_epoch'],
    batch_size=args['batch_size'],
    sequence_length=args['sequence_length'],
    lr=args['lr'],
    prior_hyperparameters=args['prior_hyperparameters'],
    device=args['device'],
    verbose=True 
)

Starting training on cpu...
Using cpu:0 device
{'num_features': 1, 'num_outputs': 100, 'device': 'cpu'}
Dataset.__dict__ {'num_steps': 50, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 100, 'seq_len': 10, 'num_features': 1, 'num_outputs': 100, 'device': 'cpu'}, 'num_features': 1, 'num_outputs': 100}
DataLoader.__dict__ {'num_steps': 50, 'fuse_x_y': False, 'get_batch_kwargs': {'batch_size': 100, 'seq_len': 10, 'num_features': 1, 'num_outputs': 100, 'device': 'cpu'}, 'PriorDataset': <class 'prior_generation.prior_dataloader.get_dataloader.<locals>.PriorDataset'>, 'num_features': 1, 'num_outputs': 100, 'dataset': <prior_generation.prior_dataloader.get_dataloader.<locals>.PriorDataset object at 0x70028e6fe510>, 'num_workers': 0, 'prefetch_factor': None, 'pin_memory': False, 'pin_memory_device': '', 'timeout': 0, 'worker_init_fn': None, '_DataLoader__multiprocessing_context': None, 'in_order': True, '_dataset_kind': 1, 'batch_size': None, 'drop_last': False, 'sampler': <torch.utils.



-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  1.93s | mean loss  0.11 | pos losses  5.55, 5.54, 5.45, 5.41, 5.39, 5.33, 5.28, 5.31, 5.36, 5.37, lr 0.0 data time  0.04 step time  1.89 forward time  0.64
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   2 | time:  1.42s | mean loss  0.11 | pos losses  5.57, 5.56, 5.45, 5.41, 5.38, 5.33, 5.27, 5.32, 5.38, 5.39, lr 0.0001 data time  0.01 step time  1.40 forward time  0.47
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   3 | time:  1.57s | mean loss  0.09 | pos losses  4.69, 4.66, 4.58, 4.54, 4.54, 4.50, 4.45, 4.49, 4.53, 4.56, lr 0.0002 data time  0.01 step time  1.55 forward tim

In [None]:
import matplotlib.pyplot as plt

plt.figure((15, 5))
