<a href="https://colab.research.google.com/github/yarengokhn/machine_learning_group_project/blob/main/notebooks/train_on_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train on Google Colab

This notebook sets up the environment and runs the training script for the project.

## 1. Check GPU
Ensure you are using a GPU runtime (Runtime > Change runtime type > T4 GPU).

In [None]:
!nvidia-smi

## 2. Clone Repository
Clone the project repository and checkout the correct branch.

In [None]:
import os

repo_name = "machine_learning_group_project"
repo_url = "https://github.com/yarengokhn/machine_learning_group_project.git"
branch = "hayat.custom_dataset"

if not os.path.exists(repo_name):
    !git clone $repo_url
    %cd $repo_name
    !git checkout $branch
    %cd ..
else:
    %cd $repo_name
    !git fetch origin
    !git checkout $branch
    !git pull origin $branch
    %cd ..

## 3. Install Dependencies
Install the required Python packages.

In [None]:
%cd $repo_name
!pip install -r requirements.txt

## 4. Mount Google Drive (Optional)
Run this cell if you want to save model checkpoints to your Google Drive.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Create a directory for checkpoints in Drive if it doesn't exist
drive_checkpoint_dir = "/content/drive/MyDrive/ml_project_checkpoints"
!mkdir -p $drive_checkpoint_dir

## 5. Run Training
Run the training script. 

**Note:**
- Adjust `--epochs`, `--batch_size`, etc., as needed.
- If you mounted Drive, you can copy the checkpoints there after training, or change the script to save there directly.

In [None]:
# Run training
!python scripts/train.py --epochs 10 --batch_size 32 --model_name colab_model

## 6. Save Checkpoints to Drive (Optional)
Copy the trained model and vocabularies to Google Drive for safekeeping.

In [None]:
# Assuming you mounted drive in step 4
!cp -r checkpoints/* $drive_checkpoint_dir/
print(f"Checkpoints saved to {drive_checkpoint_dir}")