In [1]:
from pulse2percept.models import AxonMapModel, BiphasicAxonMapModel, BiphasicAxonMapSpatial, Model, AxonMapSpatial
from pulse2percept.stimuli import Stimulus, BiphasicPulseTrain
from pulse2percept.implants import DiskElectrode, ProsthesisSystem, ArgusII, ElectrodeArray
import time
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["JAX_LOG_COMPILES"] = "True"
import jax
jax.devices()

[GpuDevice(id=0, process_index=0)]

## Setup

In [26]:
model = BiphasicAxonMapModel(axlambda=800, engine='jax', min_ax_sensitivity=5e-3)
model.build()
implant = ArgusII()
stim = Stimulus({ "A10" : BiphasicPulseTrain(30, 1, 0.45)})
stim = Stimulus({"A5" : BiphasicPulseTrain(20, 1, 0.45), "A10" : BiphasicPulseTrain(30, 1, 0.45), "A4" : BiphasicPulseTrain(20, 1, 0.45)})
implant.stim = stim

In [34]:
start = time.time()
jax_time = model.predict_percept(implant)
real_time = (time.time() - start)*1000
print(f"Took {round(real_time, 1)}ms, {round(jax_time, 1)}ms jax ({round(jax_time / real_time * 100, 0)}%)")

Took 3.0ms, 2.4ms jax (78.0%)


In [38]:
modelnb = BiphasicAxonMapGPU(axlambda=800, engine='jax', min_ax_sensitivity=5e-3)
modelnb.build()
print()




In [47]:
start = time.time()
jax_time = modelnb.predict_percept(implant)
real_time = (time.time() - start)*1000
print(f"(Old) Took {round(real_time, 1)}ms, {round(jax_time, 1)}ms jax ({round(jax_time / real_time * 100, 0)}%)")

(Old) Took 174.9ms, 174.3ms jax (100.0%)


In [40]:
from jax import jit, vmap
import jax.numpy as jnp
def predict_one_point(axon, brights, sizes, streaks, x, y, rho, axlambda):
    d2_el = (axon[:, 0, None] - x)**2 + (axon[:, 1, None] - y)**2
    intensities = brights * jnp.exp(-d2_el / (2. * rho**2 * sizes)) * (axon[:, 2, None] ** (1./streaks))
    return jnp.sum(intensities, axis=1)
#     return jit(vmap(predict_one_segment, in_axes=[0, None, None, None, None, None, None, None]), static_argnums=[6,7])(axon, brights, sizes, streaks, x, y, rho, axlambda)

def gpu_biphasic_axon_map(amp, freq, pdur, x, y, # Per ACTIVE electrode, (amp, freq, pdur, x, y)
                          axon_segments, rho, axlambda, thresh_percept):
    deg2rad = 3.14159265358979323846 / 180.0

#     n_space = len(axon_segments)

    min_size = 10**2 / rho**2
    min_streak = 10**2 / axlambda **2

    # First get contributions from F, G, H per electrode
    scaled_amps = amp / (0.8825 + 0.27*pdur)
    brights = 1.84*scaled_amps + 0.2*freq + 3.0986
    sizes = jnp.maximum(1.081*scaled_amps - 0.3533764, min_size)
    streaks = jnp.maximum(1.56 - 0.54 * pdur ** 0.21, min_streak)

#     # axon_segments is (n_space, axon_length, 3), x and y are (n_elec)
#     d2_el = (axon_segments[:, :, 0, None] - params[3])**2 + (axon_segments[:, :, 1, None] - params[4])**2
#     # (n_space, axon_length, n_elecs)
    
#     #                       (n_elecs) (n_space, axon_length, n_elecs)            (n_space, axon_length, n_elecs) 
#     electrode_intensities = brights * jnp.exp(-d2_el / ( 2. * rho**2. * sizes)) * (axon_segments[:, :, 2, None] ** (1. / streaks))
#     # (n_space, axon_length, n_elecs)
    
# #     axon_intensities = 
#     # (n_space, n_elecs)
#     I = np.sum(electrode_intensities, axis=2)
    
    
    
    
    I = jnp.max(jit(vmap(predict_one_point, in_axes=[0, None, None, None, None, None, None, None]))(axon_segments, brights, sizes, streaks, x, y, rho, axlambda), axis=1)
    I = (I > thresh_percept) * I
    return I
