Skip to content

[pt2e to tosa] face AttributeError #1161

@kris-himax

Description

@kris-himax

Hi @Jerry-Ge ,

I have run the https://github.com/pytorch/executorch/blob/main/examples/arm/run.sh example done and success, now I am try to modify it to run a quantize int8 pytorch model which need to pass vela on FVP use ARM Ethous U55.

I use the pytorch mnist classification cnn model and quantize to int8 by convert_pt2e. The result of int8 model seems correct.
And I want to export to executorch which backend is ARM U55, but face AttributeError: 'ReshapeAttribute' object has no attribute 'NewshapeAsNumpy'. Did you mean: 'NewShapeAsNumpy'? while doing edge = edge.to_backend(ArmPartitioner).
How could I fix it?

The following code is my export code.

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QuantStub, DeQuantStub
import cv2
import numpy as np

import argparse
import logging

import torch
import torch._export as export

from executorch.backends.arm.arm_backend import ArmPartitioner
from executorch.exir import EdgeCompileConfig

from ..portable.utils import export_to_edge, save_pte_program
class Net(nn.Module):
    def __init__(self): 
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, 3, 1)
        self.conv2 = nn.Conv2d(8, 16, 3, 1)
        self.conv3 = nn.Conv2d(16, 32, 5, 1)
        self.fc1 = nn.Linear(32, 64)
        self.fc2 = nn.Linear(64, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2,stride=2)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2,stride=2)
        x = self.conv3(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        output = F.softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args.dry_run:
                break


def test(model, device, test_loader):
    # model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

def calibrate(model, data_loader):  
    # model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run', action='store_true', default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')
    parser.add_argument(
        "-d",
        "--delegate",
        action="store_true",
        required=False,
        default=False,
        help="Flag for producing ArmBackend delegated model",
    )
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor()
        ])
    dataset1 = datasets.MNIST('./data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('./data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    float_model = Net().to(device)
    float_model.load_state_dict(torch.load("./pytorch_mnist_cnn_floating.pt"))
    float_model.eval()  
    model_to_quantize = Net().to(device)
    model_to_quantize.load_state_dict(torch.load("./pytorch_mnist_cnn_floating.pt"))
    model_to_quantize.eval()

    from torch._export import capture_pre_autograd_graph

    example_inputs = (torch.randn(1, 1, 28,28),)
    exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs)
    # or capture with dynamic dimensions
    # from torch._export import dynamic_dim
    # exported_model = capture_pre_autograd_graph(model_to_quantize, example_inputs, constraints=[dynamic_dim(example_inputs[0], 0)])
    from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer,
    get_symmetric_quantization_config,
    )
    quantizer = XNNPACKQuantizer()
    quantizer.set_global(get_symmetric_quantization_config())

    from torch.ao.quantization.quantize_pt2e import (
    prepare_pt2e,
    convert_pt2e,
    )
    prepared_model = prepare_pt2e(exported_model, quantizer)
    print(prepared_model.graph)

    calibrate(prepared_model, train_loader)  

    quantized_model = convert_pt2e(prepared_model)

    ################################################################
    ################################################################
    # pre-autograd export. eventually this will become torch.export
    # model = export.capture_pre_autograd_graph(quantized_model, example_inputs)
    print("convert_pt2e(prepared_model)done ")
    edge = export_to_edge(
        quantized_model,
        example_inputs,
        edge_compile_config=EdgeCompileConfig(
            _check_ir_validity=False,
        ),
    )
    print("export_to_edge done ")
    logging.info(f"Exported graph:\n{edge.exported_program().graph}")

    delegate = args.delegate
    model_name = "pytorch_mnist_cnn_ptq_qnnpack"
    if delegate is True:
        edge = edge.to_backend(ArmPartitioner)
        logging.info(f"Lowered graph:\n{edge.exported_program().graph}")
    print("edge.to_backend(ArmPartitioner) done ")
    exec_prog = edge.to_executorch()
    print("edge.to_executorch() done ")
    model_name = f"{model_name}" + (
        "_arm_delegate" if delegate is True else ""
    )
    save_pte_program(exec_prog.buffer, model_name)

    # delegate = args.delegate
    # # model_name = args.model_name + str_qconfig_mapping
    # model_name = args.model_name
    # if delegate is True:
    #     edge = edge.to_backend(ArmPartitioner)
    #     logging.info(f"Lowered graph:\n{edge.exported_program().graph}")

    # exec_prog = edge.to_executorch()

    # model_name = f"{model_name}" + (
    #     "_arm_delegate" if delegate is True else ""
    # )
    # save_pte_program(exec_prog.buffer, model_name)



if __name__ == '__main__':
    main()

Screenshot from 2023-11-07 16-54-56
Screenshot from 2023-11-07 16-48-22

Metadata

Metadata

Assignees

Labels

partner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, ArmtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions