In [1]:
import torch
from torch import nn
import pybuda
from PIL import Image
import torchvision
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("dyllanesl/ASL_Classifier")
model = AutoModelForImageClassification.from_pretrained("dyllanesl/ASL_Classifier")



In [3]:
tt0 = pybuda.TTDevice(
    name="tt_device_0",  # here we can give our device any name we wish, for tracking purposes
    arch=pybuda.BackendDevice.Grayskull
)

In [4]:
# Create module
pybuda_module = pybuda.PyTorchModule(
    name = "asl_model",  # give the module a name, this will be used for tracking purposes
    module=model  # specify the model that is being targeted for compilation
)

# Place module on device
tt0.place_module(module=pybuda_module)

In [5]:
from datasets import load_dataset

# Load the dataset from Hugging Face
dataset = load_dataset('raulit04/ASL_Dataset1')['train']

Using custom data configuration raulit04--ASL_Dataset1-d033ce9363c88848
Reusing dataset parquet (/home/user/.cache/huggingface/datasets/raulit04___parquet/raulit04--ASL_Dataset1-d033ce9363c88848/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)
100%|██████████| 1/1 [00:00<00:00, 571.43it/s]


In [6]:
# Create a mapping from labels to IDs
label_list = dataset.unique('label')

label_to_id = {label: idx for idx, label in enumerate(label_list)}
id_to_label = {idx: label for idx, label in enumerate(label_to_id)}

In [7]:
# Set PyBUDA configuration parameters
# STEP 1 : Set PyBuda configuration parameters
import os
# STEP 1 : Set PyBuda configuration parameters
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.balancer_policy = "Ribbon"
compiler_cfg.default_df_override = pybuda.DataFormat.Float32


In [8]:
def setup_image(image: Image):
    image.convert("RGB")
    processed_tensor = processor(images=image, return_tensors='pt')
    return processed_tensor['pixel_values']

In [None]:
# output = pybuda_module.run(input_tensor)  # executes compilation (if first time) + runtime
# print('output: ', output)

In [18]:
def get_prediction_given_tensor(input_tensor):
    tt0.push_to_inputs((input_tensor,))
    # output = pybuda_module.run(input_tensor)  # executes compilation (if first time) + runtime
    output_q = pybuda.run_inference()
    output = output_q.get()
    output_tensor = output[0].value()
    pred = output_tensor.argmax(-1).item()
    return id_to_label[pred]
    print('\n\n\n\n\n\n\n\n\n')
    print('output: ', output_tensor)
    print('\n\n\n\n\n\n\n\n\n')
    return output_tensor

In [None]:
dataset_length = len(dataset['image'])
correct = 0
for i, (image, label) in enumerate(zip(dataset['image'], dataset['label'])):
    if i == 0:
        print('image: ')
        display(image)
    readied_tensor = setup_image(image)
    guessed_label = get_prediction_given_tensor(readied_tensor)
    
    if i < 5:
        print('actual label: ', label, ' guessed label: ', guessed_label)
    if i ==5:
        print("you get the deal. I'll just print out the accuracy")
    correct += 1 if label == guessed_label else 0
    # break

print('accuracy: ', correct/dataset_length)

In [11]:

image_tensor = setup_image(dataset['image'][0])
tt0.push_to_inputs((image_tensor,))
# output = pybuda_module.run(input_tensor)  # executes compilation (if first time) + runtime
output_q = pybuda.run_inference()
# print('output: ', output)
output = output_q.get() # get last value from output queue

2024-08-02 23:59:43.473 | DEBUG    | pybuda.run.impl:_run_forward:644 - Running concurrent device forward: TTDevice 'tt_device_0'


[32m2024-08-02 23:59:43.479[0m | [1m[38;2;100;149;237mINFO    [0m | [36mRuntime        [0m - Running program 'run_fwd_0' with params [("$p_loop_count", "1")]


2024-08-02 23:59:43.478 | DEBUG    | pybuda.device:run_next_command:429 - Received RUN_FORWARD command on TTDevice 'tt_device_0' / 280610
2024-08-02 23:59:43.478 | DEBUG    | pybuda.ttdevice:forward:906 - Starting forward on TTDevice 'tt_device_0'
2024-08-02 23:59:43.478 | DEBUG    | pybuda.backend:feeder_thread_main:171 - Run feeder thread cmd: fwd
2024-08-02 23:59:43.479 | DEBUG    | pybuda.backend:read_queues:345 - Reading output queue asl_model.output_add_636
2024-08-02 23:59:43.481 | DEBUG    | pybuda.device_connector:pusher_thread_main:163 - Pusher thread pushing tensors
2024-08-02 23:59:43.481 | DEBUG    | pybuda.backend:push_to_queues:452 - Pushing to queue pixel_values
2024-08-02 23:59:43.541 | DEBUG    | pybuda.backend:read_queues:415 - Done reading queues
2024-08-02 23:59:43.542 | DEBUG    | pybuda.backend:pop_queues:421 - Popping from queue asl_model.output_add_636


In [17]:
output[0].value()

tensor([[12.0000,  0.2148,  0.5625,  0.4688,  2.8906,  0.6445, -0.5430, -0.4180,
         -0.3945,  0.1621, -1.3828,  1.3281, -0.7812, -1.0391, -1.9688, -0.1777,
         -1.5234,  0.8789, -2.1406, -0.0403, -0.7500, -1.5703, -1.4922, -0.3027,
         -1.5781, -2.1875]], dtype=torch.bfloat16, requires_grad=True)

In [16]:
output[0].value().argmax().item()

0

In [None]:
pybuda.shutdown()

In [None]:
tt0.remove_modules()

In [None]:
import pybuda
import torch


# Sample PyTorch module
class PyTorchTestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weights1 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
        self.weights2 = torch.nn.Parameter(torch.rand(32, 32), requires_grad=True)
    def forward(self, act1, act2):
        m1 = torch.matmul(act1, self.weights1)
        m2 = torch.matmul(act2, self.weights2)
        return m1 + m2, m1


def test_module_direct_pytorch():
    input1 = torch.rand(4, 32, 32)
    input2 = torch.rand(4, 32, 32)
    # Run single inference pass on a PyTorch module, using a wrapper to convert to PyBuda first
    output = pybuda.PyTorchModule("direct_pt", PyTorchTestModule()).run(input1, input2)
    print(output)
    print("PyBuda installation was a success!")


if __name__ == "__main__":
    test_module_direct_pytorch()