Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend worker monitoring thread interrupted or backend worker process died. #537

Closed
fancyerii opened this issue Jul 17, 2020 · 8 comments
Closed
Assignees
Labels
triaged_wait Waiting for the Reporter's resp

Comments

@fancyerii
Copy link

I have used huggingface transformer to train a text classification model and deploy it to pytorch serv.

  1. training and saving model
    The full codes can be found here. The code is not well organized but the model is very simple:
class OrderClassifier(nn.Module):
    def __init__(self, n_classes, path):
        super(OrderClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(path)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        output = self.drop(pooled_output)
        return self.out(output)

After training, I saved the best whole model:

    if val_acc > best_accuracy:
        torch.save(model, 'best_model.bin')
  1. create handler testtorchserving.py
    The full code is here.
    It seems the initialization part throws exception, So I paste the codes here:
    def initialize(self, ctx):
        self.manifest = ctx.manifest

        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        print("model: "+model_dir+"/best_model.bin")
        self.model = torch.load(model_dir+"/best_model.bin")
        #AutoModelForSequenceClassification.from_pretrained(model_dir)
        print("load model success")
        self.tokenizer = BertTokenizer.from_pretrained(model_dir)
        print("load tokenizer success")
        self.model.to(self.device)
        self.model.eval()
        print("eval success")

because I can't find any clue in ts_log.log, I print it out to find out where the problem occurs. It seems problem occurs in this line:

 self.model = torch.load(model_dir+"/best_model.bin")
  1. create mar
torch-model-archiver --model-name "order" --version 1.0 --serialized-file ./model/best_model.bin --extra-files "/home/lili/data/huggface/bert-base-chinese/config.json,/home/lili/data/huggface/bert-base-chinese/vocab.txt" --handler "./testtorchserving.py"
  1. running
mkdir model_store && mv order.mar model_store && torchserve --start --model-store model_store --models order=order.mar
  1. the error log
    The full log is here.

I found relavant logs here:

2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - model: /tmp/models/5886359598784a97ace9c91df12d99590ade3efe/best_model.bin
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Backend worker process died.
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Traceback (most recent call last):
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py", line 175, in <module>
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     worker.run_server()
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py", line 147, in run_server
2020-07-17 08:57:53,309 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     self.handle_connection(cl_socket)
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py", line 111, in handle_connection
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service, result, code = self.load_model(msg)
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py", line 84, in load_model
2020-07-17 08:57:53,310 [INFO ] epollEventLoopGroup-4-1 org.pytorch.serve.wlm.WorkerThread - 9000 Worker disconnected. WORKER_STARTED
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size)
2020-07-17 08:57:53,310 [DEBUG] W-9000-order_1.0 org.pytorch.serve.wlm.WorkerThread - System state is : WORKER_STARTED
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_loader.py", line 102, in load
2020-07-17 08:57:53,310 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     entry_point(None, service.context)
2020-07-17 08:57:53,311 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/tmp/models/5886359598784a97ace9c91df12d99590ade3efe/testtorchserving.py", line 129, in handle
2020-07-17 08:57:53,311 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     raise e
2020-07-17 08:57:53,311 [DEBUG] W-9000-order_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException
	at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:1668)
	at java.base/java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:435)
	at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:129)
	at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)
2020-07-17 08:57:53,311 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/tmp/models/5886359598784a97ace9c91df12d99590ade3efe/testtorchserving.py", line 118, in handle
2020-07-17 08:57:53,312 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     _service.initialize(context)
2020-07-17 08:57:53,312 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/tmp/models/5886359598784a97ace9c91df12d99590ade3efe/testtorchserving.py", line 49, in initialize
2020-07-17 08:57:53,313 [INFO ] W-9000-order_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     self.model = torch.load(model_dir+"/best_model.bin")
2020-07-17 08:57:53,313 [WARN ] W-9000-order_1.0 org.pytorch.serve.wlm.BatchAggregator - Load model failed: order, error: Worker died.

It print out "model: /tmp/models/5886359598784a97ace9c91df12d99590ade3efe/best_model.bin", which is before "self.model = torch.load(model_dir+"/best_model.bin")".
And "print("load model success")" is not executed.
So I guess "self.model = torch.load(model_dir+"/best_model.bin")" failed.

So I tried to load this model to check whether is good.

import torch
from transformers import BertTokenizer, BertModel
from torch import nn
class OrderClassifier(nn.Module):
    def __init__(self, n_classes, path):
        super(OrderClassifier, self).__init__()

        self.bert = BertModel.from_pretrained(path)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        output = self.drop(pooled_output)
        return self.out(output)


