In [None]:
import torch
from torch.autograd import Variable
!pip install wandb==0.10.7

In [None]:
torch.cuda.memory_allocated('cuda')

In [None]:
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn

# create a simple Sequential model
model = nn.Sequential(
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
    nn.ReLU()
)

# create the model inputs
input_var = Variable(torch.randn(1, 100), requires_grad=True)

# set the number of checkpoint segments
segments = 2

# get the modules in the model. These modules should be in the order
# the model should be executed
modules = [module for k, module in model._modules.items()]

# now call the checkpoint API and get the output
out = checkpoint_sequential(modules, segments, input_var)

# run the backwards pass on the model. For backwards pass, for simplicity purpose, 
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()

# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
    grad_checkpointed[name] = param.grad.data.clone()

In [None]:
model.to('cuda')

In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.checkpoint as checkpoint

def conv_bn(inp, oup, stride, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):
    return nn.Sequential(
        conv_layer(inp, oup, 3, stride, 1, bias=False),
        norm_layer(oup,np.sqrt(0.1)),
        nlin_layer(inplace=True)
    )

def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):
    return nn.Sequential(
        conv_layer(inp, oup, 1, 1, 0, bias=False),
        norm_layer(oup,momentum=np.sqrt(0.1)),
        nlin_layer(inplace=True)
    )

class Flatten(nn.Module):
  def forward(self, x):
    N, C, H, W = x.size() # read in N, C, H, W
    return x.view(N, -1)

class Hswish(nn.Module):
    def __init__(self, inplace=True):
        super(Hswish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x * F.relu6(x + 3., inplace=self.inplace) / 6.

class Hsigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3., inplace=self.inplace) / 6.

