In [20]:
import torch
from torch.fx import symbolic_trace
import torch.fx as fx

In [21]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

In [22]:
module = MyModule()

In [23]:
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

In [24]:
print(symbolic_traced.graph)

graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp


In [25]:
print(symbolic_traced.code)




def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
    


In [None]:
def transform(m: nn.Module,
              tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
    # Step 1: Acquire a Graph representing the code in `m`

    # NOTE: torch.fx.symbolic_trace is a wrapper around a call to
    # fx.Tracer.trace and constructing a GraphModule. We'll
    # split that out in our transform to allow the caller to
    # customize tracing behavior.
    graph : torch.fx.Graph = tracer_class().trace(m)

    # Step 2: Modify this Graph or create a new one
    graph = ...

    # Step 3: Construct a Module to return
    return torch.fx.GraphModule(m, graph)

In [None]:
def transform(m : nn.Module) -> nn.Module:
    gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)

    # Modify gm.graph
    # <...>

    # Recompile the forward() method of `gm` from its Graph
    gm.recompile()

    return gm

In [10]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(
            self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

gm.graph.print_tabular()

opcode         name           target                                                   args                kwargs
-------------  -------------  -------------------------------------------------------  ------------------  -----------
placeholder    x              x                                                        ()                  {}
get_attr       linear_weight  linear.weight                                            ()                  {}
call_function  add            <built-in function add>                                  (x, linear_weight)  {}
call_module    linear         linear                                                   (add,)              {}
call_method    relu           relu                                                     (linear,)           {}
call_function  sum_1          <built-in method sum of type object at 0x7f55f11bd540>   (relu,)             {'dim': -1}
call_function  topk           <built-in method topk of type object at 0x7f55f11bd540>  (sum_1, 3) 

In [16]:
# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph : fx.Graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.

    return fx.GraphModule(m, graph)

In [17]:
m = M()

In [18]:
m(1, 2)

tensor(3)

In [19]:
n = transform(m)

In [20]:
n(2, 3)

tensor(6)

In [7]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-tiny')

In [8]:
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)

TraceError: symbolically traced variables cannot be used as inputs to control flow

In [1]:
import torch  # This is all you need to use both PyTorch and TorchScript!
print(torch.__version__)
torch.manual_seed(191009)  # set the seed for reproducibility

2.0.1


<torch._C.Generator at 0x7f80688be110>

In [7]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.dg = MyDecisionGate()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

In [8]:
my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

(tensor([[ 0.2155, -0.0853,  0.3186,  0.6389],
        [ 0.5300,  0.5265,  0.6738,  0.5114],
        [ 0.3993,  0.5134, -0.2223,  0.4750]], grad_fn=<TanhBackward0>), tensor([[ 0.2155, -0.0853,  0.3186,  0.6389],
        [ 0.5300,  0.5265,  0.6738,  0.5114],
        [ 0.3993,  0.5134, -0.2223,  0.4750]], grad_fn=<TanhBackward0>))


In [9]:
class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
traced_cell(x, h)

MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)


(tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
         [-0.2156,  0.7301,  0.3347,  0.3326],
         [-0.2860, -0.1094,  0.4924,  0.5432]], grad_fn=<TanhBackward0>),
 tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
         [-0.2156,  0.7301,  0.3347,  0.3326],
         [-0.2860, -0.1094,  0.4924,  0.5432]], grad_fn=<TanhBackward0>))

In [10]:
print(traced_cell.graph)

graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/folders/by/rktr_w596p97pmt8_cbknvs80000gn/T/ipykernel_4017/260609686.py:7:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/folders/by/rktr_w596p97pmt8_cbknvs80000gn/T/ipykernel_4017/260609686.py:7:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/folders/by/rktr_w596p97pmt8_cbknvs80000gn/T/ipykernel_4017/260609686.py:7:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)



In [11]:
print(traced_cell.code)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)



In [12]:
print(my_cell(x, h))
print(traced_cell(x, h))

(tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
        [-0.2156,  0.7301,  0.3347,  0.3326],
        [-0.2860, -0.1094,  0.4924,  0.5432]], grad_fn=<TanhBackward0>), tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
        [-0.2156,  0.7301,  0.3347,  0.3326],
        [-0.2860, -0.1094,  0.4924,  0.5432]], grad_fn=<TanhBackward0>))
(tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
        [-0.2156,  0.7301,  0.3347,  0.3326],
        [-0.2860, -0.1094,  0.4924,  0.5432]],
       grad_fn=<DifferentiableGraphBackward>), tensor([[ 0.0123,  0.0657,  0.7297,  0.6066],
        [-0.2156,  0.7301,  0.3347,  0.3326],
        [-0.2860, -0.1094,  0.4924,  0.5432]],
       grad_fn=<DifferentiableGraphBackward>))


