In [8]:
##########################################################################
#   IMPORT BLOCK                                                         #
##########################################################################
import gc 

import torch
import torch.nn as nn
import torch.nn.functional as F 
from torch.autograd import Variable
import subprocess
import time 

# fxn taken from https://discuss.pytorch.org/t/memory-leaks-in-trans-conv/12492
def get_gpu_memory_map():   
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ])
    
    return float(result)


##########################################################################
#   CLASS BLOCK                                                          #
##########################################################################

class ClassA(object):
    def __init__(self):
        self.nets = [] 
        
class BadSubclass(ClassA):
    def __init__(self, classifier):
        super(BadSubclass, self).__init__()
        self.classifier = classifier 
        self.nets.append(self.classifier)
        
    def forward(self, inp):
        return self.classifier.forward(inp).squeeze()            
    
class GoodSubclass(BadSubclass):
    def __init__(self, classifier):
        super(BadSubclass, self).__init__()
        self.classifier = classifier
        self.nets.append(self.classifier)
        
    def forward(self, inp):
        return torch.sum(super(GoodSubclass, self).forward(inp))
        
        
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1000, 10000)
        self.fc2 = nn.Linear(10000, 1000)
        self.fcs = [self.fc1, self.fc2]
        
    def forward(self, x):
        for fc in self.fcs:
            x = fc(x)
            x = F.relu(x)
        return x # 1000 dimension output 

    

##########################################################################
#   EXAMPLE BLOCK                                                        #
##########################################################################

def memout_example(bad_or_good):
    # reuse all the code except for which subclass we use 
    # and which grad technique we use 
    
    
    assert bad_or_good in ['bad', 'good']
    if bad_or_good == 'bad':
        subclass = BadSubclass 
        grad_method = lambda output, inp: torch.autograd.backward(
                                          [output], grad_variables=[inp])
    else:
        subclass = GoodSubclass 
        grad_method = lambda output, inp: output.backward() 
            
    
    # Loop through, pick a random input, run it through model
    # then compute gradients, then clean up as much as possible 
        
    for i in xrange(10):    
        print "LOOP: (%s) | BASE STATE" % i, get_gpu_memory_map()
        x = Variable(torch.randn(1, 1000)).cuda()
        model = MyModel().cuda()

        example = subclass(model)
        out = example.forward(x)
        grad_method(out, x)
        print "LOOP: (%s) | PEAK STATE" % i, get_gpu_memory_map()
        del model 
        del example
        del out 
        del x 
        gc.collect()
        torch.cuda.empty_cache()
        time.sleep(5)

        print "LOOP: (%s) | OUT  STATE" % i, get_gpu_memory_map()   
        print '-' * 29 # pretty prints



In [9]:
memout_example('bad')

LOOP: (0) | BASE STATE 5307.0
LOOP: (0) | PEAK STATE 6453.0
LOOP: (0) | OUT  STATE 6071.0
-----------------------------
LOOP: (1) | BASE STATE 6071.0
LOOP: (1) | PEAK STATE 7217.0


KeyboardInterrupt: 

In [10]:
memout_example('good')

LOOP: (0) | BASE STATE 6835.0
LOOP: (0) | PEAK STATE 7639.0
LOOP: (0) | OUT  STATE 6453.0
-----------------------------
LOOP: (1) | BASE STATE 6453.0
LOOP: (1) | PEAK STATE 7639.0
LOOP: (1) | OUT  STATE 6453.0
-----------------------------
LOOP: (2) | BASE STATE 6453.0
LOOP: (2) | PEAK STATE 7639.0


KeyboardInterrupt: 

In [11]:
len('LOOP: (1) | PEAK STATE 3652.0')

29