# utils

> utility stuff.
 



In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#|export
from fastcore.test import *
from fastai.vision.all import *
import torch
from torchvision.models import resnet18, resnet34, resnet50
import random 
import os 
import numpy as np
import yaml
import configparser
from types import SimpleNamespace
import importlib
from nbdev import config


In [None]:
#|export
cfg = config.get_config()
PACKAGE_NAME = cfg.lib_name

In [None]:
#| export
def test_grad_on(model):
    """
    Test that all grads are on for modules with parameters.
    """
    for name, module in model.named_modules():
        # Check each parameter in the module
        for param_name, param in module.named_parameters(recurse=False):
            assert param.requires_grad, f"Gradients are off for {name}.{param_name}"

def test_grad_off(model):
    """
    Test that all non-batch norm grads are off, but batch norm grads are on.
    """
    for name, module in model.named_modules():
        # Distinguish between BatchNorm and other layers
        if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
            for param_name, param in module.named_parameters(recurse=False):
                assert param.requires_grad, f"BatchNorm parameter does not require grad in {name}.{param_name}"
        else:
            for param_name, param in module.named_parameters(recurse=False):
                assert not param.requires_grad, f"Gradients are on for non-BatchNorm layer {name}.{param_name}"

In [None]:
#| export
def seed_everything(seed=42):
    """"
    Seed everything.
    """   
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
#| export

def adjust_config_with_derived_values(config):
    # Adjust n_in based on dataset
    if config.dataset == 'cifar10':
        config.n_in = 3

    # Adjust encoder_dimension based on architecture
    if config.arch == 'resnet18':
        config.encoder_dimension = 512
    elif config.arch == 'resnet34':
        config.encoder_dimension = 512
    elif config.arch == 'resnet50':
        config.encoder_dimension = 2048

    return config

def load_config(file_path):
    with open(file_path, 'r') as f:
        config = yaml.safe_load(f)
        config = SimpleNamespace(**config)
        config = adjust_config_with_derived_values(config)
        

    return config

In [None]:
#| export

def get_ssl_dls(dataset,bs,device):
    # Define the base package name in a variable for easy modification

    try:
        # Construct the module path
        module_path = f"{PACKAGE_NAME}.{dataset}_dataloading"
        
        # Dynamically import the module
        dataloading_module = importlib.import_module(module_path)
    except ModuleNotFoundError:
        # Handle the case where the module cannot be found
        raise ImportError(f"Could not find a data loading module for '{dataset}'. "
                          f"Make sure '{module_path}' exists and is correctly named.") from None
    
    # Assuming the function name follows a consistent naming convention
    func_name = f"get_bt_{dataset}_train_dls"
    try:
        # Retrieve the data loading function from the module
        data_loader_func = getattr(dataloading_module, func_name)
    except AttributeError:
        # Handle the case where the function does not exist in the module
        raise AttributeError(f"The function '{func_name}' was not found in '{module_path}'. "
                             "Ensure it is defined and named correctly.") from None
    
    # Proceed to call the function with arguments from the config
    try:
        dls_train = data_loader_func(bs=bs,device=device)
    except Exception as e:
        # Handle any errors that occur during the function call
        raise RuntimeError(f"An error occurred while calling '{func_name}' from '{module_path}': {e}") from None
    
    return dls_train


In [None]:
#| export

#| export

@torch.no_grad()
def get_resnet_encoder(model,n_in=3):
    model = create_body(model, n_in=n_in, pretrained=False, cut=len(list(model.children()))-1)
    model.add_module('flatten', torch.nn.Flatten())
    return model

# @torch.no_grad()
# def create_resnet50_encoder(weight_type):

#     #pretrained=True if 'weight_type' in ['bt_pretrain', 'supervised_pretrain'] else False

#     if weight_type == 'bt_pretrain': model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
    
#     elif weight_type == 'no_pretrain': model = resnet50()

#     elif weight_type == 'supervised_pretrain': model = resnet50(weights='IMAGENET1K_V2')

#     #ignore the 'pretrained=False' argument here. Just means we use the weights above 
#     #(which themselves are either pretrained or not)
#     encoder = get_resnet_encoder(model)

#     return encoder

@torch.no_grad()
def resnet_arch_to_encoder(arch:str,weight_type='random'):
    """Given resnet architecture, return the encoder. Works for 3 channels.
       The 'weight_type' argument is used to specify whether the model is pretrained or not
    """

    n_in=3

    test_eq(arch in ['resnet18','resnet34','resnet50'],True)
    test_eq(weight_type in ['bt_pretrained','supervised_pretrained','random'],True)

    if weight_type == 'bt_pretrained': test_eq(arch,'resnet50')

    
    if arch == 'resnet50':

        if weight_type == 'bt_pretrained':
            _model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')

        elif weight_type == 'supervised_pretrained':
            _model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

        elif weight_type == 'random':
            _model = resnet50()
        

    elif arch == 'resnet34':

        if weight_type == 'supervised_pretrained':
            _model = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)

        elif weight_type == 'random':
            _model = resnet34() 

    elif arch == 'resnet18':
        if weight_type == 'supervised_pretrained':
            _model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) 

        elif weight_type == 'random':
            _model = resnet18()
        
    else: raise ValueError('Architecture not recognized')

    return get_resnet_encoder(_model,n_in) 



In [None]:

import nbdev; nbdev.nbdev_export()