Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Unknown device when trying to run AlbertForMaskedLM on colab tpu #1909

Closed
ayanyuegupta opened this issue Apr 11, 2020 · 8 comments
Assignees
Labels
stale Has not had recent activity

Comments

@ayanyuegupta
Copy link

ayanyuegupta commented Apr 11, 2020

Hi,

I am running the following code on colab taken from the example here: https://huggingface.co/transformers/model_doc/albert.html#albertformaskedlm

import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm

assert os.environ['COLAB_TPU_ADDR']

dev = xm.xla_device()

from transformers import AlbertTokenizer, AlbertForMaskedLM
import torch

tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
model = AlbertForMaskedLM.from_pretrained('albert-base-v2').to(dev)
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1

data = input_ids.to(dev)

outputs = model(data, masked_lm_labels=data)
loss, prediction_scores = outputs[:2]

I haven't done anything to the example code except move input_ids and model onto the TPU device using .to(dev). It seems everything is moved to the TPU no problem as when I input data I get the following output: tensor([[ 2, 10975, 15, 51, 1952, 25, 10901, 3]], device='xla:1')

However when I run this code I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-f756487db8f7> in <module>()
      1 
----> 2 outputs = model(data, masked_lm_labels=data)
      3 loss, prediction_scores = outputs[:2]

9 frames
/usr/local/lib/python3.6/dist-packages/transformers/modeling_albert.py in forward(self, hidden_states, attention_mask, head_mask)
    277         attention_output = self.attention(hidden_states, attention_mask, head_mask)
    278         ffn_output = self.ffn(attention_output[0])
--> 279         ffn_output = self.activation(ffn_output)
    280         ffn_output = self.ffn_output(ffn_output)
    281         hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])

RuntimeError: Unknown device

Anyone know what's going on?

@dlibenzi
Copy link
Collaborator

From a quick peek, looks like the to(xla_device) on the model did not fully convert the model.
Can you try to print ffn_output.device after line 278?

@dlibenzi
Copy link
Collaborator

AFAICT in this code there are a lot of CPU default tensor creations/manipulation and python scalar mode ops.
@jysohn23 Can you take a look since you worked on Huggingface code?

https://github.com/huggingface/transformers/blob/7972a4019f4bc9f85fd358f42249b90f9cd27c68/src/transformers/modeling_albert.py#L184

@ayanyuegupta
Copy link
Author

ayanyuegupta commented Apr 11, 2020

From a quick peek, looks like the to(xla_device) on the model did not fully convert the model.
Can you try to print ffn_output.device after line 278?

Hi! I'm not sure how to do this through colab -- do I navigate to the directory containingmodeling_albert.py and add a print command with a text editor through colab's terminal? Tried with vim/nano but it didn't work.

@jysohn23
Copy link
Collaborator

Hey @goggoloid I got the above running here on this colab notebook.

Basically there were other CPU default tensors so we needed to call:

model = xm.send_cpu_data_to_device(model, xm.xla_device())

Also, Albert currently uses torch.jit-ted gelu activation in huggingface/transformers:master but this branch https://github.com/huggingface/transformers/tree/fix-jit-tpu has the PR that should not jit that activation.

FYI, we also have a TPU GLUE runner in huggingface now too: https://github.com/huggingface/transformers/blob/fix-jit-tpu/examples/run_tpu_glue.py

We'll add one for LM finetuning soon too.

@jysohn23 jysohn23 self-assigned this Apr 11, 2020
@dlibenzi
Copy link
Collaborator

@jysohn23 That is not the optimal fix. For models where the parameters logic is a bit tricky, I think there is an _apply() function to override somewhere.
I noticed also the model create CPU tensors (IIRC zeros()) instead of using the current device the model it has been placed onto.

@ayanyuegupta
Copy link
Author

Thanks everyone!

@stale
Copy link

stale bot commented May 12, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label May 12, 2020
@stale stale bot closed this as completed May 19, 2020
@mobassir94
Copy link

dear @dlibenzi @jysohn23 @goggoloid @pietern @ezyang i was trying to convert this pytorch gpu code in pytorch xla : https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug

my dataloader,train_one_epoch and valid_one_epoch updated code looks like this :

