In [1]:
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
import os

In [2]:
# AWS params
bucket = "sagemaker-gpl"
prefix = "generated"
dataset = "simple"

sagemaker_session = sagemaker.Session()
role = get_execution_role()

In [3]:
# Take the contents of local folder path and put contents in folder key_prefix in bucket
# Documentation:
# https://sagemaker.readthedocs.io/en/stable/api/utility/session.html#sagemaker.session.Session.upload_data
inputs = sagemaker_session.upload_data(
    path="{}/{}".format(prefix, dataset), 
    bucket=bucket, 
    key_prefix="{}/{}".format(prefix, dataset)
)
print(inputs) # do not pass inputs as output directory as inputs change to a local directory when inside the training instance

s3://sagemaker-gpl/generated/simple


In [4]:
hyperparameters = {
#    "path_to_generated_data": f"generated/{dataset}",
    "base_ckpt": 'GPL/msmarco-distilbert-margin-mse',  
    "gpl_score_function": "dot",
    "batch_size_gpl": 32,
    "gpl_steps": 140000,
    "new_size": -1,
    "queries_per_passage": -1,
    "output_dir": "s3://{}/{}/{}".format(bucket, prefix, dataset),
    # "evaluation_data": f"./{dataset}",
    # "evaluation_output": f"evaluation/{dataset}",
    "generator": "BeIR/query-gen-msmarco-t5-base-v1",
    # "retrievers": ["msmarco-distilbert-base-v3", "msmarco-MiniLM-L-6-v3"],
    # "retriever_score_functions": ["cos_sim", "cos_sim"],
    "cross_encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
    "qgen_prefix": "qgen",
}

In [5]:
trainer = PyTorch(
    role = role, 
    entry_point ='train.py',
    instance_type="ml.m5.large",
    instance_count=1,
    source_dir = './gpl',
    framework_version = '1.9.0',
    py_version = 'py38',
    sagemaker_session=sagemaker_session,
    output_path=inputs, # you need this?
    hyperparameters=hyperparameters,
)

In [None]:
trainer.fit({'train': inputs})

2022-09-21 16:24:54 Starting - Starting the training job...
2022-09-21 16:25:18 Starting - Preparing the instances for trainingProfilerReport-1663777494: InProgress
......
2022-09-21 16:26:18 Downloading - Downloading input data......
2022-09-21 16:27:19 Training - Downloading the training image......
2022-09-21 16:28:19 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-09-21 16:28:11,080 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2022-09-21 16:28:11,082 sagemaker-training-toolkit INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2022-09-21 16:28:11,097 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2022-09-21 16:28:11,104 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m202