In [1]:
# Self contained memout assertion 

##########################################################################
#   IMPORT BLOCK                                                         #
##########################################################################
import gc 
import os
import sys 
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable

import config
import prebuilt_loss_functions as plf
import loss_functions as lf 
import utils.pytorch_utils as utils
import utils.image_utils as img_utils
import cifar10.cifar_loader as cifar_loader
import cifar10.cifar_resnets as cifar_resnets
import adversarial_attacks as aa
import adversarial_training as advtrain
import adversarial_evaluation as adveval
import checkpoints
import subprocess
import time 

# fxn taken from https://discuss.pytorch.org/t/memory-leaks-in-trans-conv/12492
def get_gpu_memory_map():   
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ])
    try:
        return float(result)
    except:
        return result


In [None]:
print "BASE STATE", get_gpu_memory_map()
x = Variable(torch.randn(1, 1000)).cuda()
model = MyModel().cuda()

part_loss = PartialLoss(model)
out = torch.sum(part_loss.forward(x))

out.backward()
print "PEAK STATE", get_gpu_memory_map()
del model 
del part_loss

torch.cuda.empty_cache()
time.sleep(2)

print "OUT STATE", get_gpu_memory_map()


In [None]:
get_gpu_memory_map()

In [None]:


##########################################################################
#   FUNCTION BLOCK                                                       #
##########################################################################



def memout_example():    
    assert vars() == {}
    # empty slate 

    # build persistent data 
    val_loader = cifar_loader.load_cifar_data('val', normalize=False, batch_size=16, use_gpu=True)
    
    base_model = cifar_resnets.resnet32()
    adv_trained_net = checkpoints.load_state_dict_from_filename('half_trained_madry.th', base_model)
    adv_trained_net.cuda()
    cifar_normer = utils.DifferentiableNormalize(mean=config.CIFAR10_MEANS,
                                           std=config.CIFAR10_STDS)        
    perceptual_loss = plf.PerceptualXentropy(adv_trained_net, normalizer=cifar_normer, use_gpu=True)            
    pgd_attack_obj = aa.LInfPGD(adv_trained_net, cifar_normer, perceptual_loss, use_gpu=True)
        
    # now loop through batches and show that there's something hanging somewhere...
    for batch_no, (batch, labels) in enumerate(val_loader):
        
        # clean up garbage and clear cuda cache as much as possible 
        gc.collect()
        print "BATCH NUMBER: %s" % batch_no
        print "GPU MEMORY: %s" % get_gpu_memory_map()
        # assert sorted(vars().keys()) == sorted(['labels', 'val_loader', 'batch', 'batch_no'])
        torch.cuda.empty_cache()
        time.sleep(2)
        print "GPU MEMORY: %s" % get_gpu_memory_map()


        
        # load things needed for attack 

        
        adv_images = pgd_attack_obj.attack(batch.cuda(), labels.cuda(), l_inf_bound =8.0/255.0, num_iterations=10, 
                                           verbose=False)
        
        # push things to cpu (in hopes it gets them out of the cache)
        # also delete everything and be sure to collect garbage before next batch 
        batch.cpu()
        labels.cpu()
        del adv_images
        del batch 
        del labels 
        # del pgd_attack_obj 
        # del pgd_perceptual_loss
        # del cifar_normer
        # adv_trained_net.cpu()
        # del adv_trained_net 
        # del base_model 
        
        
    return
    
    
##########################################################################
#   BREAK THE PLANET BLOCK                                               #
##########################################################################
print memout_example()
print "SOMEHOW THIS WORKED??"

Files already downloaded and verified
BATCH NUMBER: 0
GPU MEMORY: 431.0
GPU MEMORY: 431.0
BATCH NUMBER: 1
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 2
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 3
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 4
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 5
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 6
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 7
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 8
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 9
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 10
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 11
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 12
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 13
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 14
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 15
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 16
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 17
GPU MEMORY: 649.0
GPU MEMORY: 583.0
BATCH NUMBER: 18
G

In [None]:

def memout_example_loss_direct():    
    assert vars() == {}
    # empty slate 

    # build persistent 
    data 
    val_loader = cifar_loader.load_cifar_data('val', normalize=False, batch_size=16, use_gpu=True)
    
    lpips = lf.LpipsRegularization(None, use_gpu=True) # need to setup attack batch 
    for batch_no, (batch, labels) in enumerate(val_loader):
        gc.collect()
        print "BATCH NUMBER: %s" % batch_no
        # assert sorted(vars().keys()) == sorted(['labels', 'val_loader', 'batch', 'batch_no'])
        torch.cuda.empty_cache()
        time.sleep(2)
        print "GPU MEMORY: %s" % get_gpu_memory_map()        
        batch = Variable(batch.cuda(), requires_grad=True)
        lpips.setup_attack_batch(batch)
        output = torch.sum(lpips.forward(batch))
        output.backward()

foo = memout_example_loss_direct()

In [None]:
foo.shape