from copy import deepcopy
from pulse2percept.percepts import Percept
class BiphasicAxonMapGPUSpatial(AxonMapSpatial):
    def __init__(self, **params):
        super(BiphasicAxonMapGPUSpatial, self).__init__(**params)

    def _predict_spatial(self, earray, stim):
#         start = time.time()
        assert isinstance(earray, ElectrodeArray)
        assert isinstance(stim, Stimulus)

        # get relevant stimulus properties
        amps = np.array([stim.metadata['electrodes'][str(e)]['metadata']['amp'] for e in stim.electrodes], dtype="float32")
        freqs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['freq'] for e in stim.electrodes], dtype="float32")
        pdurs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['phase_dur'] for e in stim.electrodes], dtype="float32")

        x = np.array([earray[e].x for e in stim.electrodes], dtype=np.float32)
        y = np.array([earray[e].y for e in stim.electrodes], dtype=np.float32)
        
        params = np.vstack([amps, freqs, pdurs, x, y])
#         begin = time.time()
#         print("Setup time: %.3fms" % ((begin - start)*1000))
#         params = jax.device_put(params, jax.devices()[0])
#         ax_contrib = jax.device_put(self.axon_contrib, jax.devices()[0])
#         ax_contrib = self.axon_contrib
        e1 = time.time()
#         print("Transfer time: %.3fms" % ((e1 - begin)*1000))
        p =  jit(self.gpu_biphasic_axon_map, static_argnums=[6, 7, 8])(amps, freqs, pdurs, x, y,
                                     self.axon_contrib,
                                     self.rho, self.axlambda, self.thresh_percept)
        end = time.time()
#         print("Jax time: {} ms on sample {}".format((end - begin) * 1000, stim.shape))

        return (end - e1) * 1000

    def _build(self):
        super(BiphasicAxonMapGPUSpatial, self)._build()
        self.axon_contrib = jax.device_put(self.axon_contrib, jax.devices()[0])
        
    def predict_one_point(self, axon, brights, sizes, streaks, x, y, rho, axlambda):
        d2_el = (axon[:, 0, None] - x)**2 + (axon[:, 1, None] - y)**2
        intensities = brights * jnp.exp(-d2_el / (2. * rho**2 * sizes)) * (axon[:, 2, None] ** (1./streaks))
        return jnp.sum(intensities, axis=1)
    #     return jit(vmap(predict_one_segment, in_axes=[0, None, None, None, None, None, None, None]), static_argnums=[6,7])(axon, brights, sizes, streaks, x, y, rho, axlambda)

    def gpu_biphasic_axon_map(self, amp, freq, pdur, x, y, # Per ACTIVE electrode, (amp, freq, pdur, x, y)
                              axon_segments, rho, axlambda, thresh_percept):
        deg2rad = 3.14159265358979323846 / 180.0

    #     n_space = len(axon_segments)

        min_size = 10**2 / rho**2
        min_streak = 10**2 / axlambda **2

        # First get contributions from F, G, H per electrode
        scaled_amps = amp / (0.8825 + 0.27*pdur)
        brights = 1.84*scaled_amps + 0.2*freq + 3.0986
        sizes = jnp.maximum(1.081*scaled_amps - 0.3533764, min_size)
        streaks = jnp.maximum(1.56 - 0.54 * pdur ** 0.21, min_streak)

    #     # axon_segments is (n_space, axon_length, 3), x and y are (n_elec)
    #     d2_el = (axon_segments[:, :, 0, None] - params[3])**2 + (axon_segments[:, :, 1, None] - params[4])**2
    #     # (n_space, axon_length, n_elecs)

    #     #                       (n_elecs) (n_space, axon_length, n_elecs)            (n_space, axon_length, n_elecs) 
    #     electrode_intensities = brights * jnp.exp(-d2_el / ( 2. * rho**2. * sizes)) * (axon_segments[:, :, 2, None] ** (1. / streaks))
    #     # (n_space, axon_length, n_elecs)

    # #     axon_intensities = 
    #     # (n_space, n_elecs)
    #     I = np.sum(electrode_intensities, axis=2)




        I = jnp.max(jit(vmap(self.predict_one_point, in_axes=[0, None, None, None, None, None, None, None]), static_argnums=[6, 7])(axon_segments, brights, sizes, streaks, x, y, rho, axlambda), axis=1)
    #     I = (I > thresh_percept) * I
        return I
        



