In [40]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tqdm import notebook as tqdm
from IPython.display import Video
from IPython.display import HTML
import argparse
import sys
import os

import time
from math import sin, cos, radians, pi
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.animation as manimation
from matplotlib import animation, rc
import cv2

import numpy as onp
import jax
import jax.numpy as np
from jax import jit, grad, random, partial, lax
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax


rng = random.PRNGKey(0)

init_b1_params, net = stax.serial(
                                Dense(8), Relu,
                                Dense(2))
        
_, b1_params = init_b1_params(rng, (-1, 10))

init_params = b1_params

In [64]:

def init(total_state,n_body,fea_num,G,data):
    data[0][0][0]=500; #heavy mass at center
    data[0][0][1:5]=0.0; #that does not move

    '''For another bodies'''
    for i in range(1,n_body):
        data[0][i][0] = onp.random.rand()*8.98+0.02; #mass
        distance = onp.random.rand()*90.0+50.0;
        theta = onp.random.rand()*360; #angle in degree
        theta_rad = pi/2 - radians(theta)  #angle in radians
        data[0][i][1] = distance*cos(theta_rad); #posx
        data[0][i][2] = distance*sin(theta_rad); #posy
        '''Calc Velocity'''
        pos_norm=norm(data[0][i][1:3]) #norm of the position
        data[0][i][3] = -1*data[0][i][2] / pos_norm * (G*data[0][0][0]/pos_norm**2) * distance/1000; #velx
        data[0][i][4] = data[0][i][1]/pos_norm * (G*data[0][0][0]/pos_norm**2) * distance/1000;    #vely
    
     
    return data; 

In [76]:
    
@jit
def norm(x):
    return np.sqrt(np.sum(x**2));
@jit
def get_f(reciever,sender):
    diff=sender[1:3]-reciever[1:3];
    distance=norm(diff)
    distance = lax.cond(distance<1, distance, lambda x:1., distance, lambda x: x)
    return G*reciever[0]*sender[0]/(distance**3)*diff;
@jit 
def calc(cur_state,n_body,next_state,f_mat,f_sum,acc,diff_t):
    
    def outer_loop(i,input_values):
        next_state=input_values[0]
        f_sum=input_values[1]
        f_mat=input_values[2]
        acc=input_values[3]
        diff_t=input_values[4]
        
        def inner_loop(j,f_mat):
            
            def cond_fun(f_mat):
                f=get_f(cur_state[i][:3],cur_state[j][:3]); 
                

                f_mat = jax.ops.index_update(f_mat, jax.ops.index[i,j], f_mat[i,j]+f)

                f_mat = jax.ops.index_update(f_mat, jax.ops.index[i,j], f_mat[i,j]-f)
                return f_mat
                
            
            return lax.cond(j!=i, f_mat, cond_fun, f_mat, lambda x: x)
            
        
        init_val = f_mat
        start = i+1
        stop = n_body
        lax.fori_loop(start, stop, inner_loop, init_val)
                

        f_sum = jax.ops.index_update(f_sum, jax.ops.index[i], np.sum(f_mat[i],axis=0))
        acc = jax.ops.index_update(acc, jax.ops.index[i], f_sum[i]/cur_state[i][0])
        next_state = jax.ops.index_update(next_state, jax.ops.index[i,0], cur_state[i][0])

#         v_ = net(b1_params,cur_state.flatten())
      
        next_state = jax.ops.index_update(next_state, jax.ops.index[i,3:5],  cur_state[i][3:5]+acc[i]*diff_t)
        next_state = jax.ops.index_update(next_state, jax.ops.index[i,1:3], cur_state[i][1:3]+next_state[i][3:5]*diff_t)
    
        return [next_state,f_sum,f_mat,acc,diff_t]
    
    init_val_i = [next_state,f_sum,f_mat,acc,diff_t]
    start_i = 0
    stop_i = n_body
    o = lax.fori_loop(start_i, stop_i, outer_loop, init_val_i)
    return o[0]

def gen(n_body,total_state,fea_num,data,next_state,f_mat,f_sum,acc,diff_t):
    
    for i in tqdm.tqdm(range(1,total_state)):
        data = jax.ops.index_update(data, jax.ops.index[i], calc(data[i-1],n_body,next_state,f_mat,f_sum,acc,diff_t))
    return data;


grad_gen = grad(gen)

In [77]:
def transform(data):
    traj=[]
    for body in range(data.shape[1]):
        traj.append(data[:,body,:].T)
    return np.asarray(traj)