class SEModule(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            Hsigmoid()
            # nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class Identity(nn.Module):
    def __init__(self, channel):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

def make_divisible(x, divisible_by=8):
    import numpy as np
    return int(np.ceil(x * 1. / divisible_by) * divisible_by)

class FuSe(nn.Module):
    def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE',j=0):
        super(FuSe, self).__init__()
        assert stride in [1, 2]
        assert kernel in [3, 5]
        padding = (kernel - 1) // 2
        conv_layer = nn.Conv2d
        norm_layer = nn.BatchNorm2d
        if nl == 'RE':
            nlin_layer = nn.ReLU # or ReLU6
        elif nl == 'HS':
            nlin_layer = Hswish
        else:
            raise NotImplementedError
        if se:
            SELayer = SEModule
        else:
            SELayer = Identity
        self.conv1 = conv_layer(inp, exp, 1, 1, 0, bias=False)
        self.bn1 = norm_layer(exp)
        self.nl1 = nlin_layer(inplace=True)
        self.conv2_h = conv_layer(exp, exp, kernel_size=(1, kernel),stride=stride, padding=(0, padding), groups=exp, bias=False)
        self.bn2_h = norm_layer(exp)
        self.conv2_v = conv_layer(exp, exp, kernel_size=(kernel, 1),stride=stride, padding=(padding, 0), groups=exp, bias=False)
        self.bn2_v = norm_layer(exp)
        self.se1 = SELayer(2*exp)
        self.nl2 = nlin_layer(inplace=True)
        self.conv3 = conv_layer(2*exp, oup, 1, 1, 0, bias=False)
        self.bn3 = norm_layer(oup)
        if j!=4 and j!=9:
            self.bn1.momentum=np.sqrt(self.bn1.momentum)
            self.bn2_h.momentum=np.sqrt(self.bn2_h.momentum)
            self.bn2_v.momentum=np.sqrt(self.bn2_v.momentum)
            self.bn3.momentum=np.sqrt(self.bn3.momentum)

    def forward(self, x,i=0):
        if i==5:
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.nl1(out)
            out1 = self.bn2_h(self.conv2_h(out))
            out2 = self.bn2_v(self.conv2_v(out))
            out = torch.cat([out1, out2], 1)
            out = self.se1(out)
            out = self.nl2(out)
            out = self.conv3(out)
            out = self.bn3(out)
        else:  
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.nl1(out)
            out1 = self.bn2_h(self.conv2_h(out))
            out2 = self.bn2_v(self.conv2_v(out))
            out = torch.cat([out1, out2], 1)
            out = self.se1(out)
            out = self.nl2(out)
            out = self.conv3(out)
            out = self.bn3(out)
        return out

class FuSeConv(nn.Module):
    def __init__(self, n_class=100, input_size=224, dropout=0.2):
        super(FuSeConv, self).__init__()
        input_channel = 16
        last_channel = 1024
        self.mem=[]
        self.fmem_len=0
        fuse_setting = [
           # k, exp, c,  se,     nl,  s,
           [3, 16,  16,  True,  'RE', 2],
           [3, 72,  24,  False, 'RE', 2],
           [3, 88,  24,  False, 'RE', 1],
           [5, 96,  40,  True,  'HS', 2],
           [5, 240, 40,  True,  'HS', 1],
           [5, 240, 40,  True,  'HS', 1],
           [5, 120, 48,  True,  'HS', 1],
           [5, 144, 48,  True,  'HS', 1],
           [5, 288, 96,  True,  'HS', 2],
           [5, 576, 96,  True,  'HS', 1],
           [5, 576, 96,  True,  'HS', 1],
        ]
        # building first layer
        assert input_size % 32 == 0
        self.features = [conv_bn(3, input_channel, 2, nlin_layer=Hswish)]
#       self.features[0].momentum=np.sqrt(self.features[0].momentum)
        self.classifier = []
        # building mobile blocks
        j=0
        for k, exp, c, se, nl, s in fuse_setting:
            output_channel = c
            exp_channel = exp
            self.features.append(FuSe(input_channel, output_channel, k, s, exp_channel, se, nl,j))
            input_channel = output_channel
            j+=1
        # building last several layers
        last_conv = 576
        self.features.append(conv_1x1_bn(input_channel, last_conv, nlin_layer=Hswish))
        #self.features.append(SEModule(last_conv))
        self.features.append(nn.AdaptiveAvgPool2d(1))
        self.features.append(nn.Conv2d(last_conv, last_channel, 1, 1, 0))
        self.features.append(Hswish(inplace=True))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)
        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(last_channel, n_class),
        )
        self._initialize_weights()
        self.modules = [module for k, module in self.features._modules.items()]
        self.cpn=[5,10,14]  #1.[0,12,14]  #[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] #[0,12,14]  #check_pointing_numbers

    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(inputs[0])
            self.mem.append(torch.cuda.memory_allocated('cuda')-self.start_mem)
            return inputs
        return custom_forward

    def forward(self, x):
        torch.cuda.reset_max_memory_allocated('cuda')
        i=0
        self.start_mem=torch.cuda.memory_allocated('cuda')
        for _ in self.modules:
            if i in self.cpn:
                if i is 5:
                    x = _(x,i)
                else:
                    x = _(x)
                self.mem.append(torch.cuda.memory_allocated('cuda')-self.start_mem)
            else:
                x = checkpoint.checkpoint(self.custom(_),x)
            i+=1
        x = x.mean(3).mean(2)       
        self.mem.append(torch.cuda.memory_allocated('cuda')-self.start_mem)
        x = self.classifier(x)
        self.mem.append(torch.cuda.memory_allocated('cuda')-self.start_mem)
        self.fmem_len=len(self.mem)
        return x

    def _initialize_weights(self):
        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)


In [None]:
fc=FuSeConv().to('cuda')
x = Variable(torch.randn(1,3,1000,1000), requires_grad=True).to('cuda')
start_time=time.time()
out=fc(x)
fc.zero_grad()
out.sum().backward(retain_graph=False)
fc.mem.append(torch.cuda.memory_allocated('cuda')-fc.start_mem)
compute_time=time.time()-start_time
print('time:',compute_time)
print('mem: ',fc.mem,fc.fmem_len)
#wandb.log({'compute_time':compute_time})

