In [None]:
!git clone https://github.com/vishalpathak24/img-latex.git

In [None]:
!pip install torch

In [None]:
!pip install 'transformers[torch]'

In [None]:
!pip install torchserve torch-model-archiver torch-workflow-archiver

In [None]:
!pip install captum

In [4]:
%cd img-latex

/home/studio-lab-user/sagemaker-studiolab-notebooks/img-latex/img-latex


In [3]:
!pwd

/home/studio-lab-user/sagemaker-studiolab-notebooks/img-latex


In [16]:
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex import NougatLaTexProcessor
from transformers import AutoTokenizer
import torch
import numpy as np

# Model Basic Usage

In [None]:
model_name = "Norm/nougat-latex-base"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)

tokenizer = NougatTokenizerFast.from_pretrained(model_name)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)

In [None]:
from PIL import Image

image = Image.open("../sample-images/lt-2.jpg")
if not image.mode == "RGB":
    image = image.convert('RGB')

In [None]:
pixel_values = latex_processor(image, return_tensors="pt").pixel_values
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
                              return_tensors="pt").input_ids

In [None]:
pixel_values

In [None]:
with torch.no_grad():
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_length,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
        num_beams=5,
        bad_words_ids=[[tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

In [None]:
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)

# Testing Basic inferencing

In [11]:
# initialize
model_dir = 'model_dir'
model_name = "Norm/nougat-latex-base"
model = VisionEncoderDecoderModel.from_pretrained(model_dir)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
decoder_input_ids = tokenizer(
    tokenizer.bos_token, add_special_tokens=False, return_tensors="pt"
).input_ids

Config of the encoder: <class 'transformers.models.donut.modeling_donut_swin.DonutSwinModel'> is overwritten by shared encoder config: DonutSwinConfig {
  "attention_probs_dropout_prob": 0.0,
  "depths": [
    2,
    2,
    14,
    2
  ],
  "drop_path_rate": 0.1,
  "embed_dim": 128,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": [
    224,
    560
  ],
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-05,
  "mlp_ratio": 4.0,
  "model_type": "donut-swin",
  "num_channels": 3,
  "num_heads": [
    4,
    8,
    16,
    32
  ],
  "num_layers": 4,
  "patch_size": 4,
  "qkv_bias": true,
  "transformers_version": "4.47.0",
  "use_absolute_embeddings": false,
  "window_size": 7
}

Config of the decoder: <class 'transformers.models.mbart.modeling_mbart.MBartForCausalLM'> is overwritten by shared decoder config: MBartConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "add_final_layer_norm": true

In [12]:
import os
import io
import PIL.Image as Image

from array import array

def readimage(path):
    count = os.stat(path).st_size / 2
    with open(path, "rb") as f:
        return bytearray(f.read())
    
image_bytes = readimage('../sample-images/lt-2.jpg')

In [29]:
image = Image.open(io.BytesIO(image_bytes))
px_val = latex_processor(image).pixel_values

In [30]:
px_val[0]

array([[[-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ],
        [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ],
        [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ],
        ...,
        [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ],
        [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ],
        [-2.117904 , -2.117904 , -2.117904 , ..., -2.117904 ,
         -2.117904 , -2.117904 ]],

       [[-2.0357141, -2.0357141, -2.0357141, ..., -2.0357141,
         -2.0357141, -2.0357141],
        [-2.0357141, -2.0357141, -2.0357141, ..., -2.0357141,
         -2.0357141, -2.0357141],
        [-2.0357141, -2.0357141, -2.0357141, ..., -2.0357141,
         -2.0357141, -2.0357141],
        ...,
        [-2.0357141, -2.0357141, -2.0357141, ..., -2.0357141,
         -2.0357141, -2.0357141],
        [-2.

In [33]:
input_data = torch.tensor(np.array([px_val[0],px_val[0]]))


In [72]:
input_data.shape

torch.Size([2, 3, 224, 560])

In [76]:
decoder_input_ids[0]

tensor([0])

In [77]:
decoder_strt_inputs = torch.tensor(np.array([decoder_input_ids[0], decoder_input_ids[0]]))

In [78]:
decoder_strt_inputs

tensor([[0],
        [0]])

In [85]:
outputs = model.generate(
    input_data,
    decoder_input_ids=decoder_strt_inputs,
    max_length=model.decoder.config.max_length,
    early_stopping=True,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    use_cache=True,
    num_beams=5,
    bad_words_ids=[[tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

In [86]:
outputs.sequences

tensor([[    0,    82, 13727,   113,    82,   707,   113,    39,   115,   113,
            41,   115,   115,    82,  1459,    82,   867,    30,    44,    82,
          1459,   113,    82,   707,   113,    42,   115,   113,    41,   115,
           115,    82,   747,    31,     2],
        [    0,    82, 13727,   113,    82,   707,   113,    39,   115,   113,
            41,   115,   115,    82,  1459,    82,   867,    30,    44,    82,
          1459,   113,    82,   707,   113,    42,   115,   113,    41,   115,
           115,    82,   747,    31,     2]])

In [87]:
sequence = tokenizer.batch_decode(outputs.sequences)
sequence = [ s.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "") for s in sequence]
print(sequence)

['\\textstyle{\\frac{1}{3}}\\times\\left(6\\times{\\frac{4}{3}}\\right)', '\\textstyle{\\frac{1}{3}}\\times\\left(6\\times{\\frac{4}{3}}\\right)']


# Creating Mar for inferencing

In [None]:
import shutil

shutil.make_archive('nougat', 'zip', 'nougat_latex')

In [None]:
model.save_pretrained("model_dir")

In [91]:
!torch-model-archiver \
    --model-name img-latex \
    --version 1.0 \
    --serialized-file model_dir/model.safetensors \
    --handler handler.py \
    --extra-files "nougat.zip,model_dir/config.json,model_dir/generation_config.json" \
    --export-path ../../model-store --force

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [None]:
# # Saving Config
# model.config.save_pretrained(".")
# # Saving Model
# torch.save(model.state_dict(), 'model.pt')

In [None]:
# !torch-model-archiver --model-name img-latex \
# --version 1.0 --model-file model.py \
# --serialized-file model.pt \
# --handler handler.py \
# --extra-files "nougat.zip, config.json"

In [None]:
!mkdir ../../model-store/
!mv img-latex.mar ../../model-store/

In [None]:
!pip install torchserve

In [89]:
!torchserve --model-store model-store/ --models img-latex=img-latex.mar --ts-config img-latex/config.properties &

OSError: Background processes not supported.

In [90]:
!torchserve --stop

TorchServe is not currently running.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
# Loading model as 

In [None]:
model.__dict__

In [None]:
model_pt_path = 'model.pt'
device

In [None]:
state_dict = torch.load(model_pt_path, map_location=device)

In [None]:
state_dict

In [None]:
state_dict

In [None]:
VisionEncoderDecoderModel().load_state_dict(state_dict)

In [None]:
model.load_state_dict(state_dict)

In [None]:
del LatexHandler
del lh

In [None]:
from handler import LatexHandler

In [None]:
class Context:
    system_properties = {
        'model_dir':'model_dir'
    }

In [None]:
context = Context()

In [None]:
lh = LatexHandler()

In [None]:
lh.initialize(context)

In [None]:
data = [{'data':image_bytes}]

In [None]:
lh.preprocess(data)

In [None]:
b = data[0]['data']

In [None]:
import numpy as np
x = np.array([1,2,3])

In [None]:
x