def plot_trajectories3D(trajectories,frames=100):
        """Trajectories is a list of list containing a triplet of lists 
           for x, y and z co-ordinates for each object for each instance in time.
        """
           
        fig = plt.figure()
        ax = plt.axes()
        
        
        lines=[]
        empty_xs = []
        empty_ys = []
        empty_zs = []
        for i,_ in enumerate(trajectories):
            des = ['ro-','bo-','go-','yo-','mo-','co-','ro-','bo-','go-','ko-','yo-','mo-','co-','ro-','bo-','go-','ko-','yo-','mo-','co-']
            line, = ax.plot([], [], des[i])
            lines.append(line)
            empty_xs.append([])
            empty_ys.append([])

        # initialization function: plot the background of each frame
        def init():
            for line in lines:
                line.set_data([], [])
                plt.style.use('dark_background')
            return lines

        # animation function. This is called sequentially
        def animate(i):
            for j,traj in enumerate(trajectories):
                if i<len(traj[0]):
                    try:
                        empty_xs[j].append(traj[0][i*int(total_state/frames[-1])])
                        empty_ys[j].append(traj[1][i*int(total_state/frames[-1])]) 

                        lines[j].set_data(empty_xs[j],empty_ys[j])
                        
                    except:
                        pass
                    
            try:
                ax.set_xlim((-5+np.min(empty_xs), 5+np.max(empty_xs)))
                ax.set_ylim((-5+np.min(empty_ys), 5+np.max(empty_ys)))
            except:
                pass

            return (lines)


        # call the animator. blit=True means only re-draw the parts that have changed.
        anim = animation.FuncAnimation(fig, animate, init_func=init,
                                       frames=frames, interval=40, blit=True)

        return anim

In [80]:
plt.close('all')
total_state=1
fea_num=5 #[mass,x,y,x_vel,y_vel]
G=10**5 #G = 6.67428e-11
diff_t=0.001 # time step
n_body=5
next_state=np.zeros((n_body,fea_num),dtype=float);
f_mat=np.zeros((n_body,n_body,2),dtype=float);
f_sum=np.zeros((n_body,2),dtype=float);
acc=np.zeros((n_body,2),dtype=float);
placeholder=onp.zeros((total_state,n_body,fea_num),dtype=float)
dp=init(total_state,n_body,fea_num,G,placeholder);



data=gen(n_body,total_state,fea_num,dp,next_state,f_mat,f_sum,acc,diff_t);
print(grad_gen(n_body,total_state,fea_num,dp,next_state,f_mat,f_sum,acc,diff_t))
xy=data[:,:,1:3];
trajectories = transform(xy)
frames = range(100)
anim = plot_trajectories3D(trajectories,frames=frames)
HTML(anim.to_html5_video())

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




TypeError: Primal inputs to reverse-mode differentiation must be of float or complex type, got type int32

In [None]:
# @jit
# def init(total_state,n_body,fea_num,G,data):
#     data = jax.ops.index_update(data, jax.ops.index[0,0,0], 500)
#     data = jax.ops.index_update(data, jax.ops.index[1:5], 0.)

#     '''For another bodies'''
    
    
    
#     def body_fun(i,data):
        
#         data = jax.ops.index_update(data, jax.ops.index[0,i,0], onp.random.rand()*8.98+0.02)
#         print(onp.random.rand())
#         print(n_body,i)
#         distance = onp.random.rand()*90.0+50.0;
#         theta = onp.random.rand()*360.; #angle in degree
#         theta_rad = pi/2 - radians(theta)  #angle in radians
        

#         data = jax.ops.index_update(data, jax.ops.index[0,i,1], distance*cos(theta_rad))

#         data = jax.ops.index_update(data, jax.ops.index[0,i,2], distance*sin(theta_rad))
#         '''Calc Velocity'''
#         pos_norm=norm(data[0][i][1:3]) #norm of the position

#         data = jax.ops.index_update(data, jax.ops.index[0,i,3], -1*data[0][i][2] / pos_norm * (G*data[0][0][0]/pos_norm**2) * distance/1000)
#         data = jax.ops.index_update(data, jax.ops.index[0,i,4], data[0][i][1]/pos_norm * (G*data[0][0][0]/pos_norm**2) * distance/1000)
        
#         return data
        
#     init_val = data
#     start = 1
#     stop = n_body
#     data = lax.fori_loop(start, stop, body_fun, init_val)
    
     
#     return data;  