In [None]:
#peak_mem=[151164416,196726272,154889216,163175424,214085632]
#peak_comp=[0.1281,0.1269,0.1382,0.1248,0.1292]
peak_mem=[3404800,436736000,2928128,3347456,2435072,2155520]
peak_comp=[0.21,0.1857,0.184,0.1937,0.1468,0.1486]
labels=['exp1','exp2','exp3','exp4','exp5','exp6']
import matplotlib.pyplot as pyplot
Fig, ax = pyplot.subplots()
pyplot.rcParams['legend.numpoints'] = 1
for i, (mark, color) in enumerate(zip(
    ['s', 'o', 'D', 'v','o','s'], ['r', 'g', 'b', 'purple','y','g'])):
    ax.plot(peak_mem[i], peak_comp[i], color=color,
            marker=mark,
            markerfacecolor='None',
            markeredgecolor=color,
            linestyle = 'None',
            label=labels[i])
pyplot.xlabel('Peak GPU memory (in bytes)')
pyplot.ylabel('Compute time (in s)')
pyplot.title('Compute vs Memory tradeoff')
pyplot.legend()
pyplot.show()
#plt.plot(peak_mem,peak_comp)

In [None]:
from matplotlib import pyplot as plt
plt.plot(fc.mem)

In [None]:
import wandb
wandb.init(project='fuse_conv_ckp',name='exp3')
model=FuSeConv().to('cuda')
wandb.watch(model,log='all')
input = Variable(torch.randn(1,3,1000,1000), requires_grad=True).to('cuda')
start_time=time.time()
out=model(input)
model.zero_grad()
out.sum().backward(retain_graph=False)
model.mem.append(torch.cuda.memory_allocated('cuda')-model.start_mem)
compute_time=time.time()-start_time
memory=model.mem
for _ in memory:
    wandb.log({'GPU_memory':_})
wandb.log({'Compute time':compute_time})
wandb.log({'Backprop_start_step':model.fmem_len})

In [None]:
!pip install pytorch-pretrained-bert

In [None]:
#bert
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (CLF) to train on the Next-Sentence task (see BERT's paper).

#An example on how to use this class is given i
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 6
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['who', 'was', 'jim', 'henson', '?', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])


In [None]:
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

# Predict hidden states features for each layer
encoded_layers, _ = model(tokens_tensor, segments_tensors)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12

In [None]:
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps

from tqdm import tqdm

import boto3
from botocore.exceptions import ClientError
import requests

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
                                               Path.home() / '.pytorch_pretrained_bert'))


def url_to_filename(url: str, etag: str = None) -> str:
    """
    Convert `url` into a hashed filename in a repeatable way.
    If `etag` is specified, append its hash to the url's, delimited
    by a period.
    """
    url_bytes = url.encode('utf-8')
    url_hash = sha256(url_bytes)
    filename = url_hash.hexdigest()

    if etag:
        etag_bytes = etag.encode('utf-8')
        etag_hash = sha256(etag_bytes)
        filename += '.' + etag_hash.hexdigest()

    return filename


def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
    """
    Return the url and etag (which may be ``None``) stored for `filename`.
    Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE

    cache_path = os.path.join(cache_dir, filename)
    if not os.path.exists(cache_path):
        raise FileNotFoundError("file {} not found".format(cache_path))

    meta_path = cache_path + '.json'
    if not os.path.exists(meta_path):
        raise FileNotFoundError("file {} not found".format(meta_path))

    with open(meta_path) as meta_file:
        metadata = json.load(meta_file)
    url = metadata['url']
    etag = metadata['etag']

    return url, etag


def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
    """
    Given something that might be a URL (or might be a local path),
    determine which. If it's a URL, download the file and cache it, and
    return the path to the cached file. If it's already a local path,
    make sure the file exists and then return the path.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
    if isinstance(url_or_filename, Path):
        url_or_filename = str(url_or_filename)

    parsed = urlparse(url_or_filename)

    if parsed.scheme in ('http', 'https', 's3'):
        # URL, so get it from the cache (downloading if necessary)
        return get_from_cache(url_or_filename, cache_dir)
    elif os.path.exists(url_or_filename):
        # File, and it exists.
        return url_or_filename
    elif parsed.scheme == '':
        # File, but it doesn't exist.
        raise FileNotFoundError("file {} not found".format(url_or_filename))
    else:
        # Something unknown
        raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))


