Skip to content

Model training works on GPU but not on Pytorch/XLA #2785

@ahmedo42

Description

@ahmedo42

Issue description

I'm finetuning a BERT model , the model trains just fine on GPU but when I use TPU on Kaggle kernels or Colab it produces the following stack trace

Exception in device=TPU:5: torch_xla/csrc/aten_xla_type.cpp:163 : Check failed: at::canCast( resultType, out.scalar_type()) 
*** Begin stack trace ***
	tensorflow::CurrentStackTrace()
	
	torch_xla::AtenXlaType::mul_(at::Tensor&, at::Tensor const&)
	at::Tensor::mul_(at::Tensor const&) const
	at::native::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, at::Tensor const&, at::Tensor const&, long)
	torch_xla::AtenXlaType::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
	c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoRuntimeFunctor_<at::Tensor (*)(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long), at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long> >, at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
	at::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)
	
	at::binary_cross_entropy_with_logits(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::optional<at::Tensor> const&, long)

Code example

I believe the error happens when I'm calculating the loss using nn.BCEWithLogitsLoss() due to some type error?

this is how I'm doing it in training:

def train_step(engine, batch):

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        labels = batch["label"].unsqueeze(-1)

        if input_ids.device != device:
            input_ids = input_ids.to(device, non_blocking=True,dtype=torch.long)
            attention_mask = attention_mask.to(device, non_blocking=True,dtype=torch.long)
            token_type_ids = token_type_ids.to(device, non_blocking=True,dtype=torch.long)
            labels = labels.to(device, non_blocking=True,dtype=torch.float)

        model.train()

        with autocast(enabled=with_amp):
            output = model(input_ids, attention_mask,token_type_ids).float()
            loss = criterion(output,labels)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions