<a href="https://colab.research.google.com/github/piyengar/vehicle-predictor/blob/master/Vehicle_color_trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Colab specific
- Install additional libraries
- Setup integration with google drive
    - Needs these paths: mount point, Dataset storage path, checkpoint storage path
- Setup content folder as git repo and pull codebase from github
    - This can be done by installing github cli and triggering authentication on browser(TBD)
    - for now we can manually set it up


In [None]:
%%capture
! pip install pytorch-lightning
! pip install pytorch-lightning-bolts
! pip install ipywidgets
! pip install torchmetrics
! pip install efficientnet_pytorch


In [None]:

# GITHUB_PAT_PATH=os.path.join(STORAGE_ROOT, 'MyDrive/Gatech/github_pat_colab.txt')

# with open(GITHUB_PAT_PATH) as reader:
#     GITHUB_PAT = reader.readline()


In [None]:
import os
# Path constants
STORAGE_ROOT='/content/drive'
CHECKPOINT_ROOT=os.path.join(STORAGE_ROOT, 'MyDrive/Gatech/CARZAM/checkpoints/color')
PREDICTION_ROOT=os.path.join(STORAGE_ROOT, 'MyDrive/Gatech/CARZAM/predictions/color')
DATASET_ROOT=os.path.join(STORAGE_ROOT, 'MyDrive/Gatech/CARZAM/Datasets')

from google.colab import drive
drive.mount(STORAGE_ROOT)
!mkdir -p "{CHECKPOINT_ROOT}"
!mkdir -p "{PREDICTION_ROOT}"
!ln -s "{CHECKPOINT_ROOT}" checkpoints/
!ln -s "{PREDICTION_ROOT}" predictions/


### Setup up codebase from github -- (TBD)

In [None]:
!git clone "https://{GITHUB_PAT}@github.com/piyengar/vehicle-predictor.git" ./code

## Predict vehicle colors using SVM

In [1]:
# Imports
import numpy as np
from tqdm.notebook import tqdm
from time import sleep
from sklearn.linear_model import SGDClassifier
from color.dataset import VeriDataset
from color import ColorDataModule, valid_archs, ColorPredictionWriter

In [2]:
# one of : VRIC, Cars196, VehicleID, BoxCars116k, CompCars, Veri, Combined
train_dataset_name = 'Veri'

# Learning rate/eta0
lr=0.02
batch_size=128
max_epochs=10

In [None]:
allowed_color_list = [
                    'black',
                    'white',
                    'red',
                    'yellow',
                    'blue',
                    'gray'
]
model = SGDClassifier(learning_rate='constant', eta0=lr)
# init datamodule
dm = ColorDataModule(
    dataset_name=train_dataset_name, 
    data_dir="dataset", 
    batch_size=batch_size,
    allowed_color_list=allowed_color_list
)
dm.setup('fit')
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()
classes = [i for i in range(len(allowed_color_list))]
for epoch in range(max_epochs):
    # Train
    stats = {'accuracy':0.0}
    batches = tqdm(enumerate(train_dataloader), total=len(train_dataloader), postfix=stats)
    batches.set_description('Training')
    accs = []
    for mini_batch_idx, (x,_, y) in batches:
        n = len(x)
        x = x.view(n, -1).numpy()
        y = y.numpy()
        model.partial_fit(x, y, classes)
        accs.append(model.score(x, y))
        stats['accuracy'] = np.mean(accs)
        batches.set_postfix(stats)
        
    # Validate
    stats = {'accuracy':0.0}
    batches = tqdm(enumerate(val_dataloader), total=len(val_dataloader), postfix=stats)
    batches.set_description('Validating')
    accs = []
    for mini_batch_idx, (x,_, y) in batches:
        n = len(x)
        x = x.view(n, -1).numpy()
        y = y.numpy()
        accs.append(model.score(x, y))
        stats['accuracy'] = np.mean(accs)
        batches.set_postfix(stats)
    
