In [177]:

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data
from tqdm import tqdm
from torch_geometric.nn import Node2Vec
from itertools import permutations


In [19]:
# Transformation for the images
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


In [79]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16 = models.vgg16(pretrained=True).features.to(device)
#vgg16 = torch.nn.Sequential(*list(vgg16.children())[:-1])



In [128]:

# A function to extract features for a given image
def extract_features(img_path):
    img = Image.open(img_path).convert("RGB")
    img_t = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        features = vgg16(img_t)
    return features.cpu().numpy().flatten()



In [136]:

def extract_image_data(dataset_path, extract_features_fn = None, considered_eras = ['00', '01', '02', '03', '04', '05', '06'], early_stop= False):
    characters = {}
    edges = []
    edge_attr = []
    # Load and process each image
    t = 0
    for root, dirs, files in tqdm(os.walk(dataset_path)):
        for file in sorted(files):
            if file.endswith(".png") and file.split("_")[-1][:2] in considered_eras:
                char_class, era = os.path.basename(root), file.split("_")[-1][:2]
                if char_class not in characters:
                    characters[char_class] = []
                
                img_path = os.path.join(root, file)
                if extract_features_fn is not None:
                    features = extract_features_fn(img_path)
                    res = (t, features, era, img_path)
                else:
                    res = (t, None, era, img_path)

                characters[char_class].append(res)
                t+=1
        if t > 40 and early_stop:
            break
    return characters



In [137]:
dataset_path = "images_background"
data_images_dict = extract_image_data(dataset_path, extract_features_fn=extract_features, early_stop=True)

8it [00:09,  1.20s/it]


In [138]:
char_class_encoder = LabelEncoder()
char_class_id_map = char_class_encoder.fit(list(data_images_dict.keys()))


In [183]:
def generate_graph_data_from_dict(data_images_dict):
    edges = []
    X = []
    y = []
    for char_class, data in data_images_dict.items():
        node_ids = [data[i][0] for i in range(len(data))]
        pairs = list(permutations(node_ids, 2))
        edges += pairs # Add pairs to the edges within a character class
        features = [data[i][1] for i in range(len(data))]
        X += features # Append features
        eras = [data[i][2] for i in range(len(data))]
        y += eras
        #img_path = [data[i][3] for i in range(len(data))]
    E = torch.tensor(edges, dtype=torch.long).t().contiguous()
    X = torch.tensor(X, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.long)

    # Create PyG Data object
    data = Data(x=X, edge_index=E, y=y)
    return data


In [184]:
def load_graph_from_data(data, device = 'cpu'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'  # check if cuda is available to send the model and tensors to the GPU
    model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
                    context_size=10, walks_per_node=10,
                    num_negative_samples=1, p=1, q=1, sparse=True).to(device)
    loader = model.loader(batch_size=128, shuffle=True, num_workers=4)  # data loader to speed the train 
    optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)  # initzialize the optimizer 
    return model, loader, optimizer

[(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
[(4, 5), (4, 6), (4, 7), (4, 8), (5, 4), (5, 6), (5, 7), (5, 8), (6, 4), (6, 5), (6, 7), (6, 8), (7, 4), (7, 5), (7, 6), (7, 8), (8, 4), (8, 5), (8, 6), (8, 7)]
[(9, 10), (9, 11), (9, 12), (9, 13), (10, 9), (10, 11), (10, 12), (10, 13), (11, 9), (11, 10), (11, 12), (11, 13), (12, 9), (12, 10), (12, 11), (12, 13), (13, 9), (13, 10), (13, 11), (13, 12)]
[(14, 15), (14, 16), (14, 17), (14, 18), (14, 19), (15, 14), (15, 16), (15, 17), (15, 18), (15, 19), (16, 14), (16, 15), (16, 17), (16, 18), (16, 19), (17, 14), (17, 15), (17, 16), (17, 18), (17, 19), (18, 14), (18, 15), (18, 16), (18, 17), (18, 19), (19, 14), (19, 15), (19, 16), (19, 17), (19, 18)]
[(20, 21), (20, 22), (20, 23), (20, 24), (21, 20), (21, 22), (21, 23), (21, 24), (22, 20), (22, 21), (22, 23), (22, 24), (23, 20), (23, 21), (23, 22), (23, 24), (24, 20), (24, 21), (24, 22), (24, 23)]
[(25, 26), (25, 27), (25, 28), (25, 29), (25, 3