In [13]:
class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))  # control flow erased

print(traced_cell.dg.code)
print(traced_cell.code)

def forward(self,
    argument_1: Tensor) -> Tensor:
  return torch.neg(argument_1)

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  _1 = torch.tanh(_0)
  return (_1, _1)



  if x.sum() > 0:


In [14]:
scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)

def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)



In [15]:
# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))

(tensor([[0.1185, 0.8707, 0.5898, 0.7261],
        [0.4802, 0.2835, 0.4302, 0.9021],
        [0.4490, 0.9199, 0.6838, 0.6419]], grad_fn=<TanhBackward0>), tensor([[0.1185, 0.8707, 0.5898, 0.7261],
        [0.4802, 0.2835, 0.4302, 0.9021],
        [0.4490, 0.9199, 0.6838, 0.6419]], grad_fn=<TanhBackward0>))


In [1]:
#from transformers.utils.fx import symbolic_trace
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, AutoProcessor, AutoModelForSpeechSeq2Seq
from datasets import load_dataset
import torch
from torch.fx import symbolic_trace
import torch.fx as fx

In [2]:
model_name_or_path = 'openai/whisper-small'
data_dir = 'mozilla-foundation/common_voice_11_0'

In [3]:
# load dataset
print('loading dataset from {}'.format(data_dir))

raw_datasets = load_dataset(data_dir, "zh-CN", split="test", streaming=True)
text_column_name = 'sentence'


# model, tokenizer, feature extractor, processor

model_config = AutoConfig.from_pretrained(
    model_name_or_path,
    #cache_dir=args.cache_dir,
    #revision=args.model_revision,
    #use_auth_token=True if args.use_auth_token else None,
)

model_config.update({"forced_decoder_ids": [], "suppress_tokens": []})


feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_name_or_path,
    #cache_dir=args.cache_dir,
    #revision=args.model_revision,
    #use_auth_token=True if args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    #cache_dir=args.cache_dir,
    #use_fast=model_args.use_fast_tokenizer,
    #revision=model_args.model_revision,
    #use_auth_token=True if model_args.use_auth_token else None,
)

tokenizer.set_prefix_tokens(language='chinese', task='transcribe')


model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_name_or_path,
    config=model_config,
    #cache_dir=args.cache_dir,
    #revision=args.model_revision,
    #use_auth_token=True if args.use_auth_token else None,
)

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
    
processor = AutoProcessor.from_pretrained(model_name_or_path)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language='chinese', task='transcribe')


dataset = raw_datasets

model.eval()


sample = next(iter(dataset))


inputs = processor(
    sample['audio']["array"],
    sampling_rate=feature_extractor.sampling_rate,
    return_attention_mask=True,
    return_tensors="pt")
                
input_features = inputs.input_features
attention_mask = inputs.attention_mask
decoder_input_ids = torch.tensor([model.config.decoder_start_token_id]).reshape(1, -1)

loading dataset from mozilla-foundation/common_voice_11_0


Reading metadata...: 10581it [00:00, 19630.32it/s]


In [28]:
symbolic_trace(model, ['input_features', 'attention_mask', 'decoder_input_ids'])

