# Stanford RNA 3D Folding - Colab Starter

This notebook sets up the environment to train the **RibonanzaNet-based** 3D folding model.

## Features
- **Clones Source Code**: Automatically pulls the latest code from GitHub.
- **Downloads Data**: Fetches competition data using the Kaggle API.
- **Trains Model**: Runs the training loop on the downloaded data.

## Instructions
1. **Set Runtime to GPU**: `Runtime` -> `Change runtime type` -> `T4 GPU`.
2. **Upload `kaggle.json`**: You need a Kaggle API token. Go to Kaggle -> Account -> Create New API Token. Upload the file when prompted.
3. **Run All Cells**.

### 1. Environment Setup

In [None]:
# Install dependencies
!pip install torch numpy biopython kaggle

In [None]:
# Clone the Repository
# REPLACE WITH YOUR REPO URL IF DIFFERENT
!git clone https://github.com/YOUR_USERNAME/rna-folding-solution.git
%cd rna-folding-solution

### 2. Data Download (Kaggle)

In [None]:
from google.colab import files
print("Please upload your kaggle.json file:")
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

print("Downloading Competition Data...")
# Downloading Part 1 data as an example. Modify for Part 2 or Ribonanza as needed.
!kaggle competitions download -c stanford-rna-3d-folding
!unzip -q stanford-rna-3d-folding.zip -d data/

### 3. Run Training

In [None]:
import torch
if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU:", torch.cuda.get_device_name(0))
else:
    device = "cpu"
    print("Using CPU")

from rna_model import RNAModel
from colab_train import main_train_loop

# Initialize Model
model = RNAModel(d_model=128, n_layers=4, n_heads=4).to(device)

# Run Training Loop (Mock Data for demonstration)
# To use real data, update colab_train.py to load from 'data/' directory
pred_output = main_train_loop(model, epochs=5, batch_size=8, device=device)

### 4. Visualize Output

In [None]:
print("Predicted Coordinates Shape:", pred_output['coords'].shape)
print("Sample Coordinates (first 5 residues):\n", pred_output['coords'][0, :5].cpu().numpy())