In [1]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torchvision.transforms as transforms
import pickle
from transformers import AutoImageProcessor
import torch
class CompVisualDataset(Dataset):
    def __init__(self, pickle_file,image_path_root):
        super().__init__()
        with open(pickle_file, 'rb') as file:
            data = pickle.load(file)
        self.components = []
        self.image_paths = []
        for k in data.keys():
            self.components.extend(data[k]['components'])
            self.image_paths.extend([ f"{os.path.join(image_path_root,k)}_page-{comp['page']}.png" for comp in data[k]['components']])


    def __len__(self):
        return len(self.components)

    def __getitem__(self, index):
        comp = self.components[index]
        img = Image.open(self.image_paths[index]).convert("RGB")
        bbox = comp['bbox']
        cropped_img = transforms.functional.crop(img,top=bbox[1],left=bbox[0],height=bbox[3],width=bbox[2])
        return cropped_img, comp['object_id']
    
def collate_fn(batch):
    imgs = [e[0] for e in batch]
    object_ids = [e[1] for e in batch]
    return imgs, object_ids

In [2]:
train_dataset = CompVisualDataset('train_data.pkl','train/train')
train_dataloader = DataLoader(train_dataset,batch_size=10, shuffle=True, collate_fn= collate_fn)

In [3]:
from transformers import  DonutSwinModel
model = DonutSwinModel.from_pretrained("./donut_encoder")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

DonutSwinModel(
  (embeddings): DonutSwinEmbeddings(
    (patch_embeddings): DonutSwinPatchEmbeddings(
      (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): DonutSwinEncoder(
    (layers): ModuleList(
      (0): DonutSwinStage(
        (blocks): ModuleList(
          (0): DonutSwinLayer(
            (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attention): DonutSwinAttention(
              (self): DonutSwinSelfAttention(
                (query): Linear(in_features=128, out_features=128, bias=True)
                (key): Linear(in_features=128, out_features=128, bias=True)
                (value): Linear(in_features=128, out_features=128, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): DonutSwinSelfOutput(
                (dense): Linear(in

In [4]:
image_processor = AutoImageProcessor.from_pretrained("nielsr/donut-base")

In [5]:
from tqdm import tqdm
extracted_feat = {}
with torch.no_grad():
    for imgs, object_ids in tqdm(train_dataloader):
        image_inputs = image_processor(imgs, return_tensors="pt").to(device)
        outputs = model(**image_inputs)
        features = outputs.pooler_output.detach().cpu()
        extracted_feat.update(zip(object_ids,list(features)))

  0%|          | 0/4333 [00:29<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.46 GiB. GPU 0 has a total capacity of 1.95 GiB of which 1.07 GiB is free. Including non-PyTorch memory, this process has 888.00 MiB memory in use. Of the allocated memory 847.02 MiB is allocated by PyTorch, and 8.98 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
feature_path = "train_visual_feat.pth"
torch.save(extracted_feat,feature_path)