From a790f86bf2c67b404bf34ee254e85f01329e4fa7 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 15 Aug 2024 11:04:09 -0700 Subject: [PATCH] Remove build/model_aoti.py Summary: DSOModel is not used anywhere, since we use torch._export.aot_load to load the AOTI-compiled model.so. --- build/model_aoti.py | 65 --------------------------------------------- 1 file changed, 65 deletions(-) delete mode 100644 build/model_aoti.py diff --git a/build/model_aoti.py b/build/model_aoti.py deleted file mode 100644 index 10560d957..000000000 --- a/build/model_aoti.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch._inductor.codecache import AsyncCompile - -# with open("./dso_model.h", "rb") as f: -# dso_src = f.read().decode("utf-8") - -dso_src = "" - -src = """ -#include -#include - -#define MODELPATH "***my_model.so***" - -torch::inductor::AOTIModelContainerRunnerCpu *transformer_dso = -new torch::inductor::AOTIModelContainerRunnerCpu(MODELPATH, 1); - -extern "C" void kernel(long *tokens, long *pos, float *logits) -{ - torch::Tensor token_tensor = torch::from_blob( - tokens, {1, 1}, torch::kLong); - torch::Tensor pos_tensor = torch::from_blob(pos, { 1 }, torch::kLong); - std::vector inputs{token_tensor, pos_tensor}; - - std::vector result = transformer_dso -> run(inputs); - std::memcpy(logits, result[0].data_ptr(), result[0].numel()*sizeof(float)); -} - -""" - - -class DSOModel(nn.Module): - def __init__(self, config, dso_path) -> None: - super().__init__() - self.config = config - - # build transformer model - global src, dso_src - - src = src.replace("***my_model.so***", str(dso_path)) - async_compile = AsyncCompile() - self.transformer_model = async_compile.cpp_pybinding( - ["long *", "long *", "float *"], dso_src + src - ) - async_compile.wait(globals()) - del async_compile - - def forward(self, x, input_pos): - vocab_size = self.config.vocab_size # 32000 - assert x.dim() == 2 and x.size(0) == 1 and x.size(1) == 1 - logits = torch.empty(1, 1, vocab_size) - x = x.to(torch.long) - input_pos = input_pos.to(torch.long) - self.transformer_model(x, input_pos, logits) - return logits - - def setup_caches(self, max_batch_size, max_seq_length): - pass