In [1]:
DEVICE = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'

### Imports

#### Standard Imports

In [2]:
import os
import sys

In [3]:
import logging

In [4]:
from pathlib import Path

In [5]:
import gc

In [6]:
import pickle

In [7]:
from itertools import product
from functools import reduce

In [8]:
import numpy as np

In [9]:
import matplotlib.pyplot as plt

In [10]:
import torch

In [11]:
from PIL import Image, ImageDraw, ImageOps
import skimage

In [12]:
from tqdm.notebook import tqdm

In [13]:
from prettytable import PrettyTable

---

In [14]:
%load_ext rich

#### Custom Imports

In [15]:
sys.path.append('..')

---

In [16]:
# Autoreload Custom Modules
%load_ext autoreload
%autoreload 1

---

In [17]:
from imagelib import Im
%aimport imagelib

In [18]:
from rollout import Rollout, rollout
%aimport rollout

In [19]:
from inference import infer
%aimport inference

#### Config

In [20]:
np.set_printoptions(linewidth=1000)

In [21]:
plt.set_loglevel('error')

In [22]:
logging.basicConfig(level=logging.DEBUG)

In [23]:
logging.getLogger('imagelib').setLevel(logging.DEBUG)

In [24]:
logging.getLogger('rollout').setLevel(logging.DEBUG)

In [25]:
logging.getLogger('inference').setLevel(logging.DEBUG)

---

In [26]:
if DEVICE == 'cuda':
    # Empty cache and collect garbage
    logging.debug("> INIT / Clearing CUDA cache and collecting garbage.")
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.memory_summary('cuda', abbreviated=True)

### Functions

In [27]:
# Load Model
def load_model(model_name, model_version, *, device=DEVICE):
    try:
        if model_name == 'clip':
            import clip as Model
        elif model_name == 'flatnet':
            from flatnet import FlatNet as Model
        elif model_name == 'flatnetlite':
            from flatnet import FlatNetLite as Model
        else:
            raise ValueError(f'> Invalid model name {model_name}.')

        model, preprocess = Model.load(model_version, device=device)
        model.eval()

        # logging.debug(f'> Model loading to {device} successful.')
        return model, preprocess
    except RuntimeError as e:
        logging.error(f'> Model loading to {device} failed.')
        if device == 'cuda':
            logging.debug('> Clearing CUDA cache and collecting garbage.')
            torch.cuda.empty_cache()
            gc.collect()
            torch.cuda.memory_summary('cuda', abbreviated=True)
        raise RuntimeError(e)

In [28]:
def count_parameters(model_name, model_version, **kwargs):
    model, *_ = load_model(model_name, model_version, **kwargs)
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    # print(table)
    # print(f"Total Trainable Params: {total_params}")
    print(f'{model_name:<12} {total_params:>12} {model_version}')

## Main

In [29]:
clip_vit_14 = ('clip', 'ViT-L/14')
clip_resnet = ('clip', 'RN50x16')
flatnetlite = ('flatnetlite', '../results/__models/1_mnist_flatnetlite_60000-10000-0-0-0-10000-0-0-128-256-8-0.001-0.2-True-0.045-5-True-True-8-28-28-0.0-False-circles/mnist_flatnetlite_60000-10000-0-0-0-10000-0-0-128-256-8-0.001-0.2-True-0.045-5-True-True-8-28-28-0.0-False-circles.pt')

In [30]:
flatnet2  = ('flatnet', '../results/__models/2_mnist_cnn_60000-10000-0-0-0-10000-0-0-random-128-256-8-0.001-0.2-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False/mnist_cnn_60000-10000-0-0-0-10000-0-0-random-128-256-8-0.001-0.2-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False.pt')
flatnet7  = ('flatnet', '../results/__models/7_mnist_cnn_60000-10000-0-0-10000-10000-0-0-random-128-256-8-0.001-0.0-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False/mnist_cnn_60000-10000-0-0-10000-10000-0-0-random-128-256-8-0.001-0.0-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False.pt')

In [31]:
count_parameters(*clip_vit_14)
count_parameters(*clip_resnet)

clip            427616513 ViT-L/14
clip            290979217 RN50x16


In [32]:
count_parameters(*flatnetlite)

flatnetlite      23912330 ../results/__models/1_mnist_flatnetlite_60000-10000-0-0-0-10000-0-0-128-256-8-0.001-0.2-True-0.045-5-True-True-8-28-28-0.0-False-circles/mnist_flatnetlite_60000-10000-0-0-0-10000-0-0-128-256-8-0.001-0.2-True-0.045-5-True-True-8-28-28-0.0-False-circles.pt


In [33]:
count_parameters(*flatnet2)
count_parameters(*flatnet7)

flatnet          95835914 ../results/__models/2_mnist_cnn_60000-10000-0-0-0-10000-0-0-random-128-256-8-0.001-0.2-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False/mnist_cnn_60000-10000-0-0-0-10000-0-0-random-128-256-8-0.001-0.2-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False.pt
flatnet          95835914 ../results/__models/7_mnist_cnn_60000-10000-0-0-10000-10000-0-0-random-128-256-8-0.001-0.0-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False/mnist_cnn_60000-10000-0-0-10000-10000-0-0-random-128-256-8-0.001-0.0-False-0.045-5-True-True-8-28-28-0.0-False-circles-False-True-False.pt
