An interactive notebook that trains a Vision Transformer (ViT) model to classify images from a **local dataset stored in Google Drive**. It connects to your drive, loads the custom image dataset, and uses the `transformers` library for training and evaluation. The final model and its performance metrics are saved.

# ViT Model and Image Requirements

### Hardware Note
I have used Google Colab Pro with a GPU-based runtime, but you can also use the free version of Google Colab with a T4 GPU.

---

### Image Specifications

**Image Size** :
The `google/vit-base-patch16-224-in21k` model expects images with a resolution of 224x224 pixels. The `224` in the model's name specifies this requirement. The `ViTFeatureExtractor` used in the script automatically resizes your input images to this standard size during the preprocessing step.

**Color Format** :
The model requires images to be in the RGB color format, meaning they must have 3 color channels. The updated notebook includes a step (`ds.cast_column("image", Image(decode=True, id=None))`) to ensure images are loaded and processed correctly as RGB. Grayscale images would need to be converted to RGB before being fed to the model.

**Dimensionality** :
Each RGB image is treated as a 3D data structure or tensor, with dimensions representing Height, Width, and Channels (e.g., 224 x 224 x 3). When these images are processed in batches, the model's input becomes a 4D tensor (batch size, channels, height, width).

### 1. Install/Upgrade Libraries

This cell ensures the necessary libraries (`datasets`, `evaluate`) are installed and up-to-date.

In [None]:
pip install -U datasets evaluate

### 2. Verify Library and Python Versions

This cell confirms that the correct versions of the libraries and Python are loaded in the environment before running the main script.

In [None]:
import sys
import datasets
import evaluate

print(f"Python version: {sys.version}")
print(f"Datasets version: {datasets.__version__}")
print(f"Evaluate version: {evaluate.__version__}")

### 3. Mount Google Drive

This cell mounts your Google Drive to the Colab environment. This is essential for accessing your local dataset and saving the trained model.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### 4. How to Structure Your Local Dataset

For the script to work, you must organize your image files in the `ImageFolder` format and upload them to Google Drive. The structure should be as follows:

```
My_Dataset/               <-- Your main dataset folder
├── train/                <-- Training images
│   ├── Class_A/          <-- Folder for the first class
│   │   ├── image1.jpg
│   │   └── image2.png
│   └── Class_B/          <-- Folder for the second class
│       ├── image3.jpg
│       └── ...
├── validation/           <-- Validation images (same structure as train)
│   ├── Class_A/
│   └── Class_B/
└── test/                 <-- Test images (same structure as train)
    ├── Class_A/
    └── Class_B/
```

Upload the main folder (e.g., `My_Dataset`) to your Google Drive and update the path in the next cell.

### 5. The Complete Model Training and Evaluation Pipeline

This is the main script that orchestrates the entire process. It now loads data from the Google Drive path you specify.

In [None]:
# train_model.py

import torch
import numpy as np
from datasets import load_dataset, Image
import evaluate 
from transformers import ViTFeatureExtractor, ViTForImageClassification, Trainer, TrainingArguments
from sklearn.metrics import (
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score,
    matthews_corrcoef
)
from google.colab import drive