def prepare_dataloader(df, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    from catalyst.data.sampler import BalanceClassSampler
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=False, do_cutmix=False)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)

    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        sampler=train_sampler,
        pin_memory=False,
        drop_last=False,
        #shuffle=True,        
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
    
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_ds,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        )
    
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        sampler=valid_sampler,
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None
    model = xm.send_cpu_data_to_device(model, device)
    model = model.to(device)

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        #print(image_labels.shape, exam_label.shape)
        with autocast():
            image_preds = model(imgs)   #output = model(input)
            #print(image_preds.shape, exam_pred.shape)

            loss = loss_fn(image_preds, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                
                pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

then i designed train_model() function like this:

def train_model():
    # for training only, need nightly build pytorch
    seed_everything(CFG['seed'])
    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first
        if fold > 0:
            break 

        print('Training with {} started'.format(fold))

        print(len(trn_idx), len(val_idx))
        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')

        #device = torch.device(CFG['device'])
        
        device = xm.xla_device()
        
        print(device)
        
        
        
        
        model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True)
        
        #https://stackoverflow.com/questions/61157314/runtimeerror-unknown-device-when-trying-to-run-albertformaskedlm-on-colab-tpu
        #global model
        
        model = xm.send_cpu_data_to_device(model, device)
        
        model = model.to(device)
        
        scaler = GradScaler()   
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=CFG['epochs']-1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
        #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=25, 
        #                                                max_lr=CFG['lr'], epochs=CFG['epochs'], steps_per_epoch=len(train_loader))
        
        loss_tr = nn.CrossEntropyLoss().to(device) #MyCrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        
        
        
        for epoch in range(CFG['epochs']):
            para_loader = pl.ParallelLoader(train_loader, [device])
        
            train_one_epoch(epoch, model, loss_tr, optimizer, para_loader.per_device_loader(device), device, scheduler=scheduler, schd_batch_update=False)

            with torch.no_grad():
                para_loader = pl.ParallelLoader(val_loader, [device])
                valid_one_epoch(epoch, model, loss_fn, para_loader.per_device_loader(device), device, scheduler=None, schd_loss_update=False)

            #torch.save(model.state_dict(),'{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))
            
            xm.save(model.state_dict(),'{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))
            
        #torch.save(model.cnn_model.state_dict(),'{}/cnn_model_fold_{}_{}'.format(CFG['model_path'], fold, CFG['tag']))
        del model, optimizer, train_loader, val_loader, scaler, scheduler
        #torch.cuda.empty_cache()

then when i try to start training process using this code block :

# Start training processes

def _mp_fn(rank, flags):
    global acc_list
    torch.set_default_tensor_type('torch.FloatTensor')
    res = train_model()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')

i get this error then :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-50-1bffc6490088> in <module>
      7 
      8 FLAGS={}
----> 9 xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=1, start_method='fork')

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
    385   pf_cfg = _pre_fork_setup(nprocs)
    386   if pf_cfg.num_devices == 1:
--> 387     _start_fn(0, pf_cfg, fn, args)
    388   else:
    389     return torch.multiprocessing.start_processes(

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in _start_fn(index, pf_cfg, fn, args)
    322   # environment must be fully setup before doing so.
    323   _setup_replication()
--> 324   fn(gindex, *args)
    325 
    326 

<ipython-input-50-1bffc6490088> in _mp_fn(rank, flags)
      4     global acc_list
      5     torch.set_default_tensor_type('torch.FloatTensor')
----> 6     res = train_model()
      7 
      8 FLAGS={}

<ipython-input-47-eb6667460f21> in train_model()
     47             para_loader = pl.ParallelLoader(train_loader, [device])
     48 
---> 49             train_one_epoch(epoch, model, loss_tr, optimizer, para_loader.per_device_loader(device), device, scheduler=scheduler, schd_batch_update=False)
     50 
     51             with torch.no_grad():

<ipython-input-46-7c8958aeba56> in train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler, schd_batch_update)
     59         #print(image_labels.shape, exam_label.shape)
     60         with autocast():
---> 61             image_preds = model(imgs)   #output = model(input)
     62             #print(image_preds.shape, exam_pred.shape)
     63 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-17-efdefc1bcd78> in forward(self, x)
     13         '''
     14     def forward(self, x):
---> 15         x = self.model(x)
     16         return x

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/kaggle/input/pytorch-image-models/pytorch-image-models-master/timm/models/efficientnet.py in forward(self, x)
    388 
    389     def forward(self, x):
--> 390         x = self.forward_features(x)
    391         x = self.global_pool(x)
    392         x = x.flatten(1)

/kaggle/input/pytorch-image-models/pytorch-image-models-master/timm/models/efficientnet.py in forward_features(self, x)
    380         x = self.conv_stem(x)
    381         x = self.bn1(x)
--> 382         x = self.act1(x)
    383         x = self.blocks(x)
    384         x = self.conv_head(x)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/kaggle/input/pytorch-image-models/pytorch-image-models-master/timm/models/layers/activations.py in forward(self, x)
     94 
     95     def forward(self, x):
---> 96         return swish(x, self.inplace)
     97 
     98 

/kaggle/input/pytorch-image-models/pytorch-image-models-master/timm/models/layers/activations.py in swish(x, _inplace)
     46 
     47     def swish(x, _inplace=False):
---> 48         return SwishJitAutoFn.apply(x)
     49 
     50 

/kaggle/input/pytorch-image-models/pytorch-image-models-master/timm/models/layers/activations.py in forward(ctx, x)
     37         def forward(ctx, x):
     38             ctx.save_for_backward(x)
---> 39             return swish_jit_fwd(x)
     40 
     41         @staticmethod

RuntimeError: Unknown device

i need two help please :

  1. if i do nprocs=8 then i get this error Exception: process 0 terminated with exit code 17 where to make change in my code to get rid of this error?
  2. if i do nprocs=1 i get RuntimeError: Unknown device how to solve this issue?

xm.xrt_world_size() prints 1

and when i do : xm.get_xla_supported_devices()

i get this : ['xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7', 'xla:8']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

4 participants