In [4]:
import transformers
import torch as th


In [8]:
model = transformers.BertGenerationConfig()
model

BertGenerationConfig {
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 2,
  "eos_token_id": 1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert-generation",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 50358
}

In [5]:
device = "cuda:0" if th.cuda.is_available() else "cpu"
print(f"Use device: {device}")

Use device: cpu


In [10]:
device = th.device('cpu')
x = th.rand((10000, 10000), dtype=th.float32)
y = th.rand((10000, 10000), dtype=th.float32)
x = x.to(device)
y = y.to(device)

In [11]:
x*y

tensor([[0.1850, 0.1728, 0.0564,  ..., 0.3192, 0.9514, 0.2227],
        [0.2243, 0.6131, 0.0697,  ..., 0.0159, 0.0100, 0.1830],
        [0.2376, 0.0264, 0.1001,  ..., 0.3630, 0.3957, 0.1977],
        ...,
        [0.4425, 0.0247, 0.0834,  ..., 0.0238, 0.3545, 0.0566],
        [0.3793, 0.3825, 0.7382,  ..., 0.7046, 0.0879, 0.4709],
        [0.2573, 0.6845, 0.1712,  ..., 0.5605, 0.3641, 0.2750]])

In [12]:
device = th.device('mps')
x = th.rand((10000, 10000), dtype=th.float32)
y = th.rand((10000, 10000), dtype=th.float32)
x = x.to(device)
y = y.to(device)

In [13]:
x*y

tensor([[0.2557, 0.3103, 0.0802,  ..., 0.0025, 0.1610, 0.2204],
        [0.0038, 0.1113, 0.0733,  ..., 0.0086, 0.1609, 0.1734],
        [0.1805, 0.0270, 0.2928,  ..., 0.2096, 0.0815, 0.1382],
        ...,
        [0.5709, 0.5700, 0.3107,  ..., 0.7539, 0.1653, 0.4005],
        [0.0097, 0.3807, 0.2067,  ..., 0.3045, 0.5589, 0.0166],
        [0.3125, 0.2030, 0.4086,  ..., 0.3201, 0.2113, 0.1653]],
       device='mps:0')

In [14]:
print(f"torch backend MPS is available? {th.backends.mps.is_available()}")
print(f"current PyTorch installation built with MPS activated? {th.backends.mps.is_built()}")
print(f"check the torch MPS backend: {th.device('mps')}")
print(f"test torch tensor on MPS: {th.tensor([1,2,3], device='mps')}")

torch backend MPS is available? True
current PyTorch installation built with MPS activated? True
check the torch MPS backend: mps
test torch tensor on MPS: tensor([1, 2, 3], device='mps:0')
