In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
abo_metadata = ***Metadata CSV***
abo_images = ***Image Directory***

In [3]:
df = pd.read_csv(abo_metadata)
df.head()

Unnamed: 0,spin_id,azimuth,image_id,height,width,path
0,61c91265,0,41wqHws7a6L,248,1075,61/61c91265/61c91265_00.jpg
1,61c91265,1,41++eZZHP9L,248,1075,61/61c91265/61c91265_01.jpg
2,61c91265,2,41YF86LhGDL,248,1075,61/61c91265/61c91265_02.jpg
3,61c91265,3,41I5Zz-kbAL,248,1075,61/61c91265/61c91265_03.jpg
4,61c91265,4,41lAQM2Ys5L,248,1075,61/61c91265/61c91265_04.jpg


In [4]:
df[df['spin_id'] == "61c91265"]

Unnamed: 0,spin_id,azimuth,image_id,height,width,path
0,61c91265,0,41wqHws7a6L,248,1075,61/61c91265/61c91265_00.jpg
1,61c91265,1,41++eZZHP9L,248,1075,61/61c91265/61c91265_01.jpg
2,61c91265,2,41YF86LhGDL,248,1075,61/61c91265/61c91265_02.jpg
3,61c91265,3,41I5Zz-kbAL,248,1075,61/61c91265/61c91265_03.jpg
4,61c91265,4,41lAQM2Ys5L,248,1075,61/61c91265/61c91265_04.jpg
...,...,...,...,...,...,...
67,61c91265,67,41vc8QEtYOL,248,1075,61/61c91265/61c91265_67.jpg
68,61c91265,68,41DIITDX4hL,248,1075,61/61c91265/61c91265_68.jpg
69,61c91265,69,41Ptx3uwALL,248,1075,61/61c91265/61c91265_69.jpg
70,61c91265,70,41fLE91QPVL,248,1075,61/61c91265/61c91265_70.jpg


In [5]:
class AboSpinsDataset(Dataset):

    def __init__(self, df, id_name="spin_id", seq_len=10, negative_ratio=0.5, seed=None):
        self.df = df
        self.id_col = "spin_id"
        self.unique_ids = self.df[self.id_col].unique()
        self.seq_len = seq_len
        self.neg_ratio = negative_ratio
        self.n = len(self.unique_ids)
        
        if seed is not None:
            torch.manual_seed(seed)
    
    def __len__(self):
        return self.n

    def __getitem__(self, index):
        spin_id = self.unique_ids[index]

        # Determine if this sample is a negative sample or not
        negative = True if torch.rand(1)[0] < self.neg_ratio else False 

        id_rows = self.df[self.df[self.id_col] == spin_id]
        seq_len = min(len(id_rows), self.seq_len)
        
        samples = id_rows.sample(seq_len)

        # If negative sample, randomly find another image that is not share the same id
        if negative:
            x_matches = samples.iloc[:-1]
            y_match = self.df[self.df[self.id_col] != spin_id].sample(1)
        else:
            x_matches = samples.iloc[:-1]
            y_match = samples.iloc[-1:]
        
        paths = pd.concat([x_matches, y_match], ignore_index=True)['path'].tolist()
        return paths, not negative


In [6]:
abo_data = AboSpinsDataset(df)

In [7]:
x, label = abo_data.__getitem__(1)
x

['a8/a8d1f30b/a8d1f30b_28.jpg',
 'a8/a8d1f30b/a8d1f30b_16.jpg',
 'a8/a8d1f30b/a8d1f30b_68.jpg',
 'a8/a8d1f30b/a8d1f30b_14.jpg',
 'a8/a8d1f30b/a8d1f30b_39.jpg',
 'a8/a8d1f30b/a8d1f30b_66.jpg',
 'a8/a8d1f30b/a8d1f30b_62.jpg',
 'a8/a8d1f30b/a8d1f30b_41.jpg',
 'a8/a8d1f30b/a8d1f30b_33.jpg',
 'a8/a8d1f30b/a8d1f30b_34.jpg']

In [8]:
from transformers import AutoProcessor, AutoModel

processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
model = AutoModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")

In [12]:
def collate_fn(batch):
    features = []
    labels = []
    for x, label in batch:
        imgs = [Image.open(abo_images + path) for path in x]
        inputs = processor(images=imgs, return_tensors='pt')
        
        features.append(model.get_image_features(**inputs).unsqueeze(0))
        labels += [1] if label else [0]

    return torch.cat(features), labels

In [13]:
abo_loader = DataLoader(abo_data, batch_size=4, shuffle=True, collate_fn=collate_fn)

In [14]:
for batch in abo_loader:
    print(batch[0].size())
    break

torch.Size([4, 10, 1024])
