In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

from dataset import OCRDataset
from trainer import Trainer
from transformers import VisionEncoderDecoderModel, BeitImageProcessor, RobertaTokenizer
from transformers import BeitConfig, RobertaConfig, VisionEncoderDecoderConfig
from PIL import Image
import requests
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.utils.data import DataLoader
import torchvision
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from peft import LoraConfig, get_peft_model

In [None]:
# Load and plot the losses

train_losses = torch.load("models/train_losses_59.pt")
test_losses = torch.load("models/test_losses_59.pt")
plt.figure()
plt.plot(train_losses, label="train")
plt.plot(test_losses, label="test")
plt.legend(["train", "test"])

In [None]:
# Load the model
model_path = "models/model_59.pt"
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/roberta-base-bne')
model_trocr = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-stage1")
encoder = model_trocr.encoder
encoder.save_pretrained("pretrained_encoder")
encoder_config = encoder.config
encoder_config.image_size = (64,2304) 
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("pretrained_encoder", 'PlanTL-GOB-ES/roberta-base-bne', encoder_config=encoder_config) 
config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
)
model = get_peft_model(model, config)

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

#model.load_lora_weights("models/model_0.pt")
model.load_state_dict(torch.load(model_path))
model.to(device)


# Load the dataset
dataset = OCRDataset(characters_mode="typed", tokenizer=tokenizer, device=device, test=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)




In [None]:
# Get the first item from the dataset
image, label, mask = next(iter(dataloader))
image = image.to(device)
mask = mask[0]
nb_ids_label = [1 if mask[i] == 1 else 0 for i in range(len(mask))]
nb_ids_label = np.sum(nb_ids_label)
label = label[:,:nb_ids_label]
label = tokenizer.batch_decode(label, skip_special_tokens=True)
# Make inference
with torch.no_grad():
    generated_ids = model.generate(image)
    generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

print("True label: ", label[0])
print("Generated text: ", generated_text[0])
plt.figure()
plt.imshow(image.squeeze(0).permute(1,2,0).cpu().detach().numpy())

'''input_ids = tokenizer.convert_tokens_to_ids([tokenizer.bos_token])
input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)
output = model(image, decoder_input_ids=input_ids)
logits = output.logits
softmax = torch.softmax(logits, dim=-1)
values,idxs = torch.topk(softmax, k=30, dim=-1)#.cpu().numpy()[0][0]
values = values.cpu().detach().numpy()[0,0]
idxs = idxs.cpu().detach().numpy()[0,0]
print(values)
print(idxs)
tokens = tokenizer.convert_ids_to_tokens(idxs)
print(tokens)'''
print()