In [None]:
import os
import sys
sys.path.append('/kaggle/input/raft-pytorch')
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch

from glob import glob
from PIL import Image
from tqdm import tqdm
from PIL import Image

In [None]:
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from os import walk
import ast
import matplotlib.cm as cm
import matplotlib.animation as animation
import time
import pickle


In [None]:
a=sorted(int(x) for x in os.listdir("../input/dataflow"))

In [None]:
a=[str(x) for x in a]
start=120+40

In [None]:
a=a[40:80]

In [None]:
data=[]
for i in range(len(a)):
    rooot_filename=os.path.join("../input/dataflow",a[i])
    data.append({
        "filename":os.path.join(rooot_filename,"imgs"),
        "annot":os.path.join(rooot_filename,"annotation.json")
    })

In [None]:
folder=[]
annotations=[]
for i in range(len(data)):#len(data)):
    filename=data[i]["filename"]
    annot=data[i]["annot"]
    
    files=[]
    for (dirpath, dirnames, f1) in walk(filename):
        files.append(f1)
        
    with open(annot) as f:
        d1= ast.literal_eval(f.read())
        annotations.append(d1)
    files=sorted([int(x[:-4]) for x in files[0]])
    files=['{0:03}'.format(x) for x in files]
    files=[os.path.join(filename,str(x)+'.jpg') for x in files]
    folder.append(files)

In [None]:
# imagelist=[]
# for d1,files in zip(annotations,folder):
#     positions=[]
#     for i in range(len(d1)):
#         positionx,positiony=[int(x) for x in d1[i]["position"]]
#         top=int(d1[i]["bbox"]['top'])
#         left=int(d1[i]["bbox"]['left'])
#         right=int(d1[i]["bbox"]['right'])
#         bottom=int(d1[i]["bbox"]['bottom'])
#         positions.append([positionx,positiony,top,left,right,bottom]) 
#     #print(positions)
#     lis=[]
#     for i in range(len(files)):
#         #print(files[i])
#         image = cv2.imread(os.path.join(files[i]))
#         height, width, channels = image.shape
#         for j in positions:
#             start_point = (j[3],j[2])
#             end_point = (j[4], j[5])
#             color = (0,0,255)
#             thickness = 2
#             image = cv2.rectangle(image, start_point, end_point, color, thickness)
#         imagelist.append(image)
#     #imagelist.append(lis)
fold=np.copy(folder)

# RAFT introduction

I introduce the model: **RAFT: Recurrent All-Pairs Field Transforms for Optical Flow** which is originally introduced in ECCV2020 by Teed et. al. in Princeton University and prized Best Paper Award!.
* https://arxiv.org/abs/2003.12039
* https://github.com/princeton-vl/RAFT (licensed under the BSD 3-Clause License)

Briefly, RAFT has below features
* Recurrent optical flow estimation
* Compute pixel-wise correlation between pair-wise input images and reuse it in the following recurrent step
* Lightweight, rapid inference, and high accuracy

