In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

from dataset import OCRDataset
from trainer import Trainer
from transformers import VisionEncoderDecoderModel, BeitImageProcessor, RobertaTokenizer
from transformers import BeitConfig, RobertaConfig, VisionEncoderDecoderConfig
from PIL import Image
import requests
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.utils.data import DataLoader
import torchvision
import os
import time
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = RobertaTokenizer.from_pretrained('PlanTL-GOB-ES/roberta-base-bne')

In [None]:
encoder_config = BeitConfig()
encoder_config.image_size = (32,1200)

#decoder_config = RobertaConfig()
#model.config.decoder_start_token_id = tokenizer.cls_token_id
#model.config.pad_token_id = tokenizer.pad_token_id

model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('microsoft/beit-base-patch16-224', 'PlanTL-GOB-ES/roberta-base-bne', encoder_config=encoder_config)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model = nn.DataParallel(model)
model.to(device)
#print(model.encoder.embeddings.patch_embeddings.image_size)
print()

In [None]:


train_dataset = OCRDataset(tokenizer=tokenizer, device=device, test=False)
test_dataset = OCRDataset(tokenizer=tokenizer, device=device, test=True)
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=10, shuffle=True)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=5e-5)
trainer = Trainer(model=model, optimizer=optimizer, device=device, nb_epochs=20)
trainer.train(train_dataloader=train_dataloader, test_dataloader=test_dataloader)

In [None]:
from torchvision.transforms import Resize
image, label, mask = next(iter(train_dataloader))
data_path = "data/test/out/"
files = os.listdir(data_path)
'''img_path = os.path.join(data_path, files[0])
image = read_image(img_path)
image = image.float() / 255
image_resizer = Resize((32, 1200), antialias=True)
image = image_resizer(image)
image = image.unsqueeze(0)
image = image.to(device)
print(image.shape)'''
image = image[0:1]
generated_ids = model.generate(image)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_text)
nb_ids_label = [1 if mask[i] == 1 else 0 for i in range(len(mask))]
nb_ids_label = np.sum(nb_ids_label)
label = label[:,:nb_ids_label]
label = tokenizer.batch_decode(label, skip_special_tokens=True)
print(label)
print(generated_text)