model_dir="/tmp/models/5886359598784a97ace9c91df12d99590ade3efe"
model = torch.load(model_dir + "/best_model.bin")
print(model) 
tokenizer = BertTokenizer.from_pretrained(model_dir)

The codes above are correctly executed.

So What's wrong with it?

@dhaniram-kshirsagar dhaniram-kshirsagar self-assigned this Jul 17, 2020
@harshbafna
Copy link
Contributor

@fancyerii, From the shared logs I could observe the following exception as the root cause of failure while creating the worker processes :

AttributeError: Can't get attribute 'OrderClassifier' on <module '__main__' from '/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py'>

Looks like some environment-related issue. Please share the following details :

  • torchserve version:
  • torch version:
  • torchvision version [if any]:
  • torchtext version [if any]:
  • torchaudio version [if any]:
  • java version:
  • Operating System and version:
  • Installed using source? [yes/no]:
  • Are you using a docker container? [yes/no]:
  • Training DataSet used for training the model, so that we can reproduce the issue at our end.

@harshbafna harshbafna added the triaged_wait Waiting for the Reporter's resp label Jul 18, 2020
@fancyerii
Copy link
Author

@fancyerii, From the shared logs I could observe the following exception as the root cause of failure while creating the worker processes :

AttributeError: Can't get attribute 'OrderClassifier' on <module '__main__' from '/home/lili/env-huggface/lib/python3.6/site-packages/ts/model_service_worker.py'>

Looks like some environment-related issue. Please share the following details :

* torchserve version:

* torch version:

* torchvision version [if any]:

* torchtext version [if any]:

* torchaudio version [if any]:

* java version:

* Operating System and version:

* Installed using source? [yes/no]:

* Are you using a docker container? [yes/no]:

* Training DataSet used for training the model, so that we can reproduce the issue at our end.

I have copied all the codes of OrderClassifier to the handler(testtorchserving.py) but it seemed that it's not imported?

torchserve version: 0.1.1
torch version: 1.5.0+cu101
torchvision version [if any]: 0.6.0+cu101
torchtext version [if any]: 0.6.0
torchaudio version [if any]:
java version: jdk-14.0.2
Operating System and version: Ubuntu 16.04
Installed using source? [yes/no]: no
Are you using a docker container? [yes/no]: no
Training DataSet used for training the model, so that we can reproduce the issue at our end.

@harshbafna harshbafna removed the triaged_wait Waiting for the Reporter's resp label Jul 20, 2020
@harshbafna
Copy link
Contributor

@fancyerii : Somehow the torch.load() call in your handler failed to unpickle the serialized model. Could you please share the following training datasets used or the model's mar file.

train_data = read_data("/home/lili/data/order_data_v2/train_data_v2.csv")
test_data = read_data("/home/lili/data/order_data_v2/test_data_v2.csv")

@dhaniram-kshirsagar
Copy link
Contributor

@fancyerii There is problem in the way you are saving and loading your model. And it is generally not a recommended way of saving model for production or external deployment. This is what you are using -> Save/Load Entire Model for details.

Pytorch recommended way ->

Save:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

@harshbafna harshbafna added the triaged_wait Waiting for the Reporter's resp label Jul 30, 2020
@Bartlett-Will
Copy link

@fancyerii There is problem in the way you are saving and loading your model. And it is generally not a recommended way of saving model for production or external deployment. This is what you are using -> Save/Load Entire Model for details.

Pytorch recommended way ->

Save:

torch.save(model.state_dict(), PATH)

Load:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

I have a similar problem and the line that saves the model is from the transformer library.
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir)

Then to load the model you can use: ~transformers.PreTrainedModel.from_pretrained

Will this cause problems with Torchserve?

@Bartlett-Will
Copy link

This seems related to #283

@prashantsail
Copy link
Contributor

Hi @Bartlett-Will

save_pretrained() internally uses the pytorch recommended way for saving the model.
Hence you would not run into the issue @fancyerii is facing.

I believe this query was regarding your issue #617 .
Please have a look at the Huggingface Transformers Example
A resolution is provided for you issue as well. Please confirm if it works for you.

@maaquib
Copy link
Collaborator

maaquib commented Sep 8, 2020

@fancyerii @Bartlett-Will Closing this due to lack of activity. Please reopen if this is still an issue

@maaquib maaquib closed this as completed Sep 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged_wait Waiting for the Reporter's resp
Projects
None yet
Development

No branches or pull requests

6 participants