class BiphasicAxonMapGPU(Model):
    def __init__(self, **params):
        super(BiphasicAxonMapGPU, self).__init__(spatial=BiphasicAxonMapGPUSpatial(), temporal=None, **params)

    def predict_percept(self, implant, t_percept=None):
#         start = time.time()
        # Make sure stimulus is a BiphasicPulseTrain:
        if not isinstance(implant.stim, BiphasicPulseTrain):
            # Could still be a stimulus where each electrode has a biphasic pulse train
            for ele, params in implant.stim.metadata['electrodes'].items():
                if params['type'] != BiphasicPulseTrain or params['metadata']['delay_dur'] != 0: 
                    raise TypeError("All stimuli must be BiphasicPulseTrains with no delay dur (Failing electrode: %s)" % (ele)) 
        
#         return super(BiphasicAxonMapGPU, self).predict_percept(implant, t_percept=t_percept)

        if not self.is_built:
            raise NotBuiltError("Yout must call ``build`` first.")
        if not isinstance(implant, ProsthesisSystem):
            raise TypeError(("'implant' must be a ProsthesisSystem object, "
                             "not %s.") % type(implant))

        # Calculate the Stimulus at requested time points:

        stim = Stimulus(implant.stim) # make sure stimulus is in proper format

#         e4 = time.time()
#         print("Stim restructuring and checks: %.3fms" % ((e4 - start) * 1000))
        
        resp = self._predict_spatial(implant.earray, stim)
        return resp
#         e5 = time.time()
#         print("predict_spatial: %.3fms" % ((e5 - e4) * 1000))
        p = Percept(resp.reshape(list(self.grid.x.shape) + [-1]),
                       space=self.grid, time=None,
                       metadata={'stim': stim.metadata})
#         e6 = time.time()
#         print("Percept Construction: %.3fms" % ((e6 - e5) * 1000))
        return p

    def predict_percept_batched(self, implant, stim_list, t_percept=None):
        start = time.time()
        for stim in stim_list:
            if not isinstance(stim, BiphasicPulseTrain):
                # Could still be a stimulus where each electrode has a biphasic pulse train
                for ele, params in stim.metadata['electrodes'].items():
                    if params['type'] != BiphasicPulseTrain or params['metadata']['delay_dur'] != 0: 
                        raise TypeError("All stimuli must be BiphasicPulseTrains with no delay dur (Failing electrode: %s)" % (ele)) 
        
        if not self.is_built:
            raise NotBuiltError("You must call ``build`` first.")
        if not isinstance(implant, ProsthesisSystem):
            raise TypeError(("'implant' must be a ProsthesisSystem object, "
                             "not %s.") % type(implant))
        
        batch = []
        for stim in stim_list:   
            amps = np.array([stim.metadata['electrodes'][str(e)]['metadata']['amp'] for e in stim.electrodes], dtype="float32")
            freqs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['freq'] for e in stim.electrodes], dtype="float32")
            pdurs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['phase_dur'] for e in stim.electrodes], dtype="float32")
            x = np.array([implant.earray[e].x for e in stim.electrodes], dtype=np.float32)
            y = np.array([implant.earray[e].y for e in stim.electrodes], dtype=np.float32)

            batch.append(np.vstack([amps, freqs, pdurs, x, y]))
        
        batch = np.array(batch) 
        e1 = time.time()
        print("Setup: %.3fms" % ((e1 - start) * 1000))
#         print(batch.shape)
        start = time.time()
#         print("Transfer: %.3fms" % ((start - e1) * 1000))
        resps = jit(vmap(gpu_biphasic_axon_map, in_axes=[0, None, None, None, None]), static_argnums=[2,3,4])(batch, self.axon_contrib, self.rho, self.axlambda, self.thresh_percept).block_until_ready()
        end = time.time()
        print("Jax time: %.3f ms, %.3f ms per sample" % (float(end - start) * 1000, float((end - start) * 1000 / len(stim_list))))


In [7]:
m1 = BiphasicAxonMapModel()
m1.build()
p1 = m1.predict_percept(implant)
np.max(np.abs(p1.data - percept.data))

NameError: name 'percept' is not defined