In [1]:
import os

import torch

from PIL import Image
import moviepy.editor as mpy

from torch.nn import functional as F
from torch.autograd import Variable
from math import ceil
from models import TSN
import torchvision
from transforms import *

# Get dataset categories.
import pandas as pd
categories = pd.read_csv('/home/ec2-user/mnt/giphy_dataset/category.txt', header=None)[0]
num_class = len(categories)


def load_model(useGPU=True):
    model = TSN(num_class,
              8,
              'RGB',
              base_model='InceptionV3',
              consensus_type='TRNmultiscale',
              img_feature_dim=256, print_spec=False)
    
    weight_file = '/home/ec2-user/gif-recommendations/model_train/trn_moments_model/trn_pytorch/model/TRN_custom_RGB_InceptionV3_TRNmultiscale_segment8_checkpoint.pth.tar'
    
    if useGPU:
        model = model.cuda()
        checkpoint = torch.load(weight_file)
    else:
        checkpoint = torch.load(weight_file, map_location=lambda storage,
                                loc: storage)  # allow cpu

    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
    model.load_state_dict(base_dict)
    model.eval()
    
    for p in model.parameters():
        p.requires_grad = False
    
    return model


def load_transform():
    """Load the image transformer."""
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Lambda(lambda frames: [Image.fromarray(x) for x in frames]),
        GroupOverSample(model.input_size, model.scale_size),
        Stack(roll=True),
        ToTorchFormatTensor(div=False),
        GroupNormalize(model.input_mean, model.input_std),
    ])
    return transform

    
TOP_K = 5

def evenly_spaced_sampling(array, n):
    """Choose `n` evenly spaced elements from `array` sequence"""
    length = len(array)
    if n == 0 or length == 0:
        return []
    if n > length:
        n = length
    return [array[ceil(i * length / n)] for i in range(n)]

def display_gif(gif_id):
    from IPython import display
    return display.HTML("<img src='https://media.giphy.com/media/{}/giphy.gif'>".format(gif_id))

# Get dataset categories
model = load_model(useGPU=True)

# Load the video frame transform
transform = load_transform()

def process_gif(gif_id, model=model, categories=categories, transform=transform, samples_num=8):
    gif = mpy.VideoFileClip(f'https://media.giphy.com/media/{gif_id}/giphy.mp4')
    frames = evenly_spaced_sampling(list(gif.iter_frames()), samples_num)
    del gif
    
    data = transform(frames)
    with torch.no_grad():
        input_var = torch.autograd.Variable(data.view(-1, 3, data.size(1), data.size(2))).unsqueeze(0).cuda()
    
        # Make video prediction
        logits = model(input_var)
        h_x = torch.mean(F.softmax(logits, 1), dim=0).data
        probs, idx = h_x.sort(0, True)
    
    for i in range(TOP_K):
        print(f'{probs[i]:.8f} -> {categories[idx[i]]}')
    idx_np = idx.cpu().data.numpy()
    probs_np = probs.cpu().data.numpy()
    predictions = [[idx_np[i], probs_np[i]] for i in range(TOP_K)]
    return predictions

Multi-Scale Temporal Relation Network Module in use ['8-frame relation', '7-frame relation', '6-frame relation', '5-frame relation', '4-frame relation', '3-frame relation', '2-frame relation']
Freezing BatchNorm2D except the first one.


In [2]:
print(list(categories))

['smh', 'love', 'thumbs-up', 'ok', 'good-luck', 'disappointed', 'sorry', 'hi', 'animals', 'hello', 'angry', 'shocked', 'shrug', 'yay', 'popcorn', 'mind-blown', 'happy', 'thank-you', 'smile', 'dislike', 'k', 'shake-head', 'facepalm', 'shame', 'bored', 'eye-roll', 'party', 'yes', 'hot', 'sad', 'confused', 'lol', 'dancing', 'congratulations', 'nope', 'what', 'waiting', 'hug', 'laughing', 'classics', 'mad', 'whatever', 'omg', 'why', 'wow', 'do-want', 'like', 'bye', 'celebration', 'thanks', 'excited', 'scared', 'tired', 'applause', 'flirting', 'wtf', 'youre-welcome', 'dance', 'good-job', 'high-five', 'no', 'crying']


In [3]:
gifId = 'yZ2FSn86bf2co' # paste your gifId here
process_gif(gifId)
display_gif(gifId)

0.99661732 -> what
0.00172772 -> flirting
0.00042690 -> shocked
0.00040669 -> shame
0.00029753 -> classics
