In [1]:
from ptflops import get_model_complexity_info
import torchvision.models as models

#model = models.resnet50()
#flops, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True)
#print('FLOPs:', flops)
#print('Parameters:', params)

In [2]:
import sys
import os
os.chdir('../')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm
from omegaconf import OmegaConf
from hydra import compose, initialize
import matplotlib.pyplot as plt

from ablations.models import BayesianMLP
from experiments.utils import get_configs, get_datasets
import random
%load_ext autoreload
%autoreload 2

In [36]:
# Point to your config directory and name
config_dir = "configs/"  # adjust as needed
config_name = "camelyon17_best.yaml"

with initialize(config_path=config_dir, version_base='1.2'):
    args = compose(config_name=config_name)

# cfg is now a DictConfig
print(args)
dataset = get_datasets(args)
model_config, train_config = get_configs(args)
trainloader = torch.utils.data.DataLoader(
    dataset['train'],
    batch_size=train_config.batch_size,
    shuffle=True,
    num_workers=train_config.num_workers,
    pin_memory=train_config.pin_memory,
    persistent_workers=True,
)

valloader = torch.utils.data.DataLoader(
    dataset['valid'],
    batch_size=train_config.batch_size,
    shuffle=False,
    num_workers=train_config.num_workers,
    pin_memory=train_config.pin_memory,
    persistent_workers=True,
)

oodloader = torch.utils.data.DataLoader(
    dataset['dpddm_ood'],
    batch_size = train_config.batch_size,
    shuffle=False,
    num_workers=train_config.num_workers,
    pin_memory=True, 
    persistent_workers=True
)

{'train': {'disagreement_epochs': 5, 'disagreement_optimizer': 'torch.optim.AdamW', 'disagreement_wd': 0.0001, 'disagreement_lr': 0.01, 'disagreement_batch_size': 64, 'disagreement_alpha': 0.8, 'num_epochs': 2, 'batch_size': 256, 'lr': 1e-05, 'wd': 0.0001, 'optimizer': 'torch.optim.AdamW', 'clip_val': 1, 'val_freq': 1, 'num_workers': 10, 'pin_memory': True}, 'dataset': {'name': 'camelyon17', 'num_classes': 2, 'data_dir': '/voyager/datasets/', 'frac': 1.0, 'download': False}, 'dpddm': {'Phi_size': 1000, 'n_post_samples': 5000, 'data_sample_size': 100, 'temp': 2, 'n_repeats': 100}, 'model': {'name': 'resnet_model', 'resnet_type': 'resnet34', 'hidden_dim': 1000, 'resnet_pretrained': True, 'freeze_features': True, 'reg_weight_factor': 100, 'param': 'diagonal', 'prior_scale': 5, 'wishart_scale': 2, 'return_ood': False}, 'wandb_cfg': {'project': 'wilds_dpddm', 'entity': 'viet', 'job_type': 'train', 'log_artifacts': True}, 'monitor_type': 'bayesian', 'from_pretrained': False, 'seed': 57, 'sel

In [4]:
for k in dataset:
    print(dataset[k].X.shape)

torch.Size([358, 9])
torch.Size([120, 9])
torch.Size([120, 9])
torch.Size([119, 9])
torch.Size([323, 9])


In [37]:
from bayesian_dpddm.models import MLPModel, ConvModel, ResNetModel

model = ResNetModel(model_config, train_size=len(dataset['train']))
model.out_layer = nn.Linear(model_config.hidden_dim, model_config.out_features)

In [42]:
flops, params = get_model_complexity_info(model, (3,32, 32), as_strings=True)
print('FLOPs:', flops)
print('Parameters:', params)

ResNet(
  21.8 M, 100.000% Params, 75.52 MMac, 99.881% MACs, 
  (conv1): Conv2d(9.41 k, 0.043% Params, 2.41 MMac, 3.186% MACs, 3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(128, 0.001% Params, 32.77 KMac, 0.043% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.022% MACs, inplace=True)
  (maxpool): MaxPool2d(0, 0.000% Params, 16.38 KMac, 0.022% MACs, kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    221.95 k, 1.018% Params, 14.23 MMac, 18.821% MACs, 
    (0): BasicBlock(
      73.98 k, 0.339% Params, 4.74 MMac, 6.274% MACs, 
      (conv1): Conv2d(36.86 k, 0.169% Params, 2.36 MMac, 3.121% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, 0.001% Params, 8.19 KMac, 0.011% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(0, 0.000% Par

## Camelyon17

In [None]:
import torch
from ptflops import get_model_complexity_info
from torchvision.models import resnet34

model = resnet34()
macs, params = get_model_complexity_info(model, (3, 224, 224), print_per_layer_stat=False)
print(macs, params)

3.68 GMac 21.8 M


In [None]:
fwd_flops = 7.36 # GFLOPs
single_input_flops = fwd_flops * 3
epoch_flops = single_input_flops * len(dataset['train'])
training_flops = epoch_flops * train_config.num_epochs
training_flops # 1.56 PFLOPs

# D3M 7.36 for dis rate, 7.36 TFLOPs or  
7.36 * 100 * 1000 

# Detectron
single_input_flops * 100 * 5 * 1000

11040000.0

## CIFAR

In [None]:
# D3M
fwd_flops = 1.3 # GFLOPs
single_input_flops = fwd_flops * 3
epoch_flops = single_input_flops * len(dataset['train'])
training_flops = epoch_flops * train_config.num_epochs
training_flops # 1.56 PFLOPs


# 1.341 for dis rate, 0 for psuedolabels, * 1000 * 100 = 1.341 TFLOPs or 0.13 PFLOPs

1560000.0

In [35]:
# Detectron
single_input_flops * 100 * 5 * 1000

1950000.0000000002

## MLP Model FLOPs

In [14]:
# one forward pass 
fwd_flops = (160 + 272 * 4 + 34 *24 + 78) * 2
# one backward pass 
total_flops_per_input = fwd_flops * 3
total_flops_per_epoch = total_flops_per_input * len(dataset['train'])
fwd_flops, total_flops_per_epoch * train_config.num_epochs # total training flops

# 230 MFLOPs for training, 9 mil flops for compute dis rate. * 1000 so 9 bil flops, or 9GFLOPs.


(4284, 230050800)

In [None]:
# one forward pass 
fwd_flops = (160 + 272 * 4 + 34 + 78) * 2
# one backward pass 
total_flops_per_input = fwd_flops * 3
total_flops_per_epoch = total_flops_per_input * len(dataset['train'])
fwd_flops, total_flops_per_epoch * train_config.num_epochs # total training flops

total_flops_per_input * 458 * 5 * 1000

# 146 MFLOPs, 


18686400000

In [16]:
len(dataset['train'])

358