def split_s3_path(url: str) -> Tuple[str, str]:
    """Split a full s3 path into the bucket name and path."""
    parsed = urlparse(url)
    if not parsed.netloc or not parsed.path:
        raise ValueError("bad s3 path {}".format(url))
    bucket_name = parsed.netloc
    s3_path = parsed.path
    # Remove '/' at beginning of path.
    if s3_path.startswith("/"):
        s3_path = s3_path[1:]
    return bucket_name, s3_path


def s3_request(func: Callable):
    """
    Wrapper function for s3 requests in order to create more helpful error
    messages.
    """

    @wraps(func)
    def wrapper(url: str, *args, **kwargs):
        try:
            return func(url, *args, **kwargs)
        except ClientError as exc:
            if int(exc.response["Error"]["Code"]) == 404:
                raise FileNotFoundError("file {} not found".format(url))
            else:
                raise

    return wrapper


@s3_request
def s3_etag(url: str) -> Optional[str]:
    """Check ETag on S3 object."""
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_object = s3_resource.Object(bucket_name, s3_path)
    return s3_object.e_tag


@s3_request
def s3_get(url: str, temp_file: IO) -> None:
    """Pull a file directly from S3."""
    s3_resource = boto3.resource("s3")
    bucket_name, s3_path = split_s3_path(url)
    s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def http_get(url: str, temp_file: IO) -> None:
    req = requests.get(url, stream=True)
    content_length = req.headers.get('Content-Length')
    total = int(content_length) if content_length is not None else None
    progress = tqdm(unit="B", total=total)
    for chunk in req.iter_content(chunk_size=1024):
        if chunk: # filter out keep-alive new chunks
            progress.update(len(chunk))
            temp_file.write(chunk)
    progress.close()


def get_from_cache(url: str, cache_dir: str = None) -> str:
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE

    os.makedirs(cache_dir, exist_ok=True)

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
        etag = s3_etag(url)
    else:
        response = requests.head(url, allow_redirects=True)
        if response.status_code != 200:
            raise IOError("HEAD request failed for url {} with status code {}"
                          .format(url, response.status_code))
        etag = response.headers.get("ETag")

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    if not os.path.exists(cache_path):
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with tempfile.NamedTemporaryFile() as temp_file:
            logger.info("%s not found in cache, downloading to %s", url, temp_file.name)

            # GET file object
            if url.startswith("s3://"):
                s3_get(url, temp_file)
            else:
                http_get(url, temp_file)

            # we are copying the file before closing it, so flush to avoid truncation
            temp_file.flush()
            # shutil.copyfileobj() starts at the current position, so go to the start
            temp_file.seek(0)

            logger.info("copying %s to cache at %s", temp_file.name, cache_path)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)

            logger.info("creating metadata file for %s", cache_path)
            meta = {'url': url, 'etag': etag}
            meta_path = cache_path + '.json'
            with open(meta_path, 'w') as meta_file:
                json.dump(meta, meta_file)

            logger.info("removing temp file %s", temp_file.name)

    return cache_path


def read_set_from_file(filename: str) -> Set[str]:
    '''
    Extract a de-duped collection (set) of text from a file.
    Expected file format is one item per line.
    '''
    collection = set()
    with open(filename, 'r') as file_:
        for line in file_:
            collection.add(line.rstrip())
    return collection


