In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/VRD-IU

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import pickle
from transformers import AutoTokenizer
import torch
class CompTextDataset(Dataset):
    def __init__(self, pickle_file):
        super().__init__()
        with open(pickle_file, 'rb') as file:
            data = pickle.load(file)
        self.components = []
        for k in data.keys():
            for comp in data[k]['components']:
                if comp['bbox'] == [0.0, 0.0, 0.0, 0.0]:
                  continue
                self.components.append(comp)

    def __len__(self):
        return len(self.components)

    def __getitem__(self, index):
        comp = self.components[index]
        try:
            text = comp['text']
        except:
            text = comp['category']
        return text, comp['object_id']

In [None]:
train_dataset = CompTextDataset('train_data.pkl')

In [None]:
from transformers import XLMRobertaModel
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")
model.to(device)
model.eval()

In [None]:

tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")

In [None]:
from tqdm import tqdm
import os
def extract_features(dataloader, feature_path):
  if not os.path.exists(feature_path):
    os.makedirs(feature_path)
  with torch.no_grad():
      for texts, object_ids in tqdm(dataloader):
          text_inputs = tokenizer(texts, return_tensors="pt",padding=True, truncation=True).to(device)
          outputs = model(**text_inputs)
          features = outputs.pooler_output.detach().cpu()
          for idx, obj_id in enumerate(object_ids):
            torch.save(features[idx],os.path.join(feature_path,f"{obj_id}.pt"))

In [None]:
train_dataloader = DataLoader(train_dataset,batch_size=256, num_workers=6)

In [None]:
extract_features(train_dataloader,  'train_textual_features')
print("Extraction completed for training set!")