<a href="https://colab.research.google.com/github/srippa/dvlp/blob/main/colab_play.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Run those cells from colab

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
# git repo will be in /content/dvlp

GIT_USERNAME = "srippa" 
GIT_TOKEN = "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"           
GIT_REPOSITORY = "dvlp"     

!git clone https://{GIT_TOKEN}@github.com/{GIT_USERNAME}/{GIT_REPOSITORY}

In [None]:
!git push

In [None]:
 !pip install pytorch_lightning

In [None]:
from pathlib import Path

ROOT = Path('/content')     # default for the drive
PROJ = 'dvlp'       # path to your project on Drive
PROJECT_PATH = ROOT / PROJ

DATA_ROOT_DIR  = Path('/content/datasets/emnist')

!ls {PROJECT_PATH}
%cd {PROJECT_PATH}
!pwd

# run this cell when working locally

In [None]:
from pathlib import Path

ROOT = Path('/opt/dvlp/')                             # default for the code
PROJ = 'dvlp'                                         # path to your project on Drive
PROJECT_PATH = ROOT / PROJ

DATA_ROOT_DIR  = Path('/opt/datasets/emnist')
# !mkdir "{PROJECT_PATH}"I    # in case we haven't created it already   
!ls {PROJECT_PATH}
%cd {PROJECT_PATH}
!pwd

!ls -lh {DATA_ROOT_DIR}



# Code to. be used for both colab and local envs

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

%matplotlib inline

from typing import Callable
from importlib.util import find_spec

import numpy as np
import torch

import matplotlib.pyplot as plt
plt.set_cmap('gray')

import pytorch_lightning as pl
import torchmetrics

from datasets.emnist_ds.ds import EMNIST
from models.emnist_cnn import CNN
from lit_models.base import BaseLitModel

In [None]:
data = EMNIST(DATA_ROOT_DIR)
data.prepare_data()
data.setup()
print(data)

print(f'Train data: {len(data.data_train)}, {type(data.data_train)}')
print(f'Test data : {len(data.data_test)}, {type(data.data_train)}')

x, y = next(iter(data.test_dataloader()))
print(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())
print(y.shape, y.dtype, y.min(), y.max())

In [None]:
fig = plt.figure(figsize=(9, 9))
for i in range(9):
    ax = fig.add_subplot(3, 3, i + 1)
    rand_i = np.random.randint(len(data.data_test))
    image, label = data.data_test[rand_i]
    ax.imshow(image.reshape(28, 28), cmap='gray')
    ax.set_title(data.mapping[label])

# Train a CNN

In [None]:
ac = torchmetrics.Accuracy()

print(pl.__version__)
AVAIL_GPUS = min(1, torch.cuda.device_count())

model = CNN(data_config=data.config())
lit_model = BaseLitModel(model=model)
trainer = pl.Trainer(gpus=AVAIL_GPUS, max_epochs=5)
trainer.fit(lit_model, datamodule=data)

In [None]:
fig = plt.figure(figsize=(9, 9))
for i in range(9):
    ax = fig.add_subplot(3, 3, i + 1)
    rand_i = np.random.randint(len(data.data_test))
    image, label = data.data_test[rand_i]

    image_for_model = image.unsqueeze(0)  # (1, 1, 28, 28)
    logits = model(image_for_model)  # (1, C)
    pred_ind = logits.argmax(-1)  # (1, )
    pred_label = data.mapping[pred_ind]

    ax.imshow(image.reshape(28, 28), cmap='gray')
    ax.set_title(f'Correct: {data.mapping[label]}, Pred: {pred_label}')