NotImplementedError: Model WhisperForConditionalGeneration is not supported yet, supported models: AlbertForMaskedLM, AlbertForMultipleChoice, AlbertForPreTraining, AlbertForQuestionAnswering, AlbertForSequenceClassification, AlbertForTokenClassification, AlbertModel, AltCLIPModel, AltCLIPTextModel, AltCLIPVisionModel, BartForCausalLM, BartForConditionalGeneration, BartForQuestionAnswering, BartForSequenceClassification, BartModel, BertForMaskedLM, BertForMultipleChoice, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, BertForTokenClassification, BertLMHeadModel, BertModel, BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotSmallForCausalLM, BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel, BloomForCausalLM, BloomForQuestionAnswering, BloomForSequenceClassification, BloomForTokenClassification, BloomModel, CLIPModel, CLIPTextModel, CLIPTextModelWithProjection, CLIPVisionModel, CLIPVisionModelWithProjection, ConvNextBackbone, ConvNextForImageClassification, ConvNextModel, DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForSequenceClassification, DebertaForTokenClassification, DebertaModel, DebertaV2ForMaskedLM, DebertaV2ForMultipleChoice, DebertaV2ForQuestionAnswering, DebertaV2ForSequenceClassification, DebertaV2ForTokenClassification, DebertaV2Model, DistilBertForMaskedLM, DistilBertForMultipleChoice, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForTokenClassification, DistilBertModel, DonutSwinModel, ElectraForCausalLM, ElectraForMaskedLM, ElectraForMultipleChoice, ElectraForPreTraining, ElectraForQuestionAnswering, ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel, GPT2DoubleHeadsModel, GPT2ForQuestionAnswering, GPT2ForSequenceClassification, GPT2ForTokenClassification, GPT2LMHeadModel, GPT2Model, GPTJForCausalLM, GPTJForQuestionAnswering, GPTJForSequenceClassification, GPTJModel, GPTNeoForCausalLM, GPTNeoForQuestionAnswering, GPTNeoForSequenceClassification, GPTNeoForTokenClassification, GPTNeoModel, GitVisionModel, HubertForCTC, HubertForSequenceClassification, HubertModel, LayoutLMForMaskedLM, LayoutLMForQuestionAnswering, LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel, M2M100ForConditionalGeneration, M2M100Model, MBartForCausalLM, MBartForConditionalGeneration, MBartForQuestionAnswering, MBartForSequenceClassification, MBartModel, MT5ForConditionalGeneration, MT5ForQuestionAnswering, MT5Model, MarianForCausalLM, MarianMTModel, MarianModel, MegatronBertForCausalLM, MegatronBertForMaskedLM, MegatronBertForMultipleChoice, MegatronBertForNextSentencePrediction, MegatronBertForPreTraining, MegatronBertForQuestionAnswering, MegatronBertForSequenceClassification, MegatronBertForTokenClassification, MegatronBertModel, MobileBertForMaskedLM, MobileBertForMultipleChoice, MobileBertForNextSentencePrediction, MobileBertForPreTraining, MobileBertForQuestionAnswering, MobileBertForSequenceClassification, MobileBertForTokenClassification, MobileBertModel, NezhaForMaskedLM, NezhaForMultipleChoice, NezhaForNextSentencePrediction, NezhaForPreTraining, NezhaForQuestionAnswering, NezhaForSequenceClassification, NezhaForTokenClassification, NezhaModel, OPTForCausalLM, OPTForQuestionAnswering, OPTForSequenceClassification, OPTModel, PLBartForCausalLM, PLBartForConditionalGeneration, PLBartForSequenceClassification, PLBartModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM, PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel, ResNetBackbone, ResNetForImageClassification, ResNetModel, RobertaForCausalLM, RobertaForMaskedLM, RobertaForMultipleChoice, RobertaForQuestionAnswering, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, SegformerForImageClassification, SegformerForSemanticSegmentation, SegformerModel, Speech2Text2Decoder, Speech2Text2ForCausalLM, Speech2TextForConditionalGeneration, Speech2TextModel, SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel, T5ForConditionalGeneration, T5ForQuestionAnswering, T5Model, TrOCRDecoder, TrOCRForCausalLM, ViTForImageClassification, ViTForMaskedImageModeling, ViTModel, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2ForSequenceClassification, Wav2Vec2Model, XGLMForCausalLM, XGLMModel

In [4]:
encoder = model.get_encoder()

In [5]:
symbolic_traced : torch.fx.GraphModule = symbolic_trace(encoder)

In [6]:
print(symbolic_traced.code)




def forward(self, input_features, attention_mask = None, head_mask = None, output_attentions = None, output_hidden_states = None, return_dict = None):
    conv1 = self.conv1(input_features);  input_features = None
    gelu = torch._C._nn.gelu(conv1);  conv1 = None
    conv2 = self.conv2(gelu);  gelu = None
    gelu_1 = torch._C._nn.gelu(conv2);  conv2 = None
    permute = gelu_1.permute(0, 2, 1);  gelu_1 = None
    embed_positions_weight = self.embed_positions.weight
    add = permute + embed_positions_weight;  permute = embed_positions_weight = None
    dropout = torch.nn.functional.dropout(add, p = 0.0, training = False, inplace = False);  add = None
    getitem = head_mask[0]
    layers_0_self_attn_layer_norm = getattr(self.layers, "0").self_attn_layer_norm(dropout)
    size = layers_0_self_attn_layer_norm.size()
    getitem_1 = size[0]
    getitem_2 = size[1]
    getitem_3 = size[2];  size = None
    layers_0_self_attn_q_proj = getattr(self.layers, "0").self_attn.q_proj(layers_0

In [4]:
symbolic_traced : torch.fx.GraphModule = symbolic_trace(model)

TraceError: symbolically traced variables cannot be used as inputs to control flow