def get_file_extension(path: str, dot=True, lower: bool = True):
    ext = os.path.splitext(path)[1]
    ext = ext if dot else ext[1:]
    return ext.lower() if lower else ext

In [None]:
"""PyTorch BERT model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

#from .file_utils import cached_path

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

PRETRAINED_MODEL_ARCHIVE_MAP = {
    'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
    'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
    'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
    'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
    'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
    'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
    'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


class BertConfig(object):
    """Configuration class to store the configuration of a `BertModel`.
    """
    def __init__(self,
                 vocab_size_or_config_json_file,
                 hidden_size=768,
                 num_hidden_layers=12,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 hidden_act="gelu",
                 hidden_dropout_prob=0.1,
                 attention_probs_dropout_prob=0.1,
                 max_position_embeddings=512,
                 type_vocab_size=2,
                 initializer_range=0.02):
        """Constructs BertConfig.
        Args:
            vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
            hidden_size: Size of the encoder layers and the pooler layer.
            num_hidden_layers: Number of hidden layers in the Transformer encoder.
            num_attention_heads: Number of attention heads for each attention layer in
                the Transformer encoder.
            intermediate_size: The size of the "intermediate" (i.e., feed-forward)
                layer in the Transformer encoder.
            hidden_act: The non-linear activation function (function or string) in the
                encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
            hidden_dropout_prob: The dropout probabilitiy for all fully connected
                layers in the embeddings, encoder, and pooler.
            attention_probs_dropout_prob: The dropout ratio for the attention
                probabilities.
            max_position_embeddings: The maximum sequence length that this model might
                ever be used with. Typically set this to something large just in case
                (e.g., 512 or 1024 or 2048).
            type_vocab_size: The vocabulary size of the `token_type_ids` passed into
                `BertModel`.
            initializer_range: The sttdev of the truncated_normal_initializer for
                initializing all weight matrices.
        """
        if isinstance(vocab_size_or_config_json_file, str):
            with open(vocab_size_or_config_json_file, "r") as reader:
                json_config = json.loads(reader.read())
            for key, value in json_config.items():
                self.__dict__[key] = value
        elif isinstance(vocab_size_or_config_json_file, int):
            self.vocab_size = vocab_size_or_config_json_file
            self.hidden_size = hidden_size
            self.num_hidden_layers = num_hidden_layers
            self.num_attention_heads = num_attention_heads
            self.hidden_act = hidden_act
            self.intermediate_size = intermediate_size
            self.hidden_dropout_prob = hidden_dropout_prob
            self.attention_probs_dropout_prob = attention_probs_dropout_prob
            self.max_position_embeddings = max_position_embeddings
            self.type_vocab_size = type_vocab_size
            self.initializer_range = initializer_range
        else:
            raise ValueError("First argument must be either a vocabulary size (int)"
                             "or the path to a pretrained model config file (str)")

    @classmethod
    def from_dict(cls, json_object):
        """Constructs a `BertConfig` from a Python dictionary of parameters."""
        config = BertConfig(vocab_size_or_config_json_file=-1)
        for key, value in json_object.items():
            config.__dict__[key] = value
        return config

    @classmethod
    def from_json_file(cls, json_file):
        """Constructs a `BertConfig` from a json file of parameters."""
        with open(json_file, "r") as reader:
            text = reader.read()
        return cls.from_dict(json.loads(text))

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"


class BertLayerNorm(nn.Module):
    def __init__(self, config, variance_epsilon=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(BertLayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(config.hidden_size))
        self.beta = nn.Parameter(torch.zeros(config.hidden_size))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta


class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = BertLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        return context_layer


class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super(BertSelfOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertAttention(nn.Module):
    def __init__(self, config):
        super(BertAttention, self).__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
        self_output = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class BertOutput(nn.Module):
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


class BertLayer(nn.Module):
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)
        self.ckpn=[5,11]

    def custom(self, module):
        def custom_forward(*inputs):
            global mem, start_mem
            print('In backprop!')
            inputs = module(inputs[0],inputs[1])
            print('inside custom: ',torch.cuda.memory_allocated('cuda'),torch.cuda.memory_allocated('cuda')-start_mem,inputs.device,start_mem)
            mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
            return inputs
        return custom_forward

    def forward(self, hidden_states, attention_mask,i):
        if i in self.ckpn:
            attention_output = self.attention(hidden_states,attention_mask)
        else:
            attention_output = checkpoint.checkpoint(self.custom(self.attention),hidden_states,attention_mask)
#        if i in self.ckpn:
#            intermediate_output=self.intermediate(attention_output)
#        else:
        intermediate_output = checkpoint.checkpoint(self.intermediate,attention_output)
        layer_output = checkpoint.checkpoint(self.output,intermediate_output, attention_output)
        mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
        return layer_output


class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        layer = BertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])    
        self.ckpn=[2,11]  #[3,6,11]  #$$

    def custom(self, module):
        def custom_forward(*inputs):
            global mem, start_mem
            print('In backprop!')
            inputs = module(inputs[0],inputs[1])
            print('inside custom: ',torch.cuda.memory_allocated('cuda'),torch.cuda.memory_allocated('cuda')-start_mem,inputs.device,start_mem)
            mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
            return inputs
        return custom_forward

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
        global mem, start_mem, fmem_len
        all_encoder_layers = []
        i=0
        for layer_module in self.layer:
#            if i in self.ckpn:
#                print(i,' is checkpointed!')
            hidden_states = layer_module(hidden_states, attention_mask,i)
#            else:
#                hidden_states = torch.utils.checkpoint.checkpoint(self.custom(layer_module), hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
            i+=1
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
        fmem_len=len(mem)
        print('Done with forward pass!')
        return all_encoder_layers

class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super(BertPredictionHeadTransform, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.transform_act_fn = ACT2FN[config.hidden_act] \
            if isinstance(config.hidden_act, str) else config.hidden_act
        self.LayerNorm = BertLayerNorm(config)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class BertLMPredictionHead(nn.Module):
    def __init__(self, config, bert_model_embedding_weights):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states) + self.bias
        return hidden_states


class BertOnlyMLMHead(nn.Module):
    def __init__(self, config, bert_model_embedding_weights):
        super(BertOnlyMLMHead, self).__init__()
        self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertOnlyNSPHead(nn.Module):
    def __init__(self, config):
        super(BertOnlyNSPHead, self).__init__()
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, pooled_output):
        seq_relationship_score = self.seq_relationship(pooled_output)
        return seq_relationship_score


class BertPreTrainingHeads(nn.Module):
    def __init__(self, config, bert_model_embedding_weights):
        super(BertPreTrainingHeads, self).__init__()
        self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
        self.seq_relationship = nn.Linear(config.hidden_size, 2)

    def forward(self, sequence_output, pooled_output):
        prediction_scores = self.predictions(sequence_output)
        seq_relationship_score = self.seq_relationship(pooled_output)
        return prediction_scores, seq_relationship_score


class PreTrainedBertModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedBertModel, self).__init__()
        if not isinstance(config, BertConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
            module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
    def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
        """
        Instantiate a PreTrainedBertModel from a pre-trained model file.
        Download and cache the pre-trained model file if needed.
        
        Params:
            pretrained_model_name: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-base-multilingual`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        """
        if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
        else:
            archive_file = pretrained_model_name
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        except FileNotFoundError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name,
                    ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                    archive_file))
            return None
        if resolved_archive_file == archive_file:
            logger.info("loading archive file {}".format(archive_file))
        else:
            logger.info("loading archive file {} from cache at {}".format(
                archive_file, resolved_archive_file))
        tempdir = None
        if os.path.isdir(resolved_archive_file):
            serialization_dir = resolved_archive_file
        else:
            # Extract archive to temp dir
            tempdir = tempfile.mkdtemp()
            logger.info("extracting archive file {} to temp dir {}".format(
                resolved_archive_file, tempdir))
            with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                archive.extractall(tempdir)
            serialization_dir = tempdir
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        if len(missing_keys) > 0:
            logger.info("Weights of {} not initialized from pretrained model: {}".format(
                model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info("Weights from pretrained model not used in {}: {}".format(
                model.__class__.__name__, unexpected_keys))
        if tempdir:
            # Clean up temp dir
            shutil.rmtree(tempdir)
        return model


class BertModel(PreTrainedBertModel):
    """BERT model ("Bidirectional Embedding Representations from a Transformer").
    Params:
        config: a BertConfig class instance with the configuration to build a new model
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
    Outputs: Tuple of (encoded_layers, pooled_output)
        `encoded_layers`: controled by `output_all_encoded_layers` argument:
            - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
                of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
                encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
            - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
                to the last attention block of shape [batch_size, sequence_length, hidden_size],
        `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
            classifier pretrained on top of the hidden state associated to the first character of the
            input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    model = modeling.BertModel(config=config)
    all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = checkpoint.checkpoint(self.embeddings,input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler,sequence_output
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

In [None]:
mem,fmem_len=[],0
#mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
model=BertModel.from_pretrained('bert-base-uncased').to('cuda')
torch.cuda.reset_max_memory_allocated('cuda')
start_mem=torch.cuda.memory_allocated('cuda')
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
start_time=time.time()
encoded_layers, _ = model(tokens_tensor.to('cuda'), segments_tensors.to('cuda'))
ftime=time.time()-start_time
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
fmem_len=len(mem)
loss=sum(sum(sum(sum(encoded_layers))))
#mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
model.zero_grad()
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
loss.backward()
btime=time.time()-(start_time+ftime)
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
#del loss
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
#mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12
ttime=time.time()-start_time
print('f,b,t:',ftime,btime,ttime)
print('mem: ',mem)

In [None]:
print(mem)

In [None]:
import wandb
wandb.init(project='bert_model_ckp',name='exp6')
mem,fmem_len=[],0
model=BertModel.from_pretrained('bert-base-uncased').to('cuda')
wandb.watch(model,log='all')
torch.cuda.reset_max_memory_allocated('cuda')
start_mem=torch.cuda.memory_allocated('cuda')
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
start_time=time.time()
encoded_layers, _ = model(tokens_tensor.to('cuda'), segments_tensors.to('cuda'))
ftime=time.time()-start_time
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
loss=sum(sum(sum(sum(encoded_layers))))
model.zero_grad()
loss.backward(retain_graph=False)
btime=time.time()-(start_time+ftime)
mem.append(torch.cuda.memory_allocated('cuda')-start_mem)
assert len(encoded_layers) == 12
ttime=time.time()-start_time
print('Forward pass time:,Back-prop time,Total time:',ftime,btime,ttime)
print('GPU memory over time (in bytes): ',mem)
print('Forward pass length: ',fmem_len)
for _ in mem[0:fmem_len]:
    wandb.log({'Forward pass GPU_memory':_})
for _ in mem[fmem_len:]:
    wandb.log({'Backprop GPU_memory':_})
wandb.log({'Forward pass time':ftime,'Back-prop time':btime,'Total time':ttime})
#wandb.log({'Backprop_start_step':fmem_len})

In [None]:
print(mem)
print(torch.cuda.memory_allocated('cuda')-start_mem)
#loss.backward()
#loss.item()
#torch.cuda.empty_cache() 
#print(torch.cuda.memory_allocated('cuda')-start_mem)
#del encoded_layers

In [None]:
#del model
torch.cuda.empty_cache()


In [None]:
from matplotlib import pyplot as plt
#plt.plot(mem[0:fmem_len])
plt.plot(mem[fmem_len:])

In [None]:
mem[24]

In [None]:
memory=[3404800,436736000,2928128,3347456,2435072]
compute=[0.21,0.1857,0.184,0.1937,0.1468]
#labels=['exp1']

In [None]:
import torch
import torchvision.models as models
import torch.autograd.profiler as profiler
y=fc(x)
with profiler.profile(profile_memory=True, record_shapes=True) as prof:
    y=fc(x)
#    y.sum().backward()
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))