def main():
    """
    Main function to run the model training and evaluation pipeline.
    """
    # --- Mount Google Drive ---
    # This ensures the drive is connected.
    print("Mounting Google Drive...")
    drive.mount('/content/drive', force_remount=True)
    print("Google Drive mounted successfully.")

    # --- HARDWARE CHECK ---
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Hardware check: Using device = {device}")
    if device == "cpu":
        print("WARNING: Training on a CPU is very slow. Consider using a GPU runtime.")

    print("\nStarting Phase 1: AI Model Training...")

    # 1. Load Your Local Dataset
    # IMPORTANT: Update this path to point to your dataset folder in Google Drive.
    local_dataset_path = "/content/drive/MyDrive/My_Dataset" 

    print(f"\nStep 1: Loading local dataset from '{local_dataset_path}'...")
    try:
        ds = load_dataset('imagefolder', data_dir=local_dataset_path)
        # The .cast_column method ensures images are loaded in RGB format, which is standard for ViT.
        ds = ds.cast_column("image", Image(decode=True, id=None))
        print("Dataset loaded successfully.")
        print(ds)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        print("Please ensure the path is correct and the directory structure matches the 'ImageFolder' format.")
        return

    # Extract labels for model configuration
    labels = ds['train'].features['label'].names
    label2id = {label: i for i, label in enumerate(labels)}
    id2label = {i: label for i, label in enumerate(labels)}
    print(f"\nClasses found: {labels}")

    # 2. Choose Your Framework & Preprocessing
    print("\nStep 2: Initializing Feature Extractor...")
    model_name_or_path = 'google/vit-base-patch16-224-in21k'
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
    print("Feature extractor initialized.")

    # Create a transformation function
    def transform(example_batch):
        inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
        inputs['label'] = example_batch['label']
        return inputs

    print("Applying transformation to the dataset...")
    prepared_ds = ds.with_transform(transform)
    print("Transformation complete.")

    # 3. Define Model and Training Script
    print("\nStep 3: Defining Model Architecture (ViT)...")
    model = ViTForImageClassification.from_pretrained(
        model_name_or_path,
        num_labels=len(labels),
        id2label={str(i): c for i, c in enumerate(labels)},
        label2id={c: str(i) for i, c in enumerate(labels)}
    )
    print("Model defined.")

    # Define training arguments
    training_args = TrainingArguments(
        output_dir="/content/drive/MyDrive/my-pulmonary-fibrosis-vit-local",
        per_device_train_batch_size=16,
        eval_strategy="steps",
        num_train_epochs=1, # Set to a higher number (e.g., 10) for full training
        fp16=True if device == 'cuda' else False,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=2e-4,
        save_total_limit=2,
        remove_unused_columns=False,
        push_to_hub=False,
        report_to='tensorboard',
        load_best_model_at_end=True,
    )

    # Define metrics computation
    accuracy_metric = evaluate.load("accuracy")
    def compute_metrics(p):
        predictions = np.argmax(p.predictions, axis=1)
        return accuracy_metric.compute(predictions=predictions, references=p.label_ids)

    # Define data collator
    def collate_fn(batch):
        return {
            'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
            'labels': torch.tensor([x['label'] for x in batch])
        }

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        train_dataset=prepared_ds["train"],
        eval_dataset=prepared_ds["validation"],
        tokenizer=feature_extractor,
    )

    # Train the model
    print("\nStarting model training...")
    train_results = trainer.train()
    print("Training complete.")

    # 4. Save the Trained Model
    print("\nStep 4: Saving the final model...")
    trainer.save_model()
    trainer.log_metrics("train", train_results.metrics)
    trainer.save_metrics("train", train_results.metrics)
    trainer.save_state()
    print(f"Model saved successfully to '{training_args.output_dir}'")

    # --- Evaluation on Test Set ---
    print("\n--- Starting Evaluation on Test Set ---")
    test_dataset = prepared_ds["test"]
    predictions_output = trainer.predict(test_dataset)
    y_true = np.array(test_dataset.with_format("torch")[:]['label'])
    y_pred = np.argmax(predictions_output.predictions, axis=1)

    # Assuming binary classification for these metrics
    # For multi-class, you may need to adjust the 'average' parameter
    if len(labels) == 2:
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')
        mcc = matthews_corrcoef(y_true, y_pred)
        try:
            auc_roc = roc_auc_score(y_true, y_pred)
            auc_pr = average_precision_score(y_true, y_pred)
        except ValueError:
            auc_roc = float('nan')
            auc_pr = float('nan')
        print(f"Test Set F1 Score: {f1:.4f}")
        print(f"Test Set Precision: {precision:.4f}")
        print(f"Test Set Recall: {recall:.4f}")
        print(f"Test Set MCC: {mcc:.4f}")
        print(f"Test Set AUC-ROC: {auc_roc:.4f}")
        print(f"Test Set AUC-PR: {auc_pr:.4f}")
    
    # Accuracy is always applicable
    accuracy = accuracy_metric.compute(predictions=y_pred, references=y_true)["accuracy"]
    print(f"\nTest Set Accuracy: {accuracy:.4f}")

if __name__ == "__main__":
    main()