In [1]:
import spacy
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from doc_classifier import doc_classifier
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pd.read_csv("E:\Data\datasets\imdb_long_text_dataset.csv")

In [3]:
df.head()

Unnamed: 0,review,sentiment,token_lengths
0,So im not a big fan of Boll's work but then ag...,negative,563
1,"""The Cell"" is an exotic masterpiece, a dizzyin...",positive,749
2,'War movie' is a Hollywood genre that has been...,positive,845
3,"Taut and organically gripping, Edward Dmytryk'...",positive,608
4,One of the most significant quotes from the en...,positive,908


In [4]:
X_train, X_test, y_train, y_test = train_test_split(df['review'], df['sentiment'], stratify=df['sentiment'], test_size=0.2)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, stratify=y_train, test_size=0.2)

In [5]:
print ("y_test\n", y_test.value_counts(normalize=True), '\n')
print ("y_val\n", y_val.value_counts(normalize=True), '\n')
print ("y_train\n", y_train.value_counts(normalize=True), '\n')

y_test
 sentiment
positive    0.514089
negative    0.485911
Name: proportion, dtype: float64 

y_val
 sentiment
positive    0.514605
negative    0.485395
Name: proportion, dtype: float64 

y_train
 sentiment
positive    0.514071
negative    0.485929
Name: proportion, dtype: float64 



In [6]:
# # Loading spacy as sentence chunker
# nlp = spacy.load("en_core_web_sm")

In [7]:
# Chunking each document
# X_train_chunked = [nlp(each_sent) for each_sent in X_train]
# X_train_chunked = [[sent for sent in nlp(each_sent).sents] for each_sent in X_train[:20]]


# # Spacy takes too long, will chunk lexically first
# X_train_chunked = [each_sent.split(". ") for each_sent in X_train]
# X_val_chunked = [each_sent.split(". ") for each_sent in X_val
# X_test_chunked = [each_sent.split(". ") for each_sent in X_test]

# Spacy takes too long, will chunk lexically first
X_train_chunked = [each_sent.split(". ") for each_sent in X_train[:20]]
X_val_chunked = [each_sent.split(". ") for each_sent in X_val[:10]]
X_test_chunked = [each_sent.split(". ") for each_sent in X_test[:10]]

In [8]:
# Check max chunk length
print ("Max chunk length for X_train: ", max([len(chunks) for chunks in X_train_chunked]))
print ("Max chunk length for X_val: ", max([len(chunks) for chunks in X_val_chunked]))
print ("Max chunk length for X_test: ", max([len(chunks) for chunks in X_test_chunked]))

Max chunk length for X_train:  44
Max chunk length for X_val:  28
Max chunk length for X_test:  37


In [9]:
len(X_train)

4655

In [10]:
len(X_train_chunked)

20

In [11]:
if torch.cuda.is_available():
    device = torch.device('cuda')

In [12]:
encoder_path = "D:\\DSAI\\Pre-Trained Models\\distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(encoder_path)
encoder = AutoModel.from_pretrained(encoder_path)



In [13]:
X_train_tokenized = [[tokenizer(each_sent, padding='max_length', truncation=True, return_tensors='pt') for each_sent in each_doc] for each_doc in X_train_chunked]
X_val_tokenized = [[tokenizer(each_sent, padding='max_length', truncation=True, return_tensors='pt') for each_sent in each_doc] for each_doc in X_val_chunked]
X_test_tokenized = [[tokenizer(each_sent, padding='max_length', truncation=True, return_tensors='pt') for each_sent in each_doc] for each_doc in X_test_chunked]

In [14]:
# X_train_tokenized[0][0]

In [15]:
# X_train_tokenized[0][0]['input_ids'].clone().detach()

In [16]:
# train_seq = [[sent['input_ids'].clone().detach() for sent in doc] for doc in X_train_tokenized]
# train_mask = [[sent['attention_mask'].clone().detach() for sent in doc] for doc in X_train_tokenized]

# val_seq = [[sent['input_ids'].clone().detach() for sent in doc] for doc in X_val_tokenized]
# val_mask = [[sent['attention_mask'].clone().detach() for sent in doc] for doc in X_val_tokenized]

# test_seq = [[sent['input_ids'].clone().detach() for sent in doc] for doc in X_test_tokenized]
# test_mask = [[sent['attention_mask'].clone().detach() for sent in doc] for doc in X_test_tokenized]

# train_label = torch.tensor(y_train.map({'positive':1, 'negative':0}).tolist())
# val_label = torch.tensor(y_val.map({'positive':1, 'negative':0}).tolist())
# test_label = torch.tensor(y_test.map({'positive':1, 'negative':0}).tolist())

In [17]:
def get_doc_tensors(document_corpus, embedding_to_extract, max_chunks=30, max_sentence_token_len=512):
    doc_list = []
    for doc in document_corpus:
        sent_list = []
        for sent in doc:
            sent_list.append(sent[embedding_to_extract].clone().detach()[0])
            
        sent_seqs = torch.stack(sent_list, dim=0)
    
        if sent_seqs.size()[0] < max_chunks: # keep it below 30 sentences for now
            empty_sent_to_pad = torch.zeros(max_chunks-sent_seqs.size()[0], max_sentence_token_len)
    
            sent_seqs = torch.cat((empty_sent_to_pad, sent_seqs), dim=0)
    
        else:
            sent_seqs = sent_seqs[:max_chunks, :]
    
        doc_list.append(sent_seqs)

    return torch.stack(doc_list, dim=0)

In [18]:
train_seq = get_doc_tensors(X_train_tokenized, embedding_to_extract='input_ids', max_chunks=10)
train_mask = get_doc_tensors(X_train_tokenized, embedding_to_extract='attention_mask', max_chunks=10)

val_seq = get_doc_tensors(X_val_tokenized, embedding_to_extract='input_ids', max_chunks=10)
val_mask = get_doc_tensors(X_val_tokenized, embedding_to_extract='attention_mask', max_chunks=10)

test_seq = get_doc_tensors(X_test_tokenized, embedding_to_extract='input_ids', max_chunks=10)
test_mask = get_doc_tensors(X_test_tokenized, embedding_to_extract='attention_mask', max_chunks=10)

# train_label = torch.tensor(y_train.map({'positive':1, 'negative':0}).tolist())
# val_label = torch.tensor(y_val.map({'positive':1, 'negative':0}).tolist())
# test_label = torch.tensor(y_test.map({'positive':1, 'negative':0}).tolist())

train_label = torch.tensor(y_train.map({'positive':1, 'negative':0}).tolist()[:20])
val_label = torch.tensor(y_val.map({'positive':1, 'negative':0}).tolist()[:10])
test_label = torch.tensor(y_test.map({'positive':1, 'negative':0}).tolist()[:10])

In [19]:
# doc_list = []
# for doc in X_train_tokenized:
#     sent_list = []
#     for sent in doc:
#         sent_list.append(sent['input_ids'].clone().detach()[0])
        
#     sent_seqs = torch.stack(sent_list, dim=0)

#     if sent_seqs.size()[0] < 30: # keep it below 30 sentences for now
#         empty_sent_to_pad = torch.zeros(30-sent_seqs.size()[0], 512)

#         sent_seqs = torch.cat((empty_sent_to_pad, sent_seqs), dim=0)

#     else:
#         sent_seqs = sent_seqs[:30, :]

#     doc_list.append(sent_seqs)

In [20]:
# sent_seqs.size()

In [21]:
# torch.zeros(5, 512).size()

In [22]:
# sent_seqs[:2, :]

In [23]:
# sent_list

In [24]:
# doc_seq = torch.stack(doc_list, dim=0) # Need to pas to max length for this one! Else the shape wont fit

# Probably have to do the torch.zeros method and slowly fill in the tensor??
## Dont need can just manuall pad the fucking thing.. damn annoying - Solo

In [25]:
# doc_seq.size()

In [26]:
# len(X_train_tokenized)

In [27]:
# len(X_train_tokenized[-1])

In [28]:
# len(X_train_tokenized[-1])

In [29]:
# sent_seqs[-1].size()

In [30]:
# doc_seq.size()

In [31]:
# # FOR TRAINING
# # Define batch size
# batch_size = 8

# # Wrap tensors
# train_data = TensorDataset(doc_seq)
# # Sampler for sampling the data during training
# train_sampler = SequentialSampler(train_data)
# # Dataloader for train set
# train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

In [32]:
# train_data

In [33]:
# for num, batch in enumerate(train_dataloader):
#     see = batch

#     print (see[0].size())

In [34]:
# see[0][-1]

In [35]:
# sent_seqs

In [36]:
# torch.equal(see[0][-1], sent_seqs)

In [37]:
# doc_seq[-2]

In [38]:
# see[0][-2]

In [39]:
# torch.equal(see[0][-2], doc_seq[-2])

In [40]:
# y_train.map({'positive':1, 'negative':0})

In [41]:
# torch.tensor(y_train.map({'positive':1, 'negative':0}).tolist())

In [42]:
# FOR TRAINING
# Define batch size
batch_size = 2

# Wrap tensors
train_data = TensorDataset(train_seq, train_mask, train_label)
# Sampler for sampling the data during training
# train_sampler = SequentialSampler(train_data)
train_sampler = RandomSampler(train_data)
# Dataloader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)


# Wrap tensors
val_data = TensorDataset(val_seq, val_mask, val_label)
# Sampler for sampling the data during validation for training
val_sampler = SequentialSampler(val_data)
# Dataloader for val set
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)


# Wrap tensors
test_data = TensorDataset(test_seq, test_mask, test_label)
# Sampler for sampling the data for testing
test_sampler = SequentialSampler(test_data)
# Dataloader for test set
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [43]:
# for step, batch in enumerate(train_dataloader):
#     batch = [r.to(device) for r in batch]
#     train_seq, train_mask, train_label = batch

#     print(torch.cuda.memory_summary())    

In [44]:
# # FOR TRAINING
# # Define batch size
# batch_size = 8

# # Wrap tensors
# train_data = TensorDataset(train_seq, train_mask, train_label)
# # Sampler for sampling the data during training
# train_sampler = RandomSampler(train_data)
# # Dataloader for train set
# train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)


# # Wrap tensors
# val_data = TensorDataset(val_seq, val_mask, val_label)
# # Sampler for sampling the data during validation for training
# val_sampler = SequentialSampler(val_data)
# # Dataloader for val set
# val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)


# # Wrap tensors
# test_data = TensorDataset(test_seq, test_mask, test_label)
# # Sampler for sampling the data for testing
# test_sampler = SequentialSampler(test_data)
# # Dataloader for test set
# test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

In [45]:
# len(train_anchor_seq)

In [46]:
# len(train_anchor_seq[0])

In [47]:
# train_anchor_seq[0][0]

In [48]:
# len(X_train_tokenized)

In [49]:
# X_train_chunked[0]

In [50]:
# len(X_train_tokenized[0])

---

In [51]:
# for step, batch in enumerate(train_dataloader):
#     if step == 0:
#         train_seq, train_mask, train_label = batch

In [52]:
# see_tok = tokenizer(['first sentence', 'second sentence'], padding='max_length', truncation=True, return_tensors='pt')
# # see_tok = tokenizer(['second sentence'], padding='max_length', truncation=True, return_tensors='pt')

In [53]:
# see_embed = encoder(see_tok['input_ids'], see_tok['attention_mask'])

In [54]:
# see_tok # Seems like there is CLS tokens for distilbert as well (id "101")

In [55]:
# see_embed.last_hidden_state.size()

In [56]:
# a = np.array([[[1,1,1], [2,2,2]]])

In [57]:
# a = np.zeros((2,1,3))

In [58]:
# b = np.ones((8,10,3))

In [59]:
# b

In [60]:
# a[0,0] = [2,2,2]
# a[1,0] = [1,1,1]
# a[2,0] = [1,1,1]
# a[3,0] = [1,1,1]
# a[4,0] = [1,1,1]

In [61]:
# a

In [62]:
# a[:,-1,:]

In [63]:
# train_seq.size()

In [64]:
# train_seq.size()[0]

In [65]:
# train_seq.size()[1]

In [66]:
# train_seq

---

---
## Training

---

In [67]:
model = doc_classifier(encoder, dropout=0.2, device=device)

In [68]:
model = model.to(device)

In [69]:
from torch.optim import AdamW

# Define optimiser
optimizer = AdamW(model.parameters(), lr=2e-5)

In [70]:
weight = np.array(y_train.value_counts()[0]/y_train.value_counts()[1])

  weight = np.array(y_train.value_counts()[0]/y_train.value_counts()[1])
  weight = np.array(y_train.value_counts()[0]/y_train.value_counts()[1])


In [71]:
# Converting list of class weights to a tensor
weights = torch.tensor(weight, dtype=torch.float)

# Push weights to GPU
weights = weights.to(device)

# Define loss function
cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weights)

In [72]:
def train(train_dataloader):
    model.train()
    
    total_loss, total_accuracy = 0, 0
    
    # Empty list to save model predictions
    total_preds = []
    
    # Iterate over batches
    for step, batch in enumerate(train_dataloader):
        # Progress update for every 50 batches
        if step%10==0 and not step==0:
            print ('Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))

        print("Before sending batch to GPU")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print ("-----")

        # Push batch to GPU
        batch = [r.to(device) for r in batch]
        train_input_seq, train_input_mask, train_input_label = batch

        print("After sending batch to GPU")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print ("-----")

        # Clear previously calculated gradients
        model.zero_grad()

        print("Before passing through model")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print ("-----")


        # Get model predictions for the current batch
        train_output = model(train_input_seq, train_input_mask)

        # def getActivation(name):
        #     # the hook signature
        #     def hook(model, input, output):
        #         activation[name] = output.detach()
        #         print(output.detach())
        #     return hook
        
        # h = model.encoder.transformer.register_forward_hook(getActivation('output_layer_norm'))
        # print (h.remove())

        print("After passing through model")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print ("-----")
        
        """
        nn.CosineSimilarity measures similarity between 2 outputs, the more similar, the bigger the score.
        However for triplet loss, the positive cases are supposed to be closer and have a smaller score.
        To make things easier, we flipped the negative and positive positions
        i.e. loss(anchor, positive, negative) --> loss(anchor, negative, positive)
        """

        # print (train_output, train_input_label)
        # print (torch.squeeze(train_output))
        
        # Compute loss 
        # loss = cross_entropy(train_output, train_input_label)
        loss = cross_entropy(torch.squeeze(train_output), train_input_label.float())

        # Add on to the total loss
        total_loss = total_loss + loss.item()

        print("Before backpropagation")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print ("-----")

        # Backward pass to calculate gradients
        loss.backward()

        # Update parameters
        optimizer.step()

        print("After backpropagation")
        print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
        print("torch.cuda.memory_reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
        print("torch.cuda.max_memory_reserved: %fGB"%(torch.cuda.max_memory_reserved(0)/1024/1024/1024))
        print("=====")

    # Compute training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)

    return avg_loss

In [73]:
def evaluate(val_dataloader):
    print ('\nEvaluating...')
    
    # Deactivate dropout layers
    model.eval()
    
    total_loss, total_accuracy = 0, 0
    
    # Empty list to save model predictions
    total_preds = []
    
    # Iterate over batches
    for step, batch in enumerate(val_dataloader):
        # Progress update for every 50 batches
        if step%10==0 and not step==0:
            print ('Batch {:>5,} of {:>5,}.'.format(step, len(val_dataloader)))

        # Push batch to GPU
        batch = [t.to(device) for t in batch]
        val_input_seq, val_input_mask, val_input_label = batch

        # Deactivate autograd()
        with torch.no_grad():

            
            # Get model predictions for the current batch
            val_output = model(val_input_seq, val_input_mask)
        
            """
            nn.CosineSimilarity measures similarity between 2 outputs, the more similar, the bigger the score.
            However for triplet loss, the positive cases are supposed to be closer and have a smaller score.
            To make things easier, we flipped the negative and positive positions
            i.e. loss(anchor, positive, negative) --> loss(anchor, negative, positive)
            """

            # Compute loss 
            # loss = cross_entropy(val_output, val_input_label)
            loss = cross_entropy(torch.squeeze(val_output), val_input_label.float())

            total_loss = total_loss + loss.item()

    # Compute the validation loss of the epoch
    avg_loss = total_loss / len(val_dataloader)

    return avg_loss

In [74]:
epochs = 10

# Set initial loss to infinite
best_valid_loss = float('inf')

# Empty lists to store training and validation loss of each epoch
train_losses = []
valid_losses = []

# For each epoch
for epoch in range(epochs):
    print ('\nEpoch {:}/ {:}'.format(epoch+1, epochs))
    
    # Train model
    train_loss = train(train_dataloader)
    
    # Evaluate model
    valid_loss = evaluate(val_dataloader)
    
    # Save the best model
    if valid_loss<best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'test_model.pt')
        
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print (f"\nTraining Loss: {train_loss:.5f}")
    print (f"Validation Loss: {valid_loss:.5f}")


Epoch 1/ 10
Before sending batch to GPU
torch.cuda.memory_allocated: 0.257084GB
torch.cuda.memory_reserved: 0.285156GB
torch.cuda.max_memory_reserved: 0.285156GB
-----
After sending batch to GPU
torch.cuda.memory_allocated: 0.257161GB
torch.cuda.memory_reserved: 0.285156GB
torch.cuda.max_memory_reserved: 0.285156GB
-----
Before passing through model
torch.cuda.memory_allocated: 0.257161GB
torch.cuda.memory_reserved: 0.285156GB
torch.cuda.max_memory_reserved: 0.285156GB
-----
After passing through model
torch.cuda.memory_allocated: 6.415343GB
torch.cuda.memory_reserved: 6.441406GB
torch.cuda.max_memory_reserved: 6.441406GB
-----
Before backpropagation
torch.cuda.memory_allocated: 6.415344GB
torch.cuda.memory_reserved: 6.441406GB
torch.cuda.max_memory_reserved: 6.441406GB
-----
After backpropagation
torch.cuda.memory_allocated: 1.072346GB
torch.cuda.memory_reserved: 6.824219GB
torch.cuda.max_memory_reserved: 6.824219GB
=====
Before sending batch to GPU
torch.cuda.memory_allocated: 1.072

KeyboardInterrupt: 

In [75]:
torch.cuda.empty_cache()

In [76]:
model.encoder

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [77]:
activation = {}
def get_activation_attn(name):
    def hook(model, input, output):
        activation[name] = output[0].detach()
    return hook

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [78]:
model.encoder.transformer.layer[0].attention.register_forward_hook(get_activation_attn('out_lin'))
model.register_forward_hook(get_activation('attn'))
model.register_forward_hook(get_activation('fc'))

<torch.utils.hooks.RemovableHandle at 0x2f0ae7d84c0>

In [79]:
for each, batch in enumerate(train_dataloader):
    # Push batch to GPU
    batch = [r.to(device) for r in batch]
    train_input_seq, train_input_mask, train_input_label = batch

    # Clear previously calculated gradients
    model.zero_grad()

    # Get model predictions for the current batch
    train_output = model(train_input_seq, train_input_mask)

In [80]:
train_output

tensor([[0.2922],
        [0.3030]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [82]:
print(activation['out_lin'])
print("-----")
print(activation['attn'])
print("-----")
print(activation['fc'])

tensor([[[-0.0935, -0.2781,  0.0464,  ...,  0.2565,  0.1163,  0.4425],
         [ 0.3404,  0.9756, -0.3835,  ...,  0.1363,  0.3854, -0.0914],
         [-0.1750,  0.8955,  0.5584,  ...,  0.2086,  0.5893,  0.0634],
         ...,
         [ 0.5768,  0.2107,  0.3207,  ..., -0.3443,  0.2117, -0.1269],
         [ 0.5465, -0.0526,  0.4327,  ...,  0.0870,  0.1898, -0.2098],
         [ 0.5152,  0.1919,  0.3169,  ..., -0.2005,  0.2022, -0.0837]]],
       device='cuda:0')
-----
tensor([[0.2922],
        [0.3030]], device='cuda:0')
-----
tensor([[0.2922],
        [0.3030]], device='cuda:0')


---
Based on what I understand, during forward pass, the model saves the activation (or the intermediate output) values as well. This would mean that the memory requirements will scale with batch size.. i.e. for each layer, the `intermediate output memory requirements` X `batch size` will be the additional memory requirements as batch size scales.

See the following article for further explanation on why gradient accumulation reduces memory requirements - https://medium.com/@mccartni/implications-of-batch-size-on-llm-training-and-inference-3320cb48d610#:~:text=Larger%20batch%20sizes%20mean%20more,which%20increases%20the%20memory%20footprint.

---