In [None]:
import numpy as np
import matplotlib.pyplot as plt
from cebra import CEBRA
import glob

import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
feature_extractor = resnet50(pretrained=True)
# feature_extractor.eval()

In [None]:
monkey_directory = "data/monkey_play/frames/"
monkey_paths = sorted(glob.glob(monkey_directory+"*.png"))
monkey_play_labels = [0,0,0,1,1,1,1,1,1,1]

In [None]:
monkey_play_features = []
monkey_play_features_labels = []
with torch.no_grad():
    for x in range(len(monkey_paths)):
        input_tensor = preprocess(Image.open(monkey_paths[x]).convert("RGB"))
        input_batch = input_tensor.unsqueeze(0)
        features = feature_extractor(input_batch)
        monkey_play_features_labels.append(torch.full(features_shape,fill_value=monkey_play_labels[x]))
        monkey_play_features.append(features)
               
monkey_play_features = torch.cat(monkey_play_features,dim=0)
monkey_play_features_labels = torch.cat(monkey_play_features_labels,dim=0)

print(monkey_play_features.shape,monkey_play_features_labels.shape)

In [None]:
max_iterations = 5000 
cebra_behavior_model = CEBRA(model_architecture='offset1-model',
                        batch_size=1,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='cosine',
                        conditional='time_delta',
                        device='cpu',
                        verbose=True,
                        time_offsets=1,
                        )

In [None]:
cebra_behavior_model.fit(monkey_play_features, monkey_play_feature_labels)

In [None]:
cebra_behavior_model.save("cebra_monkey_play_model.pt")