Skip to content

Commit

Permalink
Apply AMP option to embed() wherever applicable (#45; wip)
Browse files Browse the repository at this point in the history
Signed-off-by: Christopher Schröder <chschroeder@users.noreply.github.com>
  • Loading branch information
chschroeder committed Mar 17, 2024
1 parent fe3c78b commit 53836d7
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 185 deletions.
53 changes: 40 additions & 13 deletions small_text/integrations/pytorch/classifiers/kimcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,28 @@ class KimCNNEmbeddingMixin(EmbeddingMixin):

def embed(self, data_set, return_proba=False, embedding_method=EMBEDDING_METHOD_POOLED,
module_selector=lambda x: x['fc'], pbar='tqdm'):
"""Embeds each sample in the given `data_set`.
The embedding is created by using the underlying sentence transformer model.
Parameters
----------
data_set : PytorchTextClassificationDataset
The dataset for which embeddings (and class probabilities) will be computed.
return_proba : bool
Also return the class probabilities for `data_set`.
embedding_method : str, default='pooled'
Embedding method to use ['pooled', 'gradient'].
pbar : 'tqdm' or None, default='tqdm'
Displays a progress bar if 'tqdm' is passed.
Returns
-------
embeddings : np.ndarray
Embeddings in the shape (N, hidden_layer_dimensionality).
proba : np.ndarray
Class probabilities in the shape (N, num_classes) for `data_set` (only if `return_predictions` is `True`).
"""
if self.model is None:
raise ValueError('Model is not trained. Please call fit() first.')

Expand Down Expand Up @@ -116,10 +137,14 @@ def _create_embeddings(self, batch, embedding_method='pooled', module_selector=l
logits = self.model._dropout_and_fc(embeddings)
elif embedding_method == self.EMBEDDING_METHOD_GRADIENT:
best_label, logits = self._get_best_and_softmax(text)
embeddings = self.create_embedding(best_label, logits, module_selector, text)
embeddings = self._create_gradient_embedding(best_label, logits, module_selector, text)
else:
raise ValueError(f'Invalid embedding method: {embedding_method}')

if self.amp_args.use_amp:
embeddings = embeddings.float()
logits = logits.float()

return text.size(0), logits, embeddings

def _get_best_and_softmax(self, text):
Expand All @@ -134,7 +159,7 @@ def _get_best_and_softmax(self, text):

return best_label, logits

def _create_embedding(self, best_label, logits, module_selector, text):
def _create_gradient_embedding(self, best_label, logits, module_selector, text):

batch_len = text.size(0)

Expand All @@ -151,19 +176,21 @@ def _create_embedding(self, best_label, logits, module_selector, text):
for c in range(self.num_classes):
loss = self.criterion(sm, torch.LongTensor([c] * batch_len).to(self.device))

for k in range(batch_len):
self.model.zero_grad()
loss[k].backward(retain_graph=True)
with torch.autocast(device_type=self.amp_args.device_type, dtype=self.amp_args.dtype,
enabled=False):
for k in range(batch_len):
self.model.zero_grad()
loss[k].backward(retain_graph=True)

modules = dict({name: module for name, module in self.model.named_modules()})
params = module_selector(modules).weight.grad.flatten()
modules = dict({name: module for name, module in self.model.named_modules()})
params = module_selector(modules).weight.grad.flatten()

with torch.no_grad():
sm_prob = sm_t[c][k]
if c == best_label[k]:
arr[k, grad_size*c:grad_size*(c+1)] = (1-sm_prob)*params
else:
arr[k, grad_size*c:grad_size*(c+1)] = -1*sm_prob*params
with torch.no_grad():
sm_prob = sm_t[c][k]
if c == best_label[k]:
arr[k, grad_size*c:grad_size*(c+1)] = (1-sm_prob)*params
else:
arr[k, grad_size*c:grad_size*(c+1)] = -1*sm_prob*params

self.criterion.reduction = reduction_tmp

Expand Down
33 changes: 17 additions & 16 deletions small_text/integrations/transformers/classifiers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,12 @@ def embed(self, data_set, return_proba=False, embedding_method=EMBEDDING_METHOD_
Parameters
----------
data_set : TransformersDataset
The dataset for which embeddings (and class probabilities) will be computed.
return_proba : bool
Also return the class probabilities for `data_set`.
embedding_method : str
Embedding method to use [avg, cls].
Embedding method to use ['avg', 'cls'].
hidden_layer_index : int, default=-1
Index of the hidden layer.
pbar : 'tqdm' or None, default='tqdm'
Expand All @@ -161,7 +163,7 @@ def embed(self, data_set, return_proba=False, embedding_method=EMBEDDING_METHOD_
embeddings : np.ndarray
Embeddings in the shape (N, hidden_layer_dimensionality).
proba : np.ndarray
Class probabilities for `data_set` (only if `return_predictions` is `True`).
Class probabilities in the shape (N, num_classes) for `data_set` (only if `return_predictions` is `True`).
"""

if self.model is None:
Expand Down Expand Up @@ -507,11 +509,15 @@ def _train_loop_process_batches(self, num_epoch, sub_train_, sub_valid_, weights
if not stop:
with torch.autocast(enabled=self.amp_args.use_amp, device_type=self.amp_args.device_type,
dtype=self.amp_args.dtype):
loss, acc = self._train_single_batch(x, masks, cls, weight, optimizer, scaler)
loss, logits, cls = self._train_forward(x, masks, cls, weight, optimizer)
del x, masks
self._train_backward(loss, optimizer, scaler)

scheduler.step()

train_loss += loss
train_acc += acc
train_loss += loss.detach().item()
train_acc += self.sum_up_accuracy_(logits, cls)
del cls

if validate_every and i % validate_every == 0:
valid_loss, valid_acc = self.validate(sub_valid_)
Expand Down Expand Up @@ -549,10 +555,7 @@ def _create_collate_fn(self, use_sample_weights=False):
return partial(transformers_collate_fn, multi_label=self.multi_label,
num_classes=self.num_classes, use_sample_weights=use_sample_weights)

def _train_single_batch(self, x, masks, cls, weight, optimizer, scaler):

train_loss = 0.
train_acc = 0.
def _train_forward(self, x, masks, cls, weight, optimizer):

optimizer.zero_grad()

Expand All @@ -565,6 +568,11 @@ def _train_single_batch(self, x, masks, cls, weight, optimizer, scaler):
loss = loss * weight
loss = loss.mean()

del outputs

return loss, logits, cls

def _train_backward(self, loss, optimizer, scaler):
scaler.scale(loss).backward()
scaler.unscale_(optimizer)

Expand All @@ -573,13 +581,6 @@ def _train_single_batch(self, x, masks, cls, weight, optimizer, scaler):
scaler.step(optimizer)
scaler.update()

train_loss += loss.detach().item()
train_acc += self.sum_up_accuracy_(logits, cls)

del x, masks, cls, loss, outputs

return train_loss, train_acc

def _compute_loss(self, cls, outputs):
if self.num_classes == 2:
logits = outputs.logits
Expand Down
20 changes: 12 additions & 8 deletions small_text/integrations/transformers/classifiers/setfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def embed(self, data_set, return_proba=False, pbar='tqdm'):
Parameters
----------
data_set : TextDataset
The dataset for which embeddings (and class probabilities) will be computed.
return_proba : bool
Also return the class probabilities for `data_set`.
pbar : 'tqdm' or None, default='tqdm'
Expand All @@ -91,7 +93,7 @@ def embed(self, data_set, return_proba=False, pbar='tqdm'):
embeddings : np.ndarray
Embeddings in the shape (N, hidden_layer_dimensionality).
proba : np.ndarray
Class probabilities for `data_set` (only if `return_predictions` is `True`).
Class probabilities in the shape (N, num_classes) for `data_set` (only if `return_predictions` is `True`).
"""

if self.model is None:
Expand All @@ -104,13 +106,15 @@ def embed(self, data_set, return_proba=False, pbar='tqdm'):

num_batches = int(np.ceil(len(data_set) / self.mini_batch_size))
with build_pbar_context(pbar, tqdm_kwargs={'total': len(data_set)}) as pbar:
for batch in np.array_split(data_set.x, num_batches, axis=0):

batch_embeddings, probas = self._create_embeddings(batch)
pbar.update(batch_embeddings.shape[0])
embeddings.extend(batch_embeddings.tolist())
if return_proba:
predictions.extend(probas.tolist())
with torch.autocast(device_type=self.amp_args.device_type, dtype=self.amp_args.dtype,
enabled=self.amp_args.use_amp):
for batch in np.array_split(data_set.x, num_batches, axis=0):

batch_embeddings, probas = self._create_embeddings(batch)
pbar.update(batch_embeddings.shape[0])
embeddings.extend(batch_embeddings.tolist())
if return_proba:
predictions.extend(probas.tolist())

if return_proba:
return np.array(embeddings), np.array(predictions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_kimcnn(self, device='cuda'):
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train)
indices_labeled = initialize_active_learner(active_learner, train.y)

output = f'{perform_active_learning(active_learner, train, indices_labeled, test, num_iterations)}'
with np.printoptions(edgeitems=10, linewidth=np.inf):
output = f'{perform_active_learning(active_learner, train, indices_labeled, test, num_iterations)}'

verify(output)
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_embed_with_amp_enabled(self):
kwargs['module_selector'] = default_module_selector

embedding_matrix = torch.Tensor(np.random.rand(len(tokenizer.get_vocab()), 100))
amp_args = AMPArguments(use_amp=True, device_type='cuda')
amp_args = AMPArguments(use_amp=True, device_type='cuda', dtype=torch.bfloat16)
classifier = KimCNNClassifier(6,
embedding_matrix=embedding_matrix,
amp_args=amp_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,19 +492,19 @@ def test_with_amp_args_configured(self):
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
self.assertEqual('cpu', clf.amp_args.device_type)
self.assertEqual(torch.bfloat16, clf.amp_args.dtype)
self.assertEqual('cuda', clf.amp_args.device_type)
self.assertEqual(torch.float16, clf.amp_args.dtype)

clf.initialize_transformer(clf.cache_dir)
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
self.assertEqual('cpu', clf.amp_args.device_type)
self.assertEqual(torch.bfloat16, clf.amp_args.dtype)
self.assertEqual('cuda', clf.amp_args.device_type)
self.assertEqual(torch.float16, clf.amp_args.dtype)

clf.model = clf.model.to('cuda')
amp_args = clf.amp_args
self.assertIsNotNone(amp_args)
self.assertFalse(amp_args.use_amp)
self.assertTrue(amp_args.use_amp)
self.assertEqual('cuda', clf.amp_args.device_type)
self.assertEqual(torch.bfloat16, clf.amp_args.dtype)
self.assertEqual(torch.float16, clf.amp_args.dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import pytest
from unittest import mock

import torch

from small_text.integrations.pytorch.exceptions import PytorchNotFoundError
from tests.utils.datasets import twenty_news_transformers
from tests.utils.testing import assert_array_not_equal

try:
from small_text.integrations.pytorch.classifiers.base import AMPArguments
from small_text.integrations.transformers import (
TransformerBasedClassificationFactory,
TransformerBasedEmbeddingMixin,
Expand Down Expand Up @@ -121,6 +124,29 @@ def test_embed_with_proba(self):
self.assertEqual(clf.model.config.hidden_size, embeddings.shape[1])
self.assertEqual(len(train_set), proba.shape[0])

def test_embed_with_amp_args(self):
classifier_kwargs = {
'amp_args': AMPArguments(use_amp=True, device_type='cuda', dtype=torch.bfloat16),
'num_epochs': 1
}
clf_factory = TransformerBasedClassificationFactory(
TransformerModelArguments('sshleifer/tiny-distilroberta-base'),
self.num_classes,
kwargs=classifier_kwargs)

train_set = twenty_news_transformers(20, num_labels=self.num_classes)

clf = clf_factory.new()
clf.fit(train_set)

embeddings, proba = clf.embed(train_set,
return_proba=True,
embedding_method=self.embedding_method)
self.assertEqual(2, len(embeddings.shape))
self.assertEqual(len(train_set), embeddings.shape[0])
self.assertEqual(clf.model.config.hidden_size, embeddings.shape[1])
self.assertEqual(len(train_set), proba.shape[0])


@pytest.mark.pytorch
class EmbeddingAvgBinaryClassificationTest(unittest.TestCase, _EmbeddingTest):
Expand Down
Loading

0 comments on commit 53836d7

Please sign in to comment.