-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Description
Issue description
I'm finetuning a BERT model , the model trains just fine on GPU but when I use TPU on Kaggle kernels or Colab it produces the following stack trace
Exception in device=TPU:5: torch_xla/csrc/aten_xla_type.cpp:163 : Check failed: at::canCast( resultType, out.scalar_type())
*** Begin stack trace ***
tensorflow::CurrentStackTrace()
torch_xla::AtenXlaType::mul_(at::Tensor&, at::Tensor const&)
at::Tensor::mul_(at::Tensor const&) const
at::native::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long)
torch_xla::AtenXlaType::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long> >, at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
at::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
at::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
Code example
I believe the error happens when I'm calculating the loss using nn.BCEWithLogitsLoss()
due to some type error?
this is how I'm doing it in training:
def train_step(engine, batch):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
token_type_ids = batch["token_type_ids"]
labels = batch["label"].unsqueeze(-1)
if input_ids.device != device:
input_ids = input_ids.to(device, non_blocking=True,dtype=torch.long)
attention_mask = attention_mask.to(device, non_blocking=True,dtype=torch.long)
token_type_ids = token_type_ids.to(device, non_blocking=True,dtype=torch.long)
labels = labels.to(device, non_blocking=True,dtype=torch.float)
model.train()
with autocast(enabled=with_amp):
output = model(input_ids, attention_mask,token_type_ids).float()
loss = criterion(output,labels)
Metadata
Metadata
Assignees
Labels
No labels