In [None]:
# Distributed
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.debugger import TensorBoardOutputConfig

role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket="tf-binding-sites")

training_data_s3_path = "s3://tf-binding-sites/pretraining/data/"

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path="s3://tf-binding-sites/pretraining/models/results/tensorboard",
    container_local_output_path="/opt/ml/output/tensorboard"
)

# Distributed training setup
distribution = {"pytorchddp": {"enabled": "true"}}

estimator = PyTorch(
    base_job_name="pretraining-unfrozen-transformer",
    entry_point="pretrain.py",
    model_dir='/opt/ml/model',
    source_dir="./pretraining",  # Directory containing training script and other files
    output_path="s3://tf-binding-sites/pretraining/output",
    role=role,
    py_version="py310",
    framework_version='2.0.0',
    volume_size=800,
    max_run=1209600,
    instance_count=3,
    instance_type='ml.g5.12xlarge',
    hyperparameters={
        'learning-rate': 1e-4
    },
    tensorboard_output_config=tensorboard_output_config,
    distribution=distribution
)

estimator.fit({'training': training_data_s3_path}, wait=False)