![RAFT architecture image from https://github.com/princeton-vl/RAFT](https://github.com/princeton-vl/RAFT/raw/master/RAFT.png)

This is [my explanation slide](https://speakerdeck.com/daigo0927/raft-recurrent-all-pairs-field-transforms-for-optical-flow) in Japanese.

# Run RAFT on sample images

In [None]:
from raft.core.raft import RAFT
from raft.core.utils import flow_viz
from raft.core.utils.utils import InputPadder
from raft.config import RAFTConfig

In [None]:
config = RAFTConfig(
    dropout=0,
    alternate_corr=False,
    small=False,
    mixed_precision=False
)

model = RAFT(config)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

weights_path = '/kaggle/input/raft-pytorch/raft-sintel.pth'
#weights_path = '/kaggle/input/raft-pytorch/raft-things.pth'

ckpt = torch.load(weights_path, map_location=device)
model.to(device)
model.load_state_dict(ckpt)

In [None]:
# image_files = glob('/kaggle/input/raft-pytorch/raft/demo-frames/*.png')
# image_files = sorted(image_files)

# print(f'Found {len(image_files)} images')
# print(sorted(image_files))

In [None]:
def load_image(imfile, device):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(device)


def viz(img1, img2, flo):
    img1 = img1[0].permute(1,2,0).cpu().numpy()
    img2 = img2[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 4))
    ax1.set_title('input image1')
    ax1.imshow(img1.astype(int))
    ax2.set_title('input image2')
    ax2.imshow(img2.astype(int))
    ax3.set_title('estimated optical flow')
    ax3.imshow(flo)
    plt.show()

In [None]:
#len(imagelist[3])
#imagelist

In [None]:
for il,imagelist in enumerate(fold):
    flows1=[]
    model.eval()
    n_vis = len(imagelist)-1
    os.mkdir(f'{start+il+1}')
    for i,(file1, file2) in enumerate(tqdm(zip(imagelist[:n_vis], imagelist[1:1+n_vis]))):
        image1 = load_image(file1, device)
        image2 = load_image(file2, device)

        padder = InputPadder(image1.shape)
        image1, image2 = padder.pad(image1, image2)

        with torch.no_grad():
            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)

        #viz(image1, image2, flow_up)
        flo=flow_up
        flo = flo[0].permute(1,2,0).cpu().numpy()
        flo = flow_viz.flow_to_image(flo)
        im = Image.fromarray(flo)
        filename=os.path.join(f'{start+il+1}',f'{i}.jpeg')
        im.save(filename)
        #flows1.append(flo)

The first and second columns are input paired images and right column is the predicted optical flow.

In [None]:
# img=Image.open('./1/0.jpeg')
# plt.imshow(img)

In [None]:
# len(os.listdir('./1'))

# Run on NFL video

In [None]:
# video_file = '/kaggle/input/nfl-impact-detection/train/57583_000082_Endzone.mp4'

# cap = cv2.VideoCapture(video_file)

# frames = []
# while True:
#     has_frame, image = cap.read()
    
#     if has_frame:
#         image = image[:, :, ::-1] # convert BGR -> RGB
#         frames.append(image)
#     else:
#         break
# frames = np.stack(frames, axis=0)

# print(f'frame shape: {frames.shape}')    
# plt.imshow(frames[0])

In [None]:
# flows1=[]
# n_vis = len(frames)-1

# for i in range(n_vis):
#     image1 = torch.from_numpy(frames[i]).permute(2, 0, 1).float().to(device)
#     image2 = torch.from_numpy(frames[i+1]).permute(2, 0, 1).float().to(device)
    
#     image1 = image1[None].to(device)
#     image2 = image2[None].to(device)

#     padder = InputPadder(image1.shape)
#     image1, image2 = padder.pad(image1, image2)
    
#     with torch.no_grad():
#         flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
# #     viz(image1, image2, flow_low)
#     flo=flow_up
#     flo = flo[0].permute(1,2,0).cpu().numpy()
#     flo = flow_viz.flow_to_image(flo)
#     flows1.append(flo)
    

RAFT seems to capture the motion of each player.

In [None]:
# frames = [] # for storing the generated images
# fig = plt.figure()
# plt.axis('off')
# for i in range(len(flows1)):
#     frames.append([plt.imshow(flows1[i], cmap=cm.Greys_r,animated=True)])

# ani = animation.ArtistAnimation(fig, frames, interval=200, blit=True,
#                                 repeat_delay=1000)
# ani.save('movie.mp4')
# plt.show()

In [None]:
# from IPython.display import HTML
# from base64 import b64encode
# filename12='./movie.mp4'
# def play(filename12):
#     html = ''
#     video = open(filename12,'rb').read()
#     src = 'data:video/mp4;base64,' + b64encode(video).decode()
#     html += '<video width=1000 controls autoplay><source src="%s" type="video/mp4"></video>' % src 
#     return HTML(html)

# play('./movie.mp4')

In [None]:
# !wget --quiet https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5

In [None]:
# import pixellib
# from pixellib.instance import instance_segmentation
# import cv2

# segment_video = instance_segmentation()
# segment_video.load_model("mask_rcnn_coco.h5")

In [None]:
# segment_video.process_video("../input/nfl-impact-detection/test/57906_000718_Endzone.mp4", show_bboxes = True, frames_per_second= 15, output_video_name="traffic_monitor.mp4")