In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/VRD-IU

/content/drive/MyDrive/VRD-IU


In [9]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torchvision.transforms as transforms
import pickle
from transformers import AutoImageProcessor
import torch
class CompVisualDataset(Dataset):
    def __init__(self, pickle_file,image_path_root):
        super().__init__()
        with open(pickle_file, 'rb') as file:
            data = pickle.load(file)
        self.components = []
        self.root_path = image_path_root
        for k in data.keys():
            self.components.extend(data[k]['components'])

    def __len__(self):
        return len(self.components)

    def __getitem__(self, index):
        comp = self.components[index]
        img = Image.open(os.path.join(self.root_path, f"{comp['object_id']}.png")).convert("RGB")
        return img, comp['object_id']

def collate_fn(batch):
    imgs = [e[0] for e in batch]
    object_ids = [e[1] for e in batch]
    return imgs, object_ids

In [10]:
train_dataset = CompVisualDataset('train_data.pkl','train_components')
train_dataloader = DataLoader(train_dataset,batch_size=32, collate_fn= collate_fn, num_workers=2)
val_dataset = CompVisualDataset('val_data.pkl','val_components')
val_dataloader = DataLoader(val_dataset,batch_size=32, collate_fn= collate_fn, num_workers=2)

In [11]:
from transformers import  DonutSwinModel
model = DonutSwinModel.from_pretrained("./donut_encoder")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
model.to(device)
model.eval()

Using: cuda


DonutSwinModel(
  (embeddings): DonutSwinEmbeddings(
    (patch_embeddings): DonutSwinPatchEmbeddings(
      (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): DonutSwinEncoder(
    (layers): ModuleList(
      (0): DonutSwinStage(
        (blocks): ModuleList(
          (0): DonutSwinLayer(
            (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
            (attention): DonutSwinAttention(
              (self): DonutSwinSelfAttention(
                (query): Linear(in_features=128, out_features=128, bias=True)
                (key): Linear(in_features=128, out_features=128, bias=True)
                (value): Linear(in_features=128, out_features=128, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): DonutSwinSelfOutput(
                (dense): Linear(in

In [12]:
image_processor = AutoImageProcessor.from_pretrained("nielsr/donut-base")

In [13]:
from tqdm import tqdm
import os
def extract_features(dataloader, feature_path):
  if not os.path.exists(feature_path):
    os.makedirs(feature_path)
  with torch.no_grad():
      for imgs, object_ids in tqdm(dataloader):
          image_inputs = image_processor(imgs, return_tensors="pt").to(device)
          outputs = model(**image_inputs)
          features = outputs.pooler_output.detach().cpu()
          for idx, obj_id in enumerate(object_ids):
            torch.save(features[idx],os.path.join(feature_path,f"{obj_id}.pt"))

In [None]:
extract_features(train_dataloader,  'train_visual_features')
print("Extraction completed for training set!")
extract_features(val_dataloader,  'val_visual_features')
print("Extraction completed for validation set!")

  0%|          | 0/170 [00:00<?, ?it/s]