In [None]:
print(out.shape)

In [None]:
fuse_conv=FuSeConv()
fuse_conv

In [None]:

import torch.nn as nn
torch.cuda.reset_max_memory_allocated('cuda')
# create a simple Sequential model
model = nn.Sequential(
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
    nn.ReLU()
)
model.to('cuda')
# create the model inputs
input_var = Variable(torch.randn(1, 100), requires_grad=True).to('cuda')
# set the number of checkpoint segments
segments = 2
# get the modules in the model. These modules should be in the order
# the model should be executed
modules = [module for k, module in model._modules.items()]
print(torch.cuda.memory_allocated('cuda'))
out = checkpoint_sequential(modules, segments, input_var)
print(torch.cuda.memory_allocated('cuda'))

# run the backwards pass on the model. For backwards pass, for simplicity purpose, 
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()

# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
    grad_checkpointed[name] = param.grad.data.clone()
#rad_checkpointed


In [None]:
torch.cuda.empty_cache() 
torch.cuda.memory_summary()

In [None]:
torch.cuda.reset_max_memory_allocated('cuda')
original = model
original.to('cuda')
# create a new variable using the same tensor data
x = Variable(input_var.data, requires_grad=True).to('cuda')
# get the model output and save it to prevent any modifications
print(torch.cuda.memory_allocated('cuda'))
out = original(x)
print(torch.cuda.memory_allocated('cuda'))
out_not_checkpointed = out.data.clone()
# calculate the gradient now and save the parameter gradients values
original.zero_grad()
out.sum().backward()
grad_not_checkpointed = {}
for name, param in original.named_parameters():
    print(name,param)
    grad_not_checkpointed[name] = param.grad.data.clone()
