In [7]:
import torch
from torch import nn
import torchvision.models as models
from submodule_cv import (ChunkLookupException, setup_log_file,
                                                  gpu_selector, PatchHanger, EarlyStopping, mixup_data)


class ModelTrainer(PatchHanger):
    def __init__(self, model_config_location):
        self.model_config_location = model_config_location
        self.model_config = self.load_model_config()
        self.model = self.build_model()


    def estimate_memory(self, sample_input, use_amp=False, device=0):
        """Predict the maximum memory usage of the model.
        Args:
            sample_input (torch.Tensor): A sample input to the network.
            use_amp (bool): whether to estimate based on using mixed precision
            device (torch.device): the device to use
        """
        # Reset model and optimizer
        self.model.cpu()
        optimizer = self.optimizer_type(self.model.parameters(), lr=.001)
        a = torch.cuda.memory_allocated(device)
        self.model.to(device)
        b = torch.cuda.memory_allocated(device)
        model_memory = b - a
        model_input = sample_input

        output = self.model(model_input.to(device)).sum()
        c = torch.cuda.memory_allocated(device)
        if use_amp:
            amp_multiplier = .5
        else:
            amp_multiplier = 1
        forward_pass_memory = (c - b)*amp_multiplier
        gradient_memory = model_memory
        if isinstance(optimizer, torch.optim.Adam):
            o = 2
        elif isinstance(optimizer, torch.optim.RMSprop):
            o = 1
        elif isinstance(optimizer, torch.optim.SGD):
            o = 0
        else:
            raise ValueError("Unsupported optimizer. Look up how many moments are" +
                "stored by your optimizer and add a case to the optimizer checker.")
        gradient_moment_memory = o*gradient_memory
        total_memory = model_memory + forward_pass_memory + gradient_memory + gradient_moment_memory

        return total_memory

    def test_memory(self, in_size=224, batch_size=1, use_amp=False, device=0):
        sample_input = torch.randn(batch_size, 3, in_size, in_size, dtype=torch.float32)

        max_mem_est = self.estimate_memory(self.model, sample_input, use_amp=use_amp)
        print("Maximum Memory Estimate", max_mem_est)
        optimizer = self.optimizer_type(self.model.parameters(), lr=.001)
        print("Beginning mem:", torch.cuda.memory_allocated(device), "Note - this may be higher than 0, which is due to PyTorch caching. Don't worry too much about this number")
        self.model.to(device)
        print("After model to device:", torch.cuda.memory_allocated(device))
        for i in range(3):
            print("Iteration", i)
            with torch.cuda.amp.autocast(enabled=use_amp):
                a = torch.cuda.memory_allocated(device)
                out = self.model(sample_input.to(device)).sum() # Taking the sum here just to get a scalar output
                b = torch.cuda.memory_allocated(device)
            print("1 - After forward pass", torch.cuda.memory_allocated(device))
            print("2 - Memory consumed by forward pass", b - a)
            out.backward()
            print("3 - After backward pass", torch.cuda.memory_allocated(device))
            optimizer.step()
        print("4 - After optimizer step", torch.cuda.memory_allocated(device))

ModuleNotFoundError: No module named 'pynvml'

In [4]:
test_memory(batch_size=64)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/pouya/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
55.0%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100.0%


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx