In [1]:
import torch
import numpy as np
import pandas as pd
import os
import ast

from IPython.display import HTML
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer

In [2]:
from models import PEPEModel

# Prepare Files

Following instruction [here](https://github.com/xingyaoww/gif-reply/tree/main/data) and [here (only for oscar pretrained model dir)](https://github.com/xingyaoww/gif-reply/tree/main/src/pepe) to prepare the following data.


In [3]:
INFERRED_FEATURE_PATH = "/path/to/your/gif-pepe-inferred-features-top1k.csv/or/gif-pepe-inferred-features.csv"
PEPE_MODEL_CKPT = "/path/to/your/PEPE-model-checkpoint.pth"
GIF_ID_TO_GIPHY_INDEX_MAPPING_FILE = "/path/to/your/gif-id-to-giphy-id-mapping.csv"
OSCAR_PRETRAINED_MODEL_DIR = "/path/to/your/ep_67_588997/"

# Load Inferred Features

In [5]:
def load_inferred_feature(feature_path: str, banning_gifs: set = set()):
    # load precomputed gif features        
    _gif_ds = pd.read_csv(feature_path)
    _gif_ds["gif_feature"] = _gif_ds["gif_feature"].apply(ast.literal_eval).apply(np.array)
    # filter banning gifs
    _gif_ds = _gif_ds[_gif_ds["gif_id"].apply(lambda x: x not in banning_gifs)]
    
    # load gif_features into a dict
    gif_index_to_id = _gif_ds['gif_id'].to_list()
    return {
        "gif_features": np.stack(_gif_ds['gif_feature'].to_list()),
        "gif_index_to_id": gif_index_to_id,
        "gif_id_to_index": {gif_id: idx for idx, gif_id in enumerate(gif_index_to_id)}
    }

In [6]:
inferred_feature = load_inferred_feature(INFERRED_FEATURE_PATH)

# Load `PEPE`

In [7]:
bertweet_tokenizer = AutoTokenizer.from_pretrained("vinai/bertweet-base")

def tokenizeTweet(tweet):
    # max_length=128 default for bertweet
    return bertweet_tokenizer.encode(tweet, max_length=128, truncation=True)

class PEPERetrieval():
    def __init__(self, checkpoint_path, pretrained_oscar_path, inferred_feature):
        print("loading PEPE model.")
        self.model = PEPEModel(pretrained_oscar_path)
        if os.environ.get("CUDA_VISIBLE_DEVICES", None):
            self.model = self.model.cuda()
            map_location = None
        else:
            map_location = torch.device('cpu')
        
        print(self.model.load_state_dict(
            torch.load(checkpoint_path, map_location=map_location)))
        print("PEPE model loaded.")
        
        self.gif_features = inferred_feature.get("gif_features")
        self.gif_index_to_id = inferred_feature.get("gif_index_to_id")
        self.gif_id_to_index = inferred_feature.get("gif_id_to_index")

    def _tweet_to_tweet_feature_PEPE(self, normalized_tweet: str):
        tweet_ids = tokenizeTweet(normalized_tweet)
        tweet_ids = torch.Tensor(tweet_ids).long().unsqueeze(0)
        if os.environ.get("CUDA_VISIBLE_DEVICES", None):
            tweet_ids = tweet_ids.cuda()
        return self.model.extract_tweet_feature(tweet_ids).detach().cpu().squeeze().numpy()

    def retrieve(self, normalized_tweet: str, k=10):
        tweet_feature = self._tweet_to_tweet_feature_PEPE(normalized_tweet)
        _scores = tweet_feature @ self.gif_features.T
        recommended_indices = list(reversed((_scores).argsort()[-k:].tolist()))
        recommended_gifs = [self.gif_index_to_id[i] for i in recommended_indices]
        return recommended_gifs

    def get_similarity(self, normalized_tweet: str, gif_id: str):
        tweet_feature = self._tweet_to_tweet_feature_PEPE(
            normalized_tweet)
        gif_idx = self.gif_id_to_index.get(gif_id)
        return cosine_similarity(tweet_feature.reshape(1, -1),
                                 self.PEPE_gif_features[gif_idx].reshape(1, -1)).tolist()[0][0]

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [8]:
PEPE = PEPERetrieval(PEPE_MODEL_CKPT, OSCAR_PRETRAINED_MODEL_DIR, inferred_feature)

loading PEPE model.
<All keys matched successfully>
PEPE model loaded.


# Have Fun!

In [9]:
def visualize_giphy_gif(giphy_id):    
    giphy_link = f"https://media.giphy.com/media/{giphy_id}/giphy.mp4"
    return HTML(f"""
        <video autoplay loop muted>
            <source src="{giphy_link}" type="video/mp4" />
        </video>
    """)

gif_id_to_giphy_id = dict(pd.read_csv(GIF_ID_TO_GIPHY_INDEX_MAPPING_FILE)[["gif_id", "giphy_id"]].to_numpy())

In [10]:
recommended_gif_ids = PEPE.retrieve("Hello! Nice to meet you!")
recommended_gif_ids

['ffe7e6e2f0f0f0f0ffe7e6e0f0f0f0f8ffe7e7e0f0f0f0f0',
 'ffe7e6e2f0f0f0f0ffe7e7e0f0f0f0f8ffe7e7e0f0f0f0f0',
 '0f1b090103c3c3c71b1b0d0147c7c7c71b1b090103c3c7c7',
 '00203c4f6ffffefc0000167b73fffefc0030144f6ffffefc',
 'ff00086fefcfcf87ff00084fefefefc7ff00084fefefef87',
 '7e3e5c5c0c0c0c0e7e3e5c5c0c0c0c0e7e3e5c5c0c181c0e',
 'ddd8f0f76061f98478fffed0c3f3c0c066802c0e83dfffef',
 '0400149c99833f3f0400541c9d033f3f0400149c1d033f3f',
 'c6e6c0d6cfcbc0c3c6e6c0decfcbc0c3c6e6c0decfcbc0c3',
 '0400149c99833f3f0400549c9d033f3e0400149c1d033f3f']

In [11]:
recomended_giphy_ids = list(map(lambda x: gif_id_to_giphy_id.get(x), recommended_gif_ids))
recomended_giphy_ids

['UMOQRDqoPx8UE',
 'GNvWLzRFMm3Be',
 '4iJlMbTNSaOgE',
 '3o7ZeObEUcfLbktUkg',
 'cJSDRt8csBx0A7YFfh',
 'jt2YKsUUtsKCA',
 'Lkoj36QKG8KDS',
 'Vccpm1O9gV1g4',
 '3oriO04qxVReM5rJEA',
 'EtbzbLf34qgms']

In [12]:
visualize_giphy_gif(recomended_giphy_ids[1])

In [13]:
visualize_giphy_gif(recomended_giphy_ids[2])