grad_not_checkpointed

In [None]:
print(output_checkpointed)
print(out_not_checkpointed)

In [None]:
grad_checkpointed

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable, Function
import torch.utils.checkpoint as checkpoint
from collections import OrderedDict
class ConvBNReLU(nn.Module):
    
    def __init__(self, in_planes, out_planes):
        
        super(ConvBNReLU, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_planes)
        self.relu1 = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out

class DummyNet(nn.Module):
    def __init__(self):
        super(DummyNet, self).__init__()
        self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(16)),
            ('relu1', nn.ReLU(inplace=True)),
        ]))

        # The module that we want to checkpoint
        self.module = ConvBNReLU(16, 64) 
        
        self.final_module = ConvBNReLU(64, 64)
    
    def custom(self, module):
        def custom_forward(*inputs):
            inputs = module(inputs[0])
            return inputs
        return custom_forward
    
    def forward(self, x):
        print(torch.cuda.memory_allocated('cuda'))
        out1 = self.features(x)
        print(torch.cuda.memory_allocated('cuda'))
#        out2 = self.module(out1)
        out2 = checkpoint.checkpoint(self.custom(self.module), out1)
        print(torch.cuda.memory_allocated('cuda'))
        out3 = self.final_module(out2)
        print(torch.cuda.memory_allocated('cuda'))
        return out3

In [None]:
torch.cuda.reset_max_memory_allocated('cuda')
print(torch.cuda.memory_allocated('cuda'))
dn=DummyNet().to('cuda')
print(torch.cuda.memory_allocated('cuda'))
x = Variable(torch.randn(1,3,10,10), requires_grad=True).to('cuda')
print(torch.cuda.memory_allocated('cuda'))
outi=dn(x)
print(torch.cuda.memory_allocated('cuda'))


In [None]:
488960
519168
520704
520704